In [None]:
import os
from os import path
from astropy.time import Time
from astropy.io import fits, ascii
import astropy.units as u
from astropy.table import Table
from astropy.constants import G
from astropy.stats import median_absolute_deviation

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from matplotlib.gridspec import GridSpec
import h5py
import schwimmbad

from thejoker import JokerSamples, JokerParams, TheJoker
from thejoker.sampler.mcmc import TheJokerMCMCModel
from thejoker.plot import plot_rv_curves

from twoface.config import TWOFACE_CACHE_PATH
from twoface.db import (db_connect, AllStar, AllVisit, AllVisitToAllStar, RedClump,
                        StarResult, Status, JokerRun, initialize_db)
from twoface.data import APOGEERVData
from twoface.plot import plot_data_orbits
from twoface.samples_analysis import unimodal_P

In [None]:
samples_file = path.join(TWOFACE_CACHE_PATH, 'apogee-jitter.hdf5')

In [None]:
Session, _ = db_connect(path.join(TWOFACE_CACHE_PATH, 'apogee.sqlite'))
session = Session()

In [None]:
run = session.query(JokerRun).filter(JokerRun.name == 'apogee-jitter').one()
params = run.get_joker_params()

In [None]:
# needs mcmc
stars = session.query(AllStar).join(StarResult, JokerRun, Status)\
               .filter(JokerRun.name == 'apogee-jitter')\
               .filter(Status.id == 2).all()
len(stars)

In [None]:
star = stars[1]

data = star.apogeervdata()
with h5py.File(samples_file) as f:
    samples0 = JokerSamples.from_hdf5(f[star.apogee_id])

_ = plot_data_orbits(data, samples0, xlim_choice='tight')

In [None]:
%%time

with schwimmbad.MultiPool() as pool:
    joker = TheJoker(params, pool=pool)
    model, samples, sampler = joker.mcmc_sample(data, samples0, n_steps=32768,
                                                n_walkers=256, n_burn=1024,
                                                return_sampler=True)

In [None]:
ndim = sampler.chain.shape[-1]

fig, axes = plt.subplots(ndim, 3, figsize=(12, 16))
for k in range(ndim):
    for walker in sampler.chain[..., k]:
        axes[k, 0].plot(walker, marker='', drawstyle='steps-mid', alpha=0.1)
        
    axes[k, 1].plot(np.median(sampler.chain[..., k], axis=0),
                    marker='', drawstyle='steps-mid')
    
    # std = np.std(sampler.chain[..., k], axis=0)
    std = 1.5 * median_absolute_deviation(sampler.chain[..., k], axis=0)
    axes[k, 2].plot(std, marker='', drawstyle='steps-mid')
    
fig.tight_layout()

In [None]:
plt.scatter(samples['P'].value, samples['e'].value, alpha=0.5, linewidth=0)

In [None]:
_ = plot_data_orbits(data, samples0, xlim_choice='tight')
_ = plot_data_orbits(data, samples, xlim_choice='tight', highlight_P_extrema=False)

---

In [None]:
import astropy.units as u
from astropy.stats import median_absolute_deviation
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import glob
import pickle
from os import path

In [None]:
names = [r'$\ln P$', r'$\sqrt{K}\,\cos M_0$', r'$\sqrt{K}\,\sin M_0$', 
         r'$\sqrt{e}\,\cos \omega$', r'$\sqrt{e}\,\sin \omega$', 
         r'$\ln s^2$', '$v_0$']

In [None]:
for filename in glob.glob('../scripts/test-mcmc-*.pickle'):
    with open(filename, 'rb') as f:
        sampler = pickle.load(f)

    ndim = sampler.chain.shape[-1]

    fig, axes = plt.subplots(ndim, 3, figsize=(12, 16))
    for k in range(ndim):
        axes[k, 0].set_ylabel(names[k])
        axes[k, 0].plot(sampler.chain[..., k].T, marker='', drawstyle='steps-mid', 
                        alpha=0.1, rasterized=True)
        axes[k, 1].plot(np.median(sampler.chain[..., k], axis=0),
                        marker='', drawstyle='steps-mid')

        # std = np.std(sampler.chain[..., k], axis=0)
        std = 1.5 * median_absolute_deviation(sampler.chain[..., k], axis=0)
        axes[k, 2].plot(std, marker='', drawstyle='steps-mid')

    axes[0, 0].set_title('walkers')
    axes[0, 1].set_title('med(walkers)')
    axes[0, 2].set_title('1.5 MAD(walkers)')

    fig.tight_layout()
    fig.savefig('../scripts/{0}.png'.format(path.splitext(path.basename(filename))[0]), dpi=250)
    plt.close('all')

In [None]:
for filename in glob.glob('../scripts/test-mcmc-*.pickle'):
    *_, apogee_id = path.splitext(filename)[0].split('-')
    star = session.query(AllStar).filter(AllStar.apogee_id == apogee_id).limit(1).one()
    data = star.apogeervdata()
    model = TheJokerMCMCModel(joker_params=params, data=data)
    
    with open(filename, 'rb') as f:
        sampler = pickle.load(f)
    
    samples = model.unpack_samples_mcmc(sampler.chain[:, -1])
    samples.t0 = Time(data._t0_bmjd, format='mjd', scale='tcb')
    
    fig = plot_data_orbits(data, samples, n_orbits=256)
    fig.savefig('../scripts/{0}-samples.png'.format(apogee_id), dpi=260)
    plt.close('all')