In [None]:
import os
from os import path

# Third-party
import astropy.coordinates as coord
from astropy.io import ascii
from astropy.table import Table, join, hstack, vstack
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
from astropy.io import fits
from thejoker.data import RVData

import pystan

In [None]:
os.makedirs('calibrate-visit-err', exist_ok=True)

In [None]:
sm = pystan.StanModel('visit-err.stan')

In [None]:
allstar = fits.getdata('../data/allStarLite-r12-l33.fits')

In [None]:
allstar_c = coord.SkyCoord(allstar['RA'], allstar['DEC'], unit=u.deg)

In [None]:
allvisit = fits.getdata('../data/allVisit-r12-l33.fits')

## Load exoplanet validation sample:

In [None]:
exop = Table.read('../data/exoplanets_notbinaries.votable')
total_pm = np.sqrt(exop['gaia_pmra']**2 + exop['gaia_pmdec']**2)
slow_exo_mask = ((15*total_pm * u.mas).to(u.arcsec) < 5*u.arcsec)
exop = exop[slow_exo_mask]

In [None]:
exo_c = coord.SkyCoord(ra=exop['ra'], dec=exop['dec'])

In [None]:
idx, sep, _ = exo_c.match_to_catalog_sky(allstar_c)

sub_exop = exop[sep < 8*u.arcsec]['pl_hostname', 'pl_rvamp', 'pl_rvamperr1', 'pl_rvamperr2', 
                                  'hd_name', 'hip_name', 'st_j', 'st_h', 'pl_orbper', 'pl_bmassj']
apogee_exo = hstack((Table(allstar[idx[sep < 8*u.arcsec]]), sub_exop))
hmask = (np.abs(apogee_exo['H'] - apogee_exo['st_h']) < 1e-2)
apogee_exo = apogee_exo[hmask & (apogee_exo['NVISITS'] >= 4) & (apogee_exo['pl_rvamp'] < 50.)]
len(apogee_exo)

## Load Gaia RV standards validation sample:

In [None]:
rvs = Table.read(path.expanduser('~/data/GaiaDR2/gaia_dr2_rvs_standards.fit'))
rvs['APOGEE_ID'] = ['2M'+x[6:] for x in rvs['_2MASS']]
rvs = rvs[(rvs['e_RV'] < 0.05) & (rvs['s_RV'] < 0.05)]

In [None]:
mask = np.isin(allstar['APOGEE_ID'], rvs['APOGEE_ID'])

In [None]:
sub_allstar = Table(allstar[mask])

In [None]:
join_tbl = join(sub_allstar, rvs, keys='APOGEE_ID')
join_tbl = join_tbl[join_tbl['NVISITS'] >= 4]
len(join_tbl)

---

## Combine the two validation samples:

In [None]:
valid_tbl = vstack((join_tbl, apogee_exo))

In [None]:
visit_mask = np.isin(allvisit['APOGEE_ID'], valid_tbl['APOGEE_ID'])
visits = allvisit[visit_mask]

In [None]:
plt.hist(valid_tbl['VSCATTER'], bins=np.linspace(0, 1, 64));

In [None]:
_, all_nvisits, idx = np.unique(visits['APOGEE_ID'], 
                                return_counts=True, return_index=True)

In [None]:
apogee_ids = np.unique(valid_tbl['APOGEE_ID'])
n_stars = len(apogee_ids)
print(n_stars)

all_nvisits = []
rv = []
rv_var = []
rv_snr = []
mean_rv = []
for id_ in apogee_ids:
    this_visits = visits[(visits['APOGEE_ID'] == id_) & np.isfinite(visits['VHELIO'])]
    all_nvisits.append(len(this_visits))
    
    rv.append(this_visits['VHELIO'])
    rv_var.append(this_visits['VRELERR'] ** 2)
    rv_snr.append(this_visits['SNR'])
    mean_rv.append(np.median(this_visits['VHELIO']))
    
rv = np.concatenate(rv)
rv_var = np.concatenate(rv_var)
rv_snr = np.concatenate(rv_snr)

In [None]:
data = dict()

data['n_stars'] = n_stars
data['n_visits'] = all_nvisits
data['total_n_visits'] = np.sum(all_nvisits)

data['rv'] = rv
data['rv_var'] = rv_var
data['rv_snr'] = rv_snr

In [None]:
init = dict()

init['mean_rv'] = mean_rv
init['a'] = 1.
init['b'] = -0.5
init['lns'] = -2.

In [None]:
fit = sm.optimizing(data=data, init=init, iter=1024)

In [None]:
fit

In [None]:
def get_inflation_factor(fit, snr):
    # return fit['a'] + fit['b'] * snr + fit['c'] * snr**2
    # return fit['a'] * snr**fit['b']
    return fit['a'] * snr**fit['b'] + fit['c'] * snr**fit['d']

def get_new_err(fit, visits):
    err = visits['VRELERR']
    snr = visits['SNR']
    return np.sqrt(fit['s']**2 + err**2 + fit['a']*snr**fit['b'])

def get_nidever_err(visits):
    err = visits['VRELERR']
    var = (3.5*err**1.2)**2 + 0.072**2
    return np.sqrt(var)

In [None]:
# snr = np.linspace(1, 500, 1024)
# plt.plot(snr, np.sqrt(get_inflation_factor(fit, snr)))
# plt.xlim(0, 500)
# plt.ylim(0, 10)

In [None]:
_, bins, _ = plt.hist(allvisit['SNR'][np.isfinite(allvisit['SNR'])], 
                      bins=np.linspace(0, 1035, 128));
plt.hist(visits['SNR'], bins=bins)
plt.yscale('log')

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
ax.plot(allvisit['SNR'], allvisit['VRELERR'], 
        marker=',', color='k', alpha=0.2, ls='none',
        rasterized=True)
fig.tight_layout()
ax.set_xlim(1e0, 1e3)
ax.set_ylim(1e-3, 2e1)
ax.set_xscale('log')
ax.set_yscale('log')

ax.set_xlabel('SNR')
ax.set_ylabel('VRELERR [{:latex_inline}]'.format(u.km/u.s))

fig.tight_layout()

fig.savefig('calibrate-visit-err/calib-err-snr.pdf', dpi=250)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 4))

new_err = get_new_err(fit, allvisit)
ax.plot(allvisit['SNR'], new_err, 
        marker=',', color='k', alpha=0.2, ls='none',
        rasterized=True)
fig.tight_layout()

ax.set_xlim(1e0, 1e3)
ax.set_ylim(1e-3, 2e1)

ax.set_xscale('log')
ax.set_yscale('log')

ax.axhline(fit['s'], color='tab:blue', marker='', lw=1)
ax.text(1.5, 0.8*fit['s'], '{:.3f} {:latex_inline}'.format(fit['s'], u.km/u.s), 
        va='top', color='tab:blue', fontsize=22)

ax.set_xlabel('SNR')
ax.set_ylabel('adjusted err [{:latex_inline}]'.format(u.km/u.s))

fig.tight_layout()

fig.savefig('calibrate-visit-err/calib-adj-err-snr.pdf', dpi=250)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 4))

nid_err = get_nidever_err(allvisit)
ax.plot(allvisit['SNR'], nid_err, 
        marker=',', color='k', alpha=0.2, ls='none',
        rasterized=True)
fig.tight_layout()

ax.set_xlim(1e0, 1e3)
ax.set_ylim(1e-3, 2e1)

ax.set_xscale('log')
ax.set_yscale('log')

ax.set_xlabel('SNR')
ax.set_ylabel('adjusted err [{:latex_inline}]'.format(u.km/u.s))

fig.tight_layout()

fig.savefig('calibrate-visit-err/nidever-adj-err.pdf', dpi=250)

---

In [None]:
for id_ in apogee_ids[20:]:
    this_visits = visits[(visits['APOGEE_ID'] == id_) & np.isfinite(visits['VHELIO'])]
    
    snr = this_visits['SNR']
    infl_err = get_new_err(fit, this_visits)
    
    fig, ax = plt.subplots(1, 1, figsize=(8, 6))
    ax.errorbar(snr, this_visits['VHELIO'], infl_err,
                marker='', ecolor='#aaaaaa', zorder=-100, ls='none')
    ax.plot(snr, this_visits['VHELIO'],
            marker='o', ls='none', color='k')