In [None]:
from os import path

# Third-party
from astropy.table import Table
import astropy.coordinates as coord
import astropy.units as u
from astropy.constants import G, c
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
import numpy as np
plt.style.use('apw-notebook')
%matplotlib inline

import corner
import emcee
from scipy.integrate import quad
from scipy.misc import logsumexp
import schwimmbad
import tqdm

from comoving_rv.log import logger
from comoving_rv.db import Session, Base, db_connect
from comoving_rv.db.model import (Run, Observation, TGASSource, SimbadInfo, PriorRV,
                                  SpectralLineInfo, SpectralLineMeasurement, RVMeasurement)

In [None]:
base_path = '/Volumes/ProjectData/gaia-comoving-followup/'
db_path = path.join(base_path, 'db.sqlite')
engine = db_connect(db_path)
session = Session()

In [None]:
def get_y_hat(tgas_source, names=['ra', 'dec', 'parallax', 'pmra', 'pmdec'], units=None):
    y = np.zeros(len(names))
        
    default_units = dict()
    default_units['ra'] = u.degree
    default_units['dec'] = u.degree
    default_units['parallax'] = u.mas
    default_units['pmra'] = u.mas/u.yr
    default_units['pmdec'] = u.mas/u.yr
    default_units['rv'] = u.km/u.s
    
    if units is None:
        units = [default_units[name] for name in names]
    
    for i,name in enumerate(names):
        v = getattr(tgas_source, name)
        if name in ['ra', 'dec']:
            v = v.to(u.degree).value
        y[i] = (v*default_units[name]).to(units[i]).value
    
    return y

In [None]:
def get_cov(tgas_source, names=['ra', 'dec', 'parallax', 'pmra', 'pmdec'], units=None):
    
    default_err_units = dict()
    default_err_units['ra'] = u.mas
    default_err_units['dec'] = u.mas
    default_err_units['parallax'] = u.mas
    default_err_units['pmra'] = u.mas/u.yr
    default_err_units['pmdec'] = u.mas/u.yr
    default_err_units['rv'] = u.km/u.s
    
    if units is None:
        units = [default_err_units[name] for name in names]
    
    C = np.zeros((len(names), len(names)))

    # pre-load the diagonal
    for i,name in enumerate(names):
        full_name = "{}_error".format(name)
        C[i,i] = (getattr(tgas_source, full_name)*default_err_units[name]).to(units[i]).value**2

    for i,name1 in enumerate(names):
        for j,name2 in enumerate(names):
            if j <= i:
                continue
                
            if not hasattr(tgas_source, full_name): # skip if no correlations exist
                continue
                
            full_name = "{}_{}_corr".format(name1, name2)
            u_old = default_err_units[name1]*default_err_units[name2]
            u_new = units[i]*units[j]
            C[i,j] = (getattr(tgas_source, full_name) * np.sqrt(C[i,i]*C[j,j]) * u_old).to(u_new).value
            C[j,i] = (getattr(tgas_source, full_name) * np.sqrt(C[i,i]*C[j,j]) * u_old).to(u_new).value
            
    return C

In [None]:
def get_icrs_cartesian(obs, n_samples=2**16):
    names = ['ra', 'dec', 'parallax', 'pmra', 'pmdec']
    units = [u.mas, u.mas, u.mas, u.mas/u.yr, u.mas/u.yr]

    y_hat = np.zeros(6)
    C = np.zeros((6,6))

    y_hat[:5] = get_y_hat(obs.tgas_source, names=names, units=units)
    C[:5,:5] = get_cov(obs.tgas_source, names=names, units=units)
    
    # Use the corrected, derived value
    y_hat[5] = obs.rv_measurement.rv.value
    C[5,5] = obs.rv_measurement.err.value**2
    
    # Use the raw, pseudovalue
#     y_hat[5] = (obs.measurements[0].x0 - 6563.)/6563 * c.to(u.km/u.s).value
#     C[5,5] = (obs.measurements[0].x0_error/6563 * c.to(u.km/u.s).value)**2

    ys = np.random.multivariate_normal(y_hat, C, size=n_samples)

    icrs = coord.ICRS(ra=ys[:,0]*u.mas, dec=ys[:,1]*u.mas, distance=1000./ys[:,2]*u.pc,
                      pm_ra_cosdec=ys[:,3]*u.mas/u.yr, pm_dec=ys[:,4]*u.mas/u.yr,
                      radial_velocity=ys[:,5]*u.km/u.s)

    icrs.set_representation_cls(coord.CartesianRepresentation,
                                coord.CartesianDifferential)
    
    return icrs

In [None]:
group_ids = np.array([x[0] 
                      for x in session.query(Observation.group_id).distinct().all() 
                      if x[0] is not None and x[0] > 0 and x[0] != 10])

In [None]:
base_q = session.query(Observation).join(RVMeasurement).filter(RVMeasurement.rv != None)

i = 0
for gid in np.unique(group_ids):
    observations = base_q.filter(Observation.group_id == gid).all()
    
    if len(observations) != 2:
        print("skipping group {0}".format(gid))
        continue
        
    obs1, obs2 = observations
    
    icrs1 = get_icrs_cartesian(obs1)
    icrs2 = get_icrs_cartesian(obs2)
    
    if i == 4:
        break
        
    i += 1

In [None]:
dx = coord.CartesianDifferential(icrs1.cartesian.xyz - icrs2.cartesian.xyz)
dv = coord.CartesianDifferential(icrs1.cartesian.differentials['s'].d_xyz - 
                                 icrs2.cartesian.differentials['s'].d_xyz)

In [None]:
dv_random = np.random.normal(np.zeros(3), np.full(3, 25.), size=(2**16,3))

In [None]:
_ = corner.corner(dv.d_xyz.T.value)

In [None]:
fig = corner.corner(dv_random, color='#aaaaaa')
fig = corner.corner(dv.d_xyz.T.value, fig=fig, 
                    range=[(-75, 75), (-75, 75), (-75, 75)],
                    bins=64)