In [None]:
import os

from astropy.io import fits
from astropy.time import Time
import astropy.table as at
from astropy.timeseries import BoxLeastSquares
from astropy.constants import G
import astropy.coordinates as coord
import astropy.units as u
import h5py
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import requests
from tqdm.notebook import tqdm
import lightkurve as lk

import thejoker as tj
from hq.data import get_rvdata

In [None]:
binaries = at.Table.read('../catalogs/lnK0.0_logL4.6_metadata_gaia_starhorse.fits')
gold = at.QTable(at.Table.read('../catalogs/gold_sample.fits').filled())

In [None]:
kepler = at.Table.read('/mnt/home/apricewhelan/data/Gaia-Kepler/kepler_dr2_4arcsec.fits')
k2 = at.Table.read('/mnt/home/apricewhelan/data/Gaia-Kepler/k2_dr2_4arcsec.fits')

kepler = kepler[kepler['kepler_gaia_ang_dist'] < 1.]
k2 = k2[k2['k2_gaia_ang_dist'] < 1.]

In [None]:
master = kepler[np.abs(kepler['phot_g_mean_mag'] - kepler['kepmag']) < 1.]
master['tm_designation'] = master['tm_designation'].astype(str)
master = master[master['tm_designation'] != 'N/A']

_, idx = np.unique(master['source_id'], return_index=True)
master = master[idx]

In [None]:
gold_master = at.join(gold, master, keys='source_id', 
                      uniq_col_name='{col_name}{table_name}',
                      table_names=['', '2'])

binaries_master = at.join(binaries, master, keys='source_id', 
                          uniq_col_name='{col_name}{table_name}',
                          table_names=['', '2'])

## Cross-match known EBs:

In [None]:
kebs = at.Table.read('/mnt/home/apricewhelan/data/Gaia-Kepler/Kirk2016-Kepler-EBs.csv', 
                     format='ascii.commented_header', 
                     delimiter=',', header_start=7)

In [None]:
np.isin(kebs['KIC'], binaries_master['kepid']).sum()

In [None]:
np.isin(kebs['KIC'], gold_master['kepid']).sum()

---

## Look at light curve:

In [None]:
mask = (gold_master['MAP_P'] > 365*u.day) & (gold_master['MAP_P'] < 1000*u.day)
mask.sum()

In [None]:
for row in gold_master[mask][5:]:
    lcfs = lk.search_lightcurvefile(f"KIC {row['kepid']}", mission='Kepler').download_all()
    stitched_lc = lcfs.PDCSAP_FLUX.stitch()
    break

In [None]:
def get_transit_period(lc, rv_period=None):
    
    # Convert to parts per thousand
    x = lc.astropy_time.tcb.jd
    y = lc.flux
    mu = np.nanmedian(y)
    y = (y / mu - 1) * 1e3
    yerr = lc.flux_err * 1e3

    x_ref = np.min(x)
    x = x - x_ref
    
    m = np.isfinite(y)
    bls = BoxLeastSquares(x[m], y[m])

    if rv_period is None:
        period_grid = np.exp(np.linspace(np.log(1.5), np.log(200), 10000))
    else:
        logP = np.log(rv_period.to_value(u.day))
        period_grid = np.exp(np.linspace(logP-1, logP+1, 10000))
    
    bls_power = bls.power(period_grid, 0.1, oversample=10)

    # Save the highest peak as the planet candidate
    index = np.argmax(bls_power.power)
    bls_period = bls_power.period[index]
    bls_t0 = bls_power.transit_time[index]
    bls_depth = bls_power.depth[index]
    
    return Time(bls_t0 + x_ref, format='jd', scale='tcb'), bls_period

In [None]:
stitched_lc.plot()
# plt.xlim(500, 750)
# plt.ylim(0.99, 1.01)

In [None]:
stitched_lc.fold(row['MAP_P'].to_value(u.day)).plot()

In [None]:
bls_t0, bls_P = get_transit_period(stitched_lc) # , rv_period=500*u.day)

In [None]:
dmjd = stitched_lc.astropy_time.mjd.min() - stitched_lc.time.min()

In [None]:
stitched_lc.plot()
for i in range(4):
    plt.axvline(bls_t0.mjd - dmjd + i*bls_P, 
                marker='', color='tab:red')

In [None]:
stitched_lc.fold(bls_P, t0=bls_t0.mjd - dmjd).plot(ls='none', marker='o', ms=1.5, mew=0)
# plt.xlim(-0.02, 0.02)