In [None]:
import glob
import os
from os import path

# Third-party
from astropy.constants import G
from astropy.io import fits, ascii
from astropy.stats import median_absolute_deviation
from astropy.table import Table, QTable, join
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import numpy as np
%matplotlib inline
import tqdm
from scipy.stats import beta, binned_statistic
import h5py
import yaml

from thejoker import JokerParams, JokerSamples, RVData

from twoface.samples_analysis import MAP_sample
from twoface.plot import plot_two_panel, plot_phase_fold, plot_data_orbits, plot_phase_fold_residual
from twoface.mass import get_m2_min
from twoface.ext import FileLoader
from twoface.util import config_to_jokerparams

In [None]:
data_files = glob.glob('../data/candidates/*.ecsv')
cache_path = ('../data/candidates/cache/')

In [None]:
with open('../config/bh.yml', 'r') as f:
    config = yaml.load(f.read())
joker_pars = config_to_jokerparams(config)

In [None]:
sheet = FileLoader.GoogleSheets('1z0ukn8QWJL7dTrPdCFWlhQ6r-X-kwiA4z0YRqalZsY4', 1633658731)
candidates = sheet.load()

In [None]:
def load_data_ecsv(fn):
    tbl = QTable.read(fn)
    data = RVData(t=tbl['time'],
                  rv=tbl['rv'],
                  stddev=tbl['rv_err'])
    return data, tbl['source']

In [None]:
def plot_diag_panel(data, samples, data_src):
    # Compute things we'll need in plots:
    M2_mins = []
    for i in range(min(len(samples), 256)):
        orb = samples.get_orbit(i)
        M2_mins.append(get_m2_min(M1, orb.m_f))
    M2_mins = u.Quantity(M2_mins)
    
    # Now make the plot!
    fig = plt.figure(figsize=(9, 10))
    gs = GridSpec(3, 5)
    
    ax1 = fig.add_subplot(gs[0, :])
    colormap = dict(DR15='k', LAMOST='tab:orange', Keck='tab:green')
    plot_data_orbits_kw = dict()
    plot_data_orbits_kw.setdefault('xlim_choice', 'tight')
    plot_data_orbits_kw.setdefault('highlight_P_extrema', False)
    plot_data_orbits(data, samples, ax=ax1, n_times=16384, **plot_data_orbits_kw)
    for src in np.unique(data_src):
        sub_data = data[data_src == src]
        sub_data.plot(ax=ax1, markerfacecolor='none', markeredgewidth=1, 
                      markeredgecolor=colormap[src], label=src)
    ax1.legend(loc='best')
    
    # ---
    sample = MAP_sample(data, samples, joker_pars)
    axes2 = [fig.add_subplot(gs[1, 0:3]), 
             fig.add_subplot(gs[2, 0:3])]
    plot_phase_fold_residual(data, sample, axes=axes2)
    # plt.setp(axes2[0].get_xticklabels(), visible=False)
    
    # ---
    ax3 = fig.add_subplot(gs[1, 3:])
    ax3.scatter(samples['P'].to(u.day).value, samples['e'], 
                marker='o', linewidth=0, alpha=0.5)
    ax3.set_xlim(0.8, 2500)
    ax3.set_xscale('log')
    ax3.set_ylim(0, 1)
    ax3.set_ylabel('$e$')
    # plt.setp(ax3.get_xticklabels(), visible=False)
    
    # ---
    ax4 = fig.add_subplot(gs[2, 3:])
    ax4.scatter(samples['P'].to(u.day).value[:len(M2_mins)], M2_mins.value,
                marker='o', linewidth=0, alpha=0.5)
    ax4.set_xlim(0.8, 2500)
    ax4.set_xscale('log')
    ax4.set_yscale('log')
    ax4.set_ylim(1E-2, 1E2)
    ax4.axhline(1., zorder=-10, color='#cccccc')
    ax4.axhline(M1.value, zorder=-1, color='tab:green', alpha=0.5, label='$M_1$', marker='')
    ax4.legend(loc='best')
    ax4.set_ylabel(r'$M_{2,{\rm min}}$' + ' [{0:latex_inline}]'.format(u.Msun))
    ax4.set_xlabel('$P$ [day]')
    
    for ax in [ax3, ax4]:
        ax.xaxis.set_ticks(10**np.arange(0, 3+0.1, 1))
        ax.yaxis.tick_right()
        ax.yaxis.set_label_position("right")
    
    return fig

In [None]:
for data_fn in sorted(data_files):
    if '0618' not in data_fn: continue
    apogee_id = path.splitext(path.basename(data_fn))[0]
    samples_fn = path.join(cache_path, '{0}-joker.hdf5'.format(apogee_id))
    
    if apogee_id not in candidates['APOGEE_ID']:
        continue
        
    M1 = candidates[candidates['APOGEE_ID'] == apogee_id]['M1'] * u.Msun
    
    data, src = load_data_ecsv(data_fn)
    with h5py.File(samples_fn) as f:
        samples = JokerSamples.from_hdf5(f)
    
    fig = plot_diag_panel(data, samples, src)
    fig.tight_layout()
    fig.savefig('../plots/{0}-joker.png'.format(apogee_id), dpi=250)
    plt.close(fig)