In [None]:
# Standard library
from collections import OrderedDict
import os

# Third-party
import astropy.time as atime
from astropy import log as logger
import astropy.units as u
import h5py
import matplotlib.pyplot as plt
from matplotlib import gridspec
%matplotlib inline
import numpy as np
import corner

# Project
from thejoker import Paths
paths = Paths()
from thejoker.data import RVData
from thejoker.units import usys
from thejoker.celestialmechanics import OrbitalParams, SimulatedRVOrbit, rv_from_elements
from thejoker.plot import plot_rv_curves, plot_corner

plt.style.use('../thejoker/thejoker.mplstyle')

In [None]:
data_filename = "../data/experiment4.h5"

plot_path = "../paper/figures/"
os.makedirs(plot_path, exist_ok=True)

## Read the data

In [None]:
all_data = OrderedDict()
with h5py.File(data_filename, 'r') as f:
    apogee_id = f.attrs['APOGEE_ID']
    for key in sorted([int(x) for x in f.keys()]):
        all_data[key] = RVData.from_hdf5(f[str(key)])

In [None]:
fig,axes = plt.subplots(len(all_data), 1, figsize=(5,10), 
                        sharex=True, sharey=True)   

for i,key in enumerate(list(all_data.keys())):
    data = all_data[key]
    data.plot(ax=axes[i], rv_unit=u.km/u.s, markersize=3, capsize=0,
              elinewidth=0, color='k', zorder=1000, alpha=0.75)
        
fig.tight_layout()

In [None]:
dmjd = all_data[2].t.tcb.mjd.max() - all_data[2].t.tcb.mjd.min()
t_grid = np.linspace(all_data[2].t.tcb.mjd.min() -0.25*dmjd, 
                     all_data[2].t.tcb.mjd.max() + 0.25*dmjd, 
                     1024)

rv_unit = u.km/u.s

In [None]:
fig = plt.figure(figsize=(10,12))
gs = gridspec.GridSpec(4, 3)

axes_l = []
axes_r = []
for i,key in enumerate(all_data.keys()):
    if len(axes_l) == 0:
        ax_l = fig.add_subplot(gs[i, :2])
        ax_r = fig.add_subplot(gs[i, 2])
    else:
        ax_l = fig.add_subplot(gs[i, :2], sharex=axes_l[0], sharey=axes_l[0])
        ax_r = fig.add_subplot(gs[i, 2], sharex=axes_r[0], sharey=axes_r[0])
    axes_l.append(ax_l)
    axes_r.append(ax_r)

    data = all_data[key]
    
    _mean_s = np.mean(data.stddev.to(u.km/u.s).value)
    sigma_str = r"$\left\langle \sigma_n \right\rangle = {:.1f}$ ${{\rm km}}\,{{\rm s}}^{{-1}}$".format(_mean_s)

    # read the orbital parameters
    samples_filename = "../cache/experiment4-{}.h5".format(key)
    with h5py.File(samples_filename, 'r') as g:            
        joker_opars = OrbitalParams.from_hdf5(g)
#         if 'emcee' in g and j ==0 and i < 6:
#             opars = OrbitalParams.from_hdf5(g['emcee'])
#         else:
        opars = joker_opars
        samples = opars.pack(plot_units=True)

    plot_rv_curves(opars[:128], t_grid, t_offset=data.t_offset, rv_unit=rv_unit, 
                   ax=ax_l, plot_kwargs={'color': '#888888', 'zorder': -100},
                   add_labels=False)
    ax_l.set_rasterization_zorder(-1)
    data.plot(ax=ax_l, rv_unit=u.km/u.s, markersize=3, capsize=0,
              elinewidth=0, color='k', zorder=1000, alpha=0.75)
    ax_l.text(55650, -9, r"{}, M={}".format(sigma_str, len(joker_opars)), 
              va='top', fontsize=18)

    # plot markers on right panel
    _n = len(opars)
    y1,y2 = 0.75, 0.1
    alpha = max(y2, (y2-y1) * (np.log(_n) - np.log(219)) / (np.log(50000)-np.log(219)) + y1)

    y1,y2 = 3,1
    size = max(y2, int((y2-y1) * (np.log(_n) - np.log(219)) / (np.log(50000)-np.log(219)) + y1))
    
    print(_n, size, alpha)

    style = dict(alpha=alpha, marker='.', markersize=size, linestyle='none', 
                 rasterized=True, color='#888888')
    ax_r.plot(samples[:,0], samples[:,2], **style)

    if i < (len(all_data)-1):
        plt.setp(ax_l.get_xticklabels(), visible=False)
        plt.setp(ax_r.get_xticklabels(), visible=False)

    ax_l.set_ylabel('RV [km s$^{-1}$]')
    ax_r.set_ylabel(OrbitalParams._latex_labels[2])

axes_l[-1].set_xlabel('BMJD')
axes_r[-1].set_xlabel(OrbitalParams._latex_labels[0])

axes_l[0].set_xlim(t_grid.min(), t_grid.max())
axes_l[0].set_ylim(-55, -5)

axes_r[0].set_xlim(2.5,7.5)
axes_r[0].set_ylim(-0.025, 1.025)

fig.tight_layout()

fig.tight_layout()
fig.subplots_adjust(top=0.95, hspace=0.15)
fig.suptitle("Experiment 4: {}".format(apogee_id), fontsize=26, y=0.99)

# change to: dpi=256 for production?
fig.savefig(os.path.join(plot_path, 'exp4-rv-curves.pdf'), dpi=100) 