In [None]:
from os import path

from astropy.constants import G
from astropy.table import Table
import astropy.units as u
import h5py
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from tqdm import tqdm

from hq.config import HQ_CACHE_PATH
from hq.db import db_connect, AllStar, StarResult, Status, JokerRun
from hq.io import load_samples
from hq.plot import plot_two_panel, plot_phase_fold

from thejoker.plot import plot_rv_curves

In [None]:
run_name = 'apogee-sample-jitter'

In [None]:
f = h5py.File(path.join(HQ_CACHE_PATH, '{0}.hdf5'.format(run.name)), 'r')

In [None]:
def get_samples(star, data, run):
    # with h5py.File(path.join(HQ_CACHE_PATH, '{0}.hdf5'.format(run.name)), 'r') as f:
    samples = load_samples(f[star.apogee_id], poly_trend=run.poly_trend, t0=data.t0)
        
    return samples

In [None]:
thepayne = Table.read('../data/Apogee_The_Payne.txt', format='ascii.commented_header')
payne_mask = ((thepayne['Logg'] > 0) & (thepayne['Logg'] < 3.5) &
              (thepayne['Teff[K]'] > 3000) & (thepayne['Teff[K]'] < 5500))
thepayne = thepayne[payne_mask]

apogee_ids = np.unique(thepayne['APOGEE_ID'])
len(apogee_ids)

In [None]:
Session, engine = db_connect(path.join(HQ_CACHE_PATH, 'apogee.sqlite'))
s = Session()

In [None]:
run = s.query(JokerRun).filter(JokerRun.name == run_name).limit(1).one()

In [None]:
status_ids = np.sort([x[0] for x in s.query(Status.id).distinct().all()])
for i in status_ids:
    N = s.query(AllStar).join(StarResult, Status, JokerRun)\
                        .filter(Status.id == i)\
                        .filter(JokerRun.name == run.name)\
                        .group_by(AllStar.apogee_id)\
                        .distinct().count()
    msg = s.query(Status).filter(Status.id == i).limit(1).one().message
    print("Status {0} ({2}) : {1}".format(i, N, msg))

In [None]:
stars = s.query(AllStar).join(StarResult, Status, JokerRun)\
                        .filter(Status.id == 2)\
                        .filter(JokerRun.name == run.name)\
                        .filter(~AllStar.starflags.like('%BRIGHT_NEIGHBOR%'))\
                        .filter(~AllStar.starflags.like('%STAR_WARN%'))\
                        .filter(~AllStar.starflags.like('%ATMOS%'))\
                        .filter(~AllStar.aspcapflags.like('%ATMOS%'))\
                        .group_by(AllStar.apogee_id).distinct().all()

len(stars)

In [None]:
logg = []
teff = []
P = []
ecc = []
K = []
jitter = []
for star in tqdm(stars):
    data = star.get_rvdata()
    samples = get_samples(star, data, run)
    P.append(np.mean(samples['P']))
    ecc.append(np.mean(samples['e']))
    K.append(np.mean(samples['K']))
    jitter.append(np.mean(samples['jitter']))
    logg.append(star.logg)
    teff.append(star.teff)
    # payne_i = np.where(star.apogee_id == apogee_ids)[0][0]
    # payne_logg.append(thepayne[payne_i]['Logg'])

# payne_logg = np.array(payne_logg)
logg = np.array(logg)
teff = np.array(teff)
P = u.Quantity(P)
ecc = np.array(ecc)
K = u.Quantity(K)
jitter = u.Quantity(jitter)

In [None]:
mask = (K > 1*u.km/u.s) & (P > 20*u.day) & (P < 300*u.day) & (logg > 2) & (logg < 4)

fig, ax = plt.subplots(1, 1, figsize=(6, 5))
# ax.hist(P[mask].value, 
#         bins=np.logspace(0, 4.5, 21))
# ax.set_xscale('log')

from scipy.stats import beta
bins = np.linspace(0, 1, 10)
ax.hist(ecc[mask], bins=bins, density=True);

_x = np.linspace(0, 1, 100)
ax.plot(_x, beta.pdf(_x, a=0.867, b=3.03))

In [None]:
f.close()

## Period vs. eccentricity:

In [None]:
K_mask = (K > 5*u.km/u.s) & (P > 2*u.day)
print(K_mask.sum())

plt.figure(figsize=(6, 6))
plt.plot(P.to_value(u.day)[K_mask], ecc[K_mask], 
         marker='.', color='k', alpha=0.5, ls='none')
plt.xscale('log')
plt.xlim(1, 2000)
plt.ylim(0, 1)

In [None]:
weirdos = np.where((P < 5*u.day) & (ecc > 0.3) & K_mask)[0]
len(weirdos)

In [None]:
for i in weirdos[:10]:
    star = stars[i]
    data = star.get_rvdata()
    samples = get_samples(star, data, run)
    
    fig = plot_two_panel(data, samples)
    fig.axes[0].set_title(star.apogee_id + ': ' + star.starflags + ' ' + star.aspcapflags)
    print(star.apogee_id, star.logg, star.teff)

## jitter trends with surface gravity?

In [None]:
stars = s.query(AllStar).join(StarResult, Status, JokerRun)\
                        .filter(Status.id == 4)\
                        .filter(JokerRun.name == run.name)\
                        .filter(AllStar.nvisits >= 7)\
                        .filter(AllStar.nvisits <= 20)\
                        .filter(AllStar.apogee_id.in_(apogee_ids))\
                        .group_by(AllStar.apogee_id).distinct().all()
len(stars)

In [None]:
for star in stars[:4]:
    data = star.get_rvdata()
    samples = get_samples(star, data, run)
    
    fig = plot_two_panel(data, samples)

In [None]:
all_jitters = []
payne_logg = []
for star in tqdm(stars):
    data = star.get_rvdata()
    samples = get_samples(star, data, run)
    all_jitters.append(samples['jitter'])
    
    payne_i = np.where(star.apogee_id == apogee_ids)[0][0]
    payne_logg.append(thepayne[payne_i]['Logg'])

payne_logg = np.array(payne_logg)
all_jitters = u.Quantity(all_jitters)

In [None]:
s_ms = all_jitters.to_value(u.m/u.s)

In [None]:
all_jitters.shape

In [None]:
payne_logg.shape

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(12, 8))

ax.plot(payne_logg[:, None], s_ms, 
        marker=',', ls='none', alpha=0.1, color='k')

ax.set_xlim(3.1, 0.4)
ax.set_yscale('log')