# Fit GPs to TESS data

Can a damped, driven simple harmonic oscillator model the TESS data effectively?  To what extent can multiple peaks be explained as phase drift?

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina' 

In [None]:
from astropy.stats import LombScargle
import astropy.units as u
from astropy.time import Time

In [None]:
import celerite2
from celerite2 import terms

In [None]:
import lightkurve as lk

## Retrieve custom-made lightkurve data

We'll retrieve the custom made lightcurve that we saved.

### Sector 10

In [None]:
lc_raw = lk.TessLightCurve.read('../../data/TESS/lightkurve_custom_4pixel.fits', format='tess').normalize()

In [None]:
threshold1L, threshold1R = 1572, 1581.7
threshold2L, threshold2R = 1585, 1595.7

mask1 = (lc_raw.time.value > threshold1L) & (lc_raw.time.value < threshold1R)
mask2 = (lc_raw.time.value > threshold2L) & (lc_raw.time.value < threshold2R)

In [None]:
lc_s10 = lc_raw[mask1 | mask2]

In [None]:
lc_s10.time.max().value - lc_s10.time.min().value

In [None]:
lc_s10.plot()

### Sector 36

In [None]:
lc_raw = lk.LightCurve.read('../../data/TESS/LUH16_TESS_S36_lk_custom_4pixel.csv', time_format='BTJD'
                           ).remove_nans().normalize()

In [None]:
lc_raw.flux_err = np.abs(lc_raw.flux / 200.0) + np.median(lc_raw.flux / 200.0)

In [None]:
lc_s36 = lc_raw

FYI The `time_format` kwarg does not appear to work!  doesn't matter for periodograms.

In [None]:
lc_raw.time

In [None]:
lc_raw.time.max() - lc_raw.time.min()

In [None]:
lc_s36.plot()

### Sector 37

In [None]:
lc_raw = lk.LightCurve.read('../../data/TESS/LUH16_TESS_S37o1_lk_custom_2pixel.csv'
                               ).remove_nans().normalize()

In [None]:
lc_raw.flux_err = np.abs(lc_raw.flux / 200.0) + np.median(lc_raw.flux / 200.0)

In [None]:
lc_s37 = lc_raw[((lc_raw.time.value > 2308.8 ) & ((lc_raw.time.value < 2320.0)))| 
                ((lc_raw.time.value > 2321.2 ) & ((lc_raw.time.value < 23033))) ]

In [None]:
lc_s37.time.max() - lc_s37.time.min()

In [None]:
lc_s37.plot()

In [None]:
lc_s10_shifted = lc_s10

In [None]:
lc_s10_shifted.time += 680 *u.day

### Attempt at binning...

In [None]:
time_bin = 4.7*u.hour
ax = lc_s36.bin(time_bin_size=time_bin).plot()#scatter(label='Sector 36', marker='o', alpha=0.5, ec='k', fc='b', s=20,)
lc_s37.bin(time_bin_size=time_bin).plot(ax=ax)#scatter(ax=ax, label='Sector 37', marker='o', alpha=0.5, ec='k', fc='r')
lc_s10_shifted.bin(time_bin_size=time_bin).plot(ax=ax)#scatter(ax=ax, label='Sector 10 (Time shifted)', marker='o', alpha=0.5, ec='k', fc='g')
ax.legend(loc='best', ncol=3)
ax.set_xlabel('Time [BTJD]')
#ax.set_ylim(0.85, 1.1)

Meh! Might be useful, but let's make a better lightcurve figure first.

## Make the Power Spectrum

Let's create the power spectrum separately for each campaign.  This approach allows us to not worry about the mean level that we assign to each campaign, and it helps see which PSD structures persist from campaign to campaign.

In [None]:
pg_s10 = lc_s10.to_periodogram(normalization='psd', period=np.arange(1.0, 10.0, 0.005)*u.hour, oversample_factor=10)
pg_s36 = lc_s36.to_periodogram(normalization='psd', period=np.arange(1.0, 10.0, 0.005)*u.hour, oversample_factor=10)
pg_s37 = lc_s37.to_periodogram(normalization='psd', period=np.arange(1.0, 10.0, 0.005)*u.hour, oversample_factor=10)

In [None]:
CB_color_cycle = ['#377eb8', '#ff7f00', '#4daf4a',
                  '#f781bf', '#a65628', '#984ea3',
                  '#999999', '#e41a1c', '#dede00']

In [None]:
ax = pg_s10.plot(unit=u.hour, view='period', label='Sector 10', drawstyle='steps-mid')
pg_s36.plot(unit=u.hour, view='period', label='Sector 36', ax=ax, drawstyle='steps-mid', color='#377eb8', lw=1)
pg_s37.plot(unit=u.hour, view='period', label='Sector 37 (orbit 1)', ax=ax, drawstyle='steps-mid', color='#ff7f00')

ax.axvline(5.28, color='#4daf4a', linestyle='dashed', 
           label='$P_B = $ {:0.3f} h'.format(5.28),linewidth=1.8)
ax.axvline(6.940, color='#f781bf', linestyle='solid', 
           label='{:0.3f} h (maybe $P_A$)'.format(6.940),linewidth=1.8)

ax.set_yscale('log')
ax.set_ylim(3e-8, 1e-3)
ax.set_xlim(2, 8)
ax.legend(loc='best', ncol=3)

plt.savefig('../../figures/TESS_Periodogram_S10-S37_wide.png', bbox_inches='tight', dpi=300)

In [None]:
ax = pg_s10.plot(unit=u.hour, view='period', label='Sector 10', drawstyle='steps-mid')
pg_s36.plot(unit=u.hour, view='period', label='Sector 36', ax=ax, drawstyle='steps-mid', color='#377eb8', lw=1)
pg_s37.plot(unit=u.hour, view='period', label='Sector 37', ax=ax, drawstyle='steps-mid', color='#ff7f00')

ax.axvline(2.435, color='#e41a1c', linestyle='dashed', 
           label='2.435 h',linewidth=0.8)
ax.axvline(2.565, color='#dede00', linestyle='solid', 
           label='2.565 h',linewidth=0.8)

ax.set_yscale('log')
ax.set_ylim(3e-8, 1e-4)
ax.set_xlim(2, 3)
ax.legend(loc='best', ncol=5, fontsize=9)

plt.savefig('../../figures/TESS_Periodogram_S10-S37_short.png', bbox_inches='tight', dpi=300)

In [None]:
lc=lc_s36

In [None]:
pg = lc.to_periodogram(normalization='psd', freq_unit=1/u.day, oversample_factor=10)

In [None]:
lc_noise = lc.copy()
lc_noise.flux = np.random.normal(1, scale=lc.flux_err)

In [None]:
ax = lc_noise.plot(label='Noise')
lc.plot(ax=ax);

In [None]:
pg_noise = lc_noise.to_periodogram(normalization='psd', freq_unit=1/u.day, oversample_factor=10)

In [None]:
n_draws = 251

In [None]:
peak_period = 5.28 * u.hour
peak_frequency = (1.0/(peak_period.to(u.day)))

alt_period = 6.94 * u.hour
alt_frequency = 1.0/(alt_period.to(u.day))

Construct a noise region

In [None]:
noise_power_draws = []

In [None]:
for i in range(n_draws):
    lc_noise = lc.copy()
    lc_noise.flux = np.random.normal(1, scale=lc.flux_err)
    pg_noise = lc_noise.to_periodogram(normalization='psd', freq_unit=1/u.day, oversample_factor=10)
    pg_noise.plot(ax=ax, scale='log', label=None, alpha=0.1, color='#c0392b')
    noise_power_draws.append(pg_noise.power.value)

In [None]:
lo, med, hi = np.percentile(noise_power_draws, (15.9, 50.0, 84.1), axis=0)

In [None]:
ax = pg.plot(scale='log', zorder=10)
ax.set_ylim(med.mean()/3);
ax.axvline(peak_frequency.value, color='#2980b9', linestyle='dashed', 
           label='$P_B = $ {:0.3f}'.format(peak_period),linewidth=0.8)
ax.axvline(peak_frequency.value*2, color='#2980b9', linestyle='dotted', 
           label='$P_B \; /\; 2$'.format(peak_period/2),linewidth=0.8)
ax.axvline(alt_frequency.value, color='#27ae60', linestyle='solid', 
           label='{:0.3f} (maybe $P_A$)'.format(alt_period),linewidth=0.8)
ax.axhline(med.mean(), color='#f1c40f', linestyle='dashed', label='Guessed Noise Floor', zorder=-1)

#plt.plot(pg.frequency, med, color='#95a5a6')

pg_noise.plot(ax=ax, scale='log', label='Noise draw', color='#e67e22')

plt.fill_between(pg.frequency, lo, hi, color='#f39c12', alpha=0.2, zorder=0)
#pg_noise.plot(ax=ax, scale='log', label='Noise Draw', alpha=0.5)
plt.legend(loc='best')
ax.set_xlim(pg.frequency[0].value, pg.frequency[-1].value);

The lightkurve power scale factor is:  
$$ \tilde P_{lk} = P_{lk} \cdot \frac{2 T}{N}$$

Where the tilde represents the rescaled, and  
$N$ is the number of samples  
$T$ is the total observation window duration, in say, days or $\frac{1}{\mathrm{Hz}}$

The *celerité* power scale **expects** a rescaling of Lomb Scargle power:
$$ \tilde P_{LS} = P_{LS} \cdot \frac{1}{N}$$  
*assuming* the `.get_psd()` power is scaled by $ \tilde P_c = P_c \cdot \frac{2}{T}$.

So to get them to match up, we can simply divide the lightkurve power by $2T$, **or** multiply the `celerite` power $\tilde P_c$ by $2T$, yielding: 

$ \hat P_c = \tilde P_c \cdot 2T = P_c \cdot \frac{2}{T}\cdot 2T =  4P_c $


My inclination is to leave lightkurve as perfect, and rescale celerite.

In [None]:
variance = np.var(lc.flux)

#### A periodic term

In [None]:
guess_period = peak_period.to(u.day).value
guess_period

In [None]:
Q_guess = 400
w0_guess = 2.0*np.pi / guess_period
S0_guess = variance /3600

bounds1 = dict(S0=(variance/1000000, variance*100000),
               Q=(3,1000000), 
               w0=(np.log(w0_guess*0.8),w0_guess*1.2))

kernel_sho = terms.SHOTerm(S0=S0_guess, Q=Q_guess, w0=w0_guess)#, bounds=bounds1)

#### A second periodic term

In [None]:
guess_period2 = guess_period / 2 # 2.5 / 24.0 * 1.0

In [None]:
Q2_guess = 100
w02_guess = 2.0*np.pi / guess_period2
S02_guess = variance /9000

bounds_sho2 = dict(log_S0=(np.log(variance/10000), np.log(variance*1000)),
               log_Q=(np.log(3), np.log(2000)), 
               log_omega0=(np.log(w02_guess*0.8),np.log(w02_guess*1.2)))

kernel_sho2 = terms.SHOTerm(S0=S02_guess, Q=Q2_guess, w0=w02_guess)#, bounds=bounds_sho2)

#### A Matern term

>   log_sigma (float): The log of the parameter $\sigma$.  

>    log_rho (float): The log of the parameter $\rho$.   
    
>    eps (Optional[float]): The value of the parameter $\epsilon$.   
        (default: `0.01`)

In [None]:
sigma_guess = np.sqrt(variance)/10
rho_guess =  guess_period2 / 10.0

kernel_mat = terms.Matern32Term(sigma=sigma_guess, rho=rho_guess)

#### A Jitter term

In [None]:
#kernel_jit = terms.JitterTerm(log_sigma=np.log(lc.flux_err.mean()))

#### Compute the GP

In [None]:
net_kernel = kernel_sho + kernel_sho2 + kernel_mat #+ kernel_jit
gp = celerite2.GaussianProcess(net_kernel, fit_mean=True, mean=lc.flux.value.mean())
gp.compute(lc.time.value, yerr=lc.flux_err.value)

In [None]:
f = pg.frequency.value

In [None]:
power_true = net_kernel.get_psd(2*np.pi*f) * 4
power_sho1 = kernel_sho.get_psd(2*np.pi*f) * 4
power_sho2 = kernel_sho2.get_psd(2*np.pi*f) * 4
power_mat = kernel_mat.get_psd(2*np.pi*f) * 4

In [None]:
flux_draw = gp.sample(include_mean=True)

In [None]:
ax = lc.plot()
ax.plot(lc.time.value, gp.sample()+0.2);

Plausibly in the same ballpark--- good enough for an initial guess.

In [None]:
lc_draw = lc.copy()
lc_draw.flux = gp.sample() * lc.flux.unit
pg_draw = lc_draw.to_periodogram(normalization='psd', freq_unit=1/u.day, oversample_factor=10)

In [None]:
ax = pg.plot(scale='log')
ax.axvline(1.0/guess_period, color='#ddaaaa', linestyle='dotted', label='{:0.3f} d'.format(guess_period), alpha=1)
pg_draw.plot(ax=ax, label='GP Draw', scale='log')
ax.step(pg.frequency, power_true, color='#f39c12', lw=2,label="Analytic model", where='mid', zorder=0)

ax.plot(f, power_sho1, color='#f39c12', lw=1,label="SHO 1", linestyle='--')
ax.plot(f, power_sho2, color='#f39c12', lw=1,label="SHO 2", linestyle=':')
ax.plot(f, power_mat, color='#f39c12', lw=1,label="Matern", linestyle='-.')

plt.ylim(med.mean()/3)
ax.set_xlim(pg.frequency[0].value, pg.frequency[-1].value);
plt.legend(loc='best')
ax.set_xlim(2, 13)
plt.title('Initial Guess PSD');

Awesome!  Let's spot-check our results by making a draw from the model and then computing as if it were data.

In [None]:
y = lc.flux.value

In [None]:
# Set up the GP model

print("Initial log-likelihood: {0}".format(gp.log_likelihood(y)))

Define a likelihood function

In [None]:
def neg_log_like(params, y, gp1):
    gp1.set_parameter_vector(params)
    return -gp1.log_likelihood(y)

def grad_neg_log_like(params, y, gp1):
    gp1.set_parameter_vector(params)
    return -gp1.grad_log_likelihood(y)[1]

New for celerite2

In [None]:
def set_params(params, gp):
    gp.mean = params[0]
    theta = np.exp(params[1:])
    gp.kernel = (terms.SHOTerm(S0=theta[0], Q=theta[1], w0=theta[2]) + 
                 terms.SHOTerm(S0=theta[3], Q=theta[4], w0=theta[5]) +
                 terms.Matern32Term(sigma=theta[6], rho=theta[7])    )
    return gp

def neg_log_like(params, gp):
    gp = set_params(params, gp)
    gp.recompute(quiet=True)
    return -gp.log_likelihood(y)

In [None]:
from scipy.optimize import minimize

### Refine the GP parameters with optimization

In [None]:
initial_params = [1.0, 
                  np.log(S0_guess), np.log(Q_guess), np.log(w0_guess),
                  np.log(S02_guess), np.log(Q2_guess), np.log(w02_guess),
                  np.log(sigma_guess), np.log(rho_guess)
                 ]
soln = minimize(neg_log_like, initial_params, method="L-BFGS-B", args=(gp,))
opt_gp = set_params(soln.x, gp)

 Spot check the optimization results.

In [None]:
print("Final log-likelihood: {0}".format(-soln.fun))

In [None]:
t_pred = np.linspace(lc.time.value[0], lc.time.value[-1], num=10000)

In [None]:
# Make the maximum likelihood prediction
mu, var = opt_gp.predict(y, t_pred, return_var=True)
std = np.sqrt(var)

In [None]:
ax = lc.plot(drawstyle='steps-mid')
ax.step(lc.time.value, gp.sample()+0.1, label='GP Sample', lw=1, linestyle=':')
ax.step(t_pred, mu, label='mean prediction', alpha=1, linestyle='dashed')
ax.fill_between(t_pred, mu-std, mu+std, label='Confidence region', alpha=0.3)
ax.legend();

In [None]:
ax = lc.plot(drawstyle='steps-mid', linewidth=2)
#ax.step(lc.time.value, gp.sample(), label='GP Sample', lw=1)
ax.step(t_pred, mu, label='mean prediction', alpha=1, linestyle='dashed')
ax.fill_between(t_pred, mu-std, mu+std, label='Confidence region', alpha=0.3)
#ax.set_xlim(1590, 1590+1)
#ax.set_ylim(0.992, 0.998)
plt.legend();

In [None]:
this_power = gp.kernel.get_psd(2*np.pi*f) * 4

In [None]:
pg_analytic = pg_noise.copy()

In [None]:
pg_analytic.power = this_power*pg_noise.power.unit

Improve the $S/N$ on the periodogram of GP draws.

In [None]:
%%time
lc_draw = lc.copy()
many_draw = []
for i in range(100):
    lc_draw.flux = gp.sample()
    pg_draw = lc_draw.to_periodogram(normalization='psd', freq_unit=1/u.day, oversample_factor=10)
    many_draw.append(pg_draw.power)
    
pg_draw.power = np.median(np.array(many_draw), axis=0)*pg_draw.power.unit

In [None]:
ax = pg.plot(scale='log')
ax.axvline(1.0/guess_period, color='#ddaaaa', linestyle='dotted', label='{:0.3f} d'.format(guess_period), alpha=1)
pg_draw.plot(ax=ax, label='GP Draw', scale='log')
ax.step(pg.frequency, this_power, color='#f39c12', lw=2,label="Analytic model", where='mid', zorder=0)

#ax.plot(f, power_sho1, color='#f39c12', lw=1,label="SHO 1", linestyle='--')
#ax.plot(f, power_sho2, color='#f39c12', lw=1,label="SHO 2", linestyle=':')
#ax.plot(f, power_mat, color='#f39c12', lw=1,label="Matern", linestyle='-.')

plt.ylim(med.mean()/3)
ax.set_xlim(pg.frequency[0].value, pg.frequency[-1].value);
plt.legend(loc='best')
ax.set_xlim(2, 13)
plt.title('Posterior PSD');

# Figure for paper

Import the IGRINS Epochs

In [None]:
import pandas as pd

In [None]:
df = pd.read_csv('../../data/IGRINS/2021A_metadata_log_CORRECT.csv')

In [None]:
df

In [None]:
binary_color_dict = {'A':'#3498db', 'B':'#27ae60'}
binary_line_dict = {'A':'dotted', 'B':'dashed'}

In [None]:
ax = lc.scatter(alpha=1, label=None, marker='o', fc='k', ec='k', s=15)
lc.scatter(ax=ax, alpha=1, label='TESS Sector 36', marker='o', fc='#f1c40f', ec=None, s=10)
ax.step(t_pred, mu, label=None, alpha=0.6, linestyle='solid', color='#2980b9', linewidth=2.0, zorder=10)
ax.step(t_pred, mu, label='Trendline', alpha=1, linestyle='solid', color='#2c3e50', linewidth=0.8, zorder=10)
#ax.fill_between(t_pred, mu-std, mu+std, label='Confidence region', alpha=1, color='#95a5a6', zorder=0)
ax.set_xlim(2285, 2292.5)
#ax.set_ylim(0.992, 0.998)
ax.figure.set_size_inches(9, 4)
ax.set_xlabel('Time [BTJD]')

for i in range(8):
    AorB = df['Binary Component'][i]
    if AorB == 'B':
        ax.axvline(df.BTJD[i], color=binary_color_dict[AorB], linestyle=binary_line_dict[AorB])

ax.axvline(-100, color=binary_color_dict[AorB], linestyle=binary_line_dict[AorB], label='IGRINS visits')
plt.legend(loc='best',ncol=3);

plt.savefig('../../document/paper1/figures/TESS_S36_O1_IGRINS_overlay.png', bbox_inches='tight', dpi=300)

In [None]:
ax = lc_s36.scatter(label='Sector 36', marker='o', alpha=0.1, ec='k', fc='#d35400', s=10,)
lc_s37.plot(ax=ax)#scatter(ax=ax, label='Sector 37', marker='o', alpha=0.5, ec='k', fc='r')
lc_s10_shifted.plot(ax=ax)#scatter(ax=ax, label='Sector 10 (Time shifted)', marker='o', alpha=0.5, ec='k', fc='g')
ax.legend(loc='best', ncol=3)
ax.set_xlabel('Time [BTJD]')
#ax.set_ylim(0.85, 1.1)