In [1]:
import numpy as np
import pandas as pd
import glob
import os
from astropy.io import fits
import matplotlib.pyplot as plt
from astropy.io import ascii
from astropy.table import Table
import itertools
from astroquery.simbad import Simbad
from utils import PLOT_PARAMS
PLOT_PARAMS()
from datetime import date
today=date.today()
DATE =today.strftime("%m_%d_%y")
print(DATE)
import pickle
import thejoker as tj
import pymc as pm
import corner
import arviz as az
plt.rcParams['figure.facecolor'] = 'white'
import warnings
warnings.filterwarnings("ignore")

05_01_25


In [3]:
rv_df = pd.read_csv('comb_rvs.csv')

In [5]:
rv_df.head(1)

Unnamed: 0,sobject_id,logg,teff,RG_id,galah_id,obj_name,esp_time,esp_rv,esp_rv_err,Li_val,...,dec_dr2,dr3_source_id,MJD_local,rv_nogr_obst,e_rv_nogr_obst,rv_galah,e_rv_galah,comb_time,comb_rv,comb_rv_err
0,150107004201104,1.792561,4190.2886,1,150107004201104,UCAC4 297-057956,"[2460341.53519178, 2460389.56581584, 2460371.5...","[53.7010931002614, 53.7985038354756, 53.715772...","[0.0356353249891873, 0.0194215970708735, 0.018...",3.41,...,-30.734294,5632958769593943424,57029.633,53.576,0.071,53.551998,0.072,[2460341.53519178 2460389.56581584 2460371.513...,[53.7010931 53.79850384 53.71577284 53.458772...,[0.03563532 0.0194216 0.01884311 0.00615718 0...


In [None]:
def test_tj(obj,Pi,Pf=1000,K0=20,error=0.1,chains=2, prior_size=50_000,save=False,MCMC=False):
    data = Table()
    data['bjd'] = obj_dir[obj]['time']
    data['rv'] = obj_dir[obj]['rv']
    print('mean error originally: %.3f km/s'%np.mean(np.array(obj_dir[obj]['err'])))
    data['rv_err'] = np.array(obj_dir[obj]['err'])#*5
    data['rv_err'] = np.array([error]*len(data['rv']))
    print('mean error now: %.3f km/s'%np.mean(data['rv_err']))
    
    t = Time(data["bjd"], format="jd", scale="tcb")
    
    rnd = np.random.default_rng(seed=42)
    
    data = tj.RVData(t=t, rv=data["rv"]*u.km/u.s, rv_err=data["rv_err"]*u.km/u.s)
    
    with pm.Model() as model:
        #P_prior = xu.with_unit(pm.Uniform("P", 0.25, 1e3), u.day)
        #M0_prior    = xu.with_unit(pm.Uniform("M0", -np.pi, np.pi), u.radian)
        #omega_prior = xu.with_unit(pm.Uniform("omega", -np.pi, np.pi), u.radian)
        
        #s_prior = xu.with_unit(pm.Normal("s", 0, 0.5), u.km/u.s)
        #K_prior = xu.with_unit(pm.Normal("K", 0, 20), u.km/u.s)
        #e_prior = xu.with_unit(pm.Uniform("e", 0, 1), u.one)
        
        prior = tj.JokerPrior.default(
                P_min = Pi * u.day, # to use Joker's default log-normal, give P-min, P-max
                P_max = Pf * u.day,
                sigma_K0 = K0 * u.km / u.s,
                sigma_v = 100 * u.km / u.s,
                #s = 0.1 * u.km / u.s,
#                 sigma_v=[100 * u.km / u.s, 
#                           0.5 * u.km / u.s / u.day, 
#                           1e-2 * u.km / u.s / u.day**2],
#                 poly_trend=3,
                #pars={'s': s_prior}
                        #'P': P_prior,
#                       #'M0':M0_prior,
#                       #'omega':omega_prior,
#                       #'e': e_prior,
#                       #'K': K_prior, 
        )

    prior_samples = prior.sample(size=prior_size, rng=rnd)
    print(prior_samples)
    
    joker = tj.TheJoker(prior, rng=rnd)

    joker_samples = joker.rejection_sample(data, prior_samples, 
                                           max_posterior_samples=256,
                                           return_all_logprobs=True)
    print(joker_samples[0])
    fig, ax = plt.subplots(1, 1, figsize=(8, 4))
    _ = tj.plot_rv_curves(
        joker_samples[0],
        data=data,
        relative_to_t_ref=True,
        ax=ax,
    )
    
    ax.set_xlabel(f"BMJD$ - {data.t.tcb.mjd.min():.3f}$")
#     ax.set_xlabel(f"BMJD$ - {data.t.min():.3f}$")
    if save: fig.savefig('../rv_fit/fit_RG%s_%s_2.png'%(obj,DATE),bbox_inches='tight',dpi=100)
    if MCMC:
        
        with prior.model:
            mcmc_init = joker.setup_mcmc(data, joker_samples[0])
            trace = pm.sample(tune=2000, draws=2000, start=mcmc_init, cores=1, chains=chains)

        az.summary(trace, var_names=prior.par_names)
        print('MCMC samples...')
        mcmc_samples = tj.JokerSamples.from_inference_data(prior, trace, data)
        mcmc_samples.wrap_K()
        df = mcmc_samples.tbl.to_pandas()
        colnames = []
        # ==== 3 body ====
        #truth = ['P','e','omega','M0','K','v0', 'v1', 'v2','s']
        #truth = ['P','e','K','v0', 'v1', 'v2','s']
        # ==== 2 body ====
        #truth = ['P','e','omega','M0','K','v0','s']
        truth = ['P','e','K','v0']
        for name in df.columns:
            if name in truth:
                colnames.append(name)
                #truths.append(truth[name].value)
        fig = corner.corner(df[colnames], 
                            labels=colnames,
                            quantiles=[0.16, 0.5, 0.84],
                            show_titles=True,
                            title_kwargs={"fontsize": 12}) #, truths=truths)
        if save: fig.savefig('../rv_fit/corner_RG%s_%s.png'%(obj, DATE),bbox_inches='tight',dpi=200)
    return prior_samples, joker_samples, mcmc_samples, trace, prior.par_names#az.summary(trace, var_names=prior.par_names)

