In [None]:
# Standard library
from collections import OrderedDict
from os.path import join, exists

# Third-party
import corner
import astropy.units as u
import h5py
from matplotlib import gridspec
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np

# Project
from thejoker import Paths
paths = Paths()
from thejoker.data import RVData
from thejoker.celestialmechanics import OrbitalParams
from thejoker.plot import plot_rv_curves, _truth_color

plt.style.use('../thejoker/thejoker.mplstyle')

from figurehelpers import samples_units, make_rv_curve_figure, apw_corner

In [None]:
dpi = 256

In [None]:
# experiment number
e_number = 3
e_name = 'numpts'

In [None]:
data_filename = "../data/experiment{}.h5".format(e_number)

In [None]:
labels = OrbitalParams.get_labels(samples_units)

## Read the data

In [None]:
all_data = OrderedDict()
all_pars = OrderedDict()
all_samples = OrderedDict()
emcee_samples = OrderedDict()
with h5py.File(data_filename, 'r') as f:
    for key in sorted([int(x) for x in f.keys() if x != 'truth'], reverse=True):
        all_data[key] = RVData.from_hdf5(f[str(key)])
    
        _path = join(paths.cache, 'experiment{}-{}.h5'.format(e_number, key))
        if not exists(_path): continue
        with h5py.File(_path, 'r') as g:
            all_pars[key] = OrbitalParams.from_hdf5(g)
            all_samples[key] = all_pars[key].pack(plot_transform=True, units=samples_units)

            if 'emcee' in g:
                emcee_pars = OrbitalParams.from_hdf5(g['emcee'])
                emcee_samples[key] = emcee_pars.pack(plot_transform=True, units=samples_units)
            
        print(len(all_data[key]), all_samples[key].shape)
    
    truth_pars = OrbitalParams.from_hdf5(f['truth'])
    truth_vec = truth_pars.pack(plot_transform=True, units=samples_units)

In [None]:
fig,axes = plt.subplots(len(all_data), 1, figsize=(5,10), 
                        sharex=True, sharey=True)   

t_grid = np.linspace(all_data[11].t.tcb.mjd.min()-50, all_data[11].t.tcb.mjd.max()+50, 1024)

for i,key in enumerate(list(all_data.keys())):
    data = all_data[key]
    
    axes[i].plot(t_grid, truth_pars.rv_orbit()(t_grid), marker='', color='g', linestyle='--')
    data.plot(ax=axes[i], rv_unit=u.km/u.s, markersize=3, capsize=0,
              elinewidth=0, color='k', zorder=1000, alpha=0.75)
        
    if key > 0:
        not_included = np.array([x not in data._rv for x in all_data[key]._rv])
        reject_data = all_data[key][not_included]

        axes[i].plot(reject_data.t.tcb.mjd, reject_data.rv.to(u.km/u.s).value, 
                     marker='x', zorder=100, linestyle='none',
                     markersize=4, markeredgewidth=1, color='#de2d26')
        
fig.tight_layout()

In [None]:
dmjd = all_data[11].t.tcb.mjd.max() - all_data[11].t.tcb.mjd.min()
t_grid = np.linspace(all_data[11].t.tcb.mjd.min() - 0.25*dmjd, 
                     all_data[11].t.tcb.mjd.max() + 0.25*dmjd, 
                     1024)

### Two-column, rv curves + period-eccentricity

In [None]:
fig = plt.figure(figsize=(9.5,12))
gs = gridspec.GridSpec(5,3)

axes_l = []
axes_r = []
for i,key in enumerate(all_data.keys()):
    if len(axes_l) == 0:
        ax_l = fig.add_subplot(gs[i, :2])
        ax_r = fig.add_subplot(gs[i, 2])
    else:
        ax_l = fig.add_subplot(gs[i, :2], sharex=axes_l[0], sharey=axes_l[0])
        ax_r = fig.add_subplot(gs[i, 2], sharex=axes_r[0], sharey=axes_r[0])
    axes_l.append(ax_l)
    axes_r.append(ax_r)

    data = all_data[key]
    pars = all_pars[key]
    samples = all_samples[key]
    
    plot_rv_curves(pars[:128], t_grid, rv_unit=u.km/u.s, 
                   ax=ax_l, plot_kwargs={'color': '#aaaaaa', 'zorder': -100, 'marker': ''},
                   add_labels=False)
    ax_l.set_rasterization_zorder(-1)
    data.plot(ax=ax_l, rv_unit=u.km/u.s, markersize=4, capsize=0,
              elinewidth=0, color='k', zorder=1000, alpha=0.75)
    ax_l.text(55400, 40, r"$N={}$, $M={}$".format(len(data), len(samples)), 
              va='top', fontsize=16)

    # plot markers on right panel
    _n = len(pars)
    if key in emcee_samples:
        _n += len(emcee_samples[key])
        
    _tmp = np.log(max([len(d) for d in all_pars.values()]))
    y1,y2 = 0.75, 0.05
    alpha = (y2-y1) * (np.log(_n) - np.log(128)) / (_tmp-np.log(128)) + y1

    y1,y2 = 4,1
    size = max(1, int((y2-y1) * (np.log(_n) - np.log(128)) / (_tmp-np.log(128)) + y1))

    style = dict(alpha=alpha, marker='.', markersize=size, linestyle='none', 
                 rasterized=True, color='#888888')
    ax_r.plot(samples[:,0], samples[:,1], **style)
    if key in emcee_samples:
        ax_r.plot(emcee_samples[key][:,0], emcee_samples[key][:,1], **style)

    if i < 4:
        print("hide", i)
        plt.setp(ax_l.get_xticklabels(), visible=False)
        plt.setp(ax_r.get_xticklabels(), visible=False)

    ax_r.set_ylabel(labels[1])

axes_l[0].set_ylabel('RV [km s$^{-1}$]')
axes_l[-1].set_xlabel('BMJD')
axes_r[-1].set_xlabel(labels[0])

axes_l[0].set_xlim(t_grid.min(), t_grid.max())
axes_l[0].set_ylim(-5, 45)

axes_r[0].set_xlim(2.5, 7.5)
axes_r[0].set_ylim(-0.025, 1)

# hack:
# if j == 0:
#     axes_r[0].yaxis.set_ticks([0, 0.1, 0.2])
#     axes_r[0].xaxis.set_ticks([5.1,5.2,5.3,5.4])

fig.tight_layout()
fig.subplots_adjust(top=0.95, hspace=0.15)
fig.suptitle("Experiment {}".format(e_number), fontsize=24)

# change to: dpi=256 for production?
fig.savefig(join(paths.figures, '{}-rv-curves.pdf'.format(e_name)), dpi=dpi) 