In [None]:
from os import path
import os

import astropy.coordinates as coord
from astropy.constants import G
from astropy.io import fits
from astropy.table import Table, QTable, join
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from tqdm import tqdm
import emcee
import yaml

from hq.config import HQ_CACHE_PATH, config_to_alldata
from hq.plot import plot_two_panel, plot_phase_fold
from hq.data import get_rvdata
from hq.physics_helpers import period_at_surface, stellar_radius
from hq.hierarchical.period_ecc import Model

from thejoker.plot import plot_rv_curves
from thejoker import JokerSamples, JokerParams, TheJoker

from scipy.special import logsumexp
from scipy.stats import beta, truncnorm
from scipy.optimize import minimize

In [None]:
samples_path = path.join(HQ_CACHE_PATH, 'dr16/samples')
metadata = QTable.read(path.join(HQ_CACHE_PATH, 'dr16/metadata-master.fits'))

with open(path.join(HQ_CACHE_PATH, "dr16/config.yml"), "r") as f:
    config = yaml.safe_load(f.read())
allstar, allvisit = config_to_alldata(config)

In [None]:
metadata = join(metadata, allstar, keys='APOGEE_ID')

In [None]:
llr_mask = (metadata['max_unmarginalized_ln_likelihood'] - metadata['robust_constant_ln_likelihood']) > 6
# qual_mask = (metadata['unimodal'] | 
#              (metadata['joker_completed'] & (metadata['n_visits'] >= 4)))
qual_mask = (metadata['unimodal'] | (metadata['n_visits'] >= 4))
qual_mask &= (metadata['LOGG'] > -0.5) & (metadata['TEFF'] > 3200)

In [None]:
llr_mask.sum(), (llr_mask & qual_mask).sum()

In [None]:
def logg_f(teff):
    slope = -0.1 / 200
    pt = (5500, 4.)
    
    teff = np.array(teff)
    teff_crit = 5200
    val1 = slope * (teff - pt[0]) + pt[1]
    val2 = slope * (teff_crit - pt[0]) + pt[1]
    
    mask = teff > teff_crit
    ret = np.zeros_like(teff)
    ret[mask] = val1[mask]
    ret[~mask] = val2
    return ret

def rg_f(teff):
    slope = 0.25 / 100
    pt = (4800, 4.)
    return slope * (teff - pt[0]) + pt[1]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(7, 6))

ax.plot(metadata['TEFF'][llr_mask & qual_mask],
        metadata['LOGG'][llr_mask & qual_mask],
        marker='o', ls='none', ms=1.5, mew=0, 
        color='k', alpha=0.2)

ax.set_xlim(7000, 3500)
ax.set_ylim(5.5, 0)

_x = np.linspace(3500, 7000, 1024)
ax.plot(_x, logg_f(_x), marker='')
ax.plot(_x, rg_f(_x), marker='')

ax.set_xlabel(r'$T_{\rm eff}$ [K]')
ax.set_ylabel(r'$\log g$')

fig.set_facecolor('w')

In [None]:
ms_mask = metadata['LOGG'] > logg_f(metadata['TEFF'])
rg_mask = ~ms_mask & (metadata['LOGG'] < rg_f(metadata['TEFF'])) & (metadata['TEFF'] < 5500)

ms_mask &= llr_mask & qual_mask
rg_mask &= llr_mask & qual_mask

ms_mask.sum(), rg_mask.sum()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(7, 6))

for _mask in [ms_mask, rg_mask]:
    ax.plot(metadata['TEFF'][_mask],
            metadata['LOGG'][_mask],
            marker='o', ls='none', ms=1.5, mew=0, 
            alpha=0.2)

ax.set_xlim(7000, 3200)
ax.set_ylim(5.5, -0.5)

ax.set_xlabel(r'$T_{\rm eff}$ [K]')
ax.set_ylabel(r'$\log g$')

fig.set_facecolor('w')

---

In [None]:
def load_ez_samples(apogee_ids, n_samples=256):
    ez_samples = np.full((2, mask.sum(), n_samples), np.nan)
    for n, apogee_id in enumerate(apogee_ids):
        filename = path.join(samples_path, apogee_id[:4],
                             '{}.fits.gz'.format(apogee_id))
        t = fits.getdata(filename)
        K = min(n_samples, len(t))
        ez_samples[0, n, :K] = t['e'][:K]
        ez_samples[1, n, :K] = np.log(t['P'][:K])
    
    return ez_samples

In [None]:
B1 = beta(1.5, 50.)
B2 = beta(1, 1.8)

In [None]:
logg_step = 0.25
logg_binsize = 1.5 * logg_step
logg_bins = np.arange(0, 4+1e-3, logg_step)

teff_step = 300
teff_binsize = 1.5 * teff_step
teff_bins = np.arange(3400, 7000+1e-3, teff_step)

In [None]:
ms_pars = []
for i, ctr in enumerate(teff_bins):
    l = ctr - teff_binsize / 2
    r = ctr + teff_binsize / 2
    pixel_mask = ((metadata['TEFF'] > l) & (metadata['TEFF'] <= r))
    mask = pixel_mask & ms_mask

    print('{}: loading {} stars'.format(i, mask.sum()))
    ez_samples = load_ez_samples(metadata['APOGEE_ID'][mask])
    mod = Model(ez_samples, B1=B1, B2=B2)
    p0 = mod.pack_pars({'lnk': 0., 'z0': np.log(30.), 'alpha0': 0.2,
                        'muz': np.log(100), 'lnsigz': np.log(4.)})
    
    print('starting minimize')
    res = minimize(lambda *args: -mod(*args), x0=p0, method='powell')
    ms_pars.append(mod.unpack_pars(res.x))

In [None]:
rg_pars = []
for i, ctr in enumerate(logg_bins):
    l = ctr - logg_binsize / 2
    r = ctr + logg_binsize / 2
    pixel_mask = ((metadata['LOGG'] > l) & (metadata['LOGG'] <= r))
    mask = pixel_mask & rg_mask

    print('{}: loading {} stars'.format(i, mask.sum()))
    ez_samples = load_ez_samples(metadata['APOGEE_ID'][mask])
    mod = Model(ez_samples, B1=B1, B2=B2)
    p0 = mod.pack_pars({'lnk': 0., 'z0': np.log(30.), 'alpha0': 0.2,
                        'muz': np.log(100), 'lnsigz': np.log(4.)})
    
    print('starting minimize')
    res = minimize(lambda *args: -mod(*args), x0=p0, method='powell')
    rg_pars.append(mod.unpack_pars(res.x))

---

# MCMC

In [None]:
from schwimmbad import MultiPool

In [None]:
nparams = 5
nwalkers = 8 * nparams

In [None]:
import pickle

In [None]:
ms_samplers = []
for i, ctr in enumerate(teff_bins):
    sampler_filename = 'teff_{}_sampler.pkl'.format(i)
    if not path.exists(sampler_filename): 
        l = ctr - teff_binsize / 2
        r = ctr + teff_binsize / 2
        pixel_mask = ((metadata['TEFF'] > l) & (metadata['TEFF'] <= r))
        mask = pixel_mask & ms_mask

        print('{}: loading {} stars'.format(i, mask.sum()))
        ez_samples = load_ez_samples(metadata['APOGEE_ID'][mask])
        mod = Model(ez_samples, B1=B1, B2=B2)
        p0 = mod.pack_pars({'lnk': 0., 'z0': np.log(30.), 'alpha0': 0.2,
                            'muz': np.log(100), 'lnsigz': np.log(4.)})

        print('starting emcee')
        all_p0 = emcee.utils.sample_ball(mod.pack_pars(ms_pars[i]), 
                                         [1e-3] * nparams, 
                                         size=nwalkers)

        with MultiPool() as pool:
            sampler = emcee.EnsembleSampler(nwalkers=nwalkers, 
                                            ndim=nparams, 
                                            log_prob_fn=mod, 
                                            pool=pool)
            _ = sampler.run_mcmc(all_p0, 512, progress=True)
            sampler.pool = None
        
        with open(sampler_filename, 'wb') as f:
            pickle.dump(sampler, f)
    
    else:
        with open(sampler_filename, 'rb') as f:
            sampler = pickle.load(f)

    ms_samplers.append(sampler)

In [None]:
fig, axes = plt.subplots(nparams, 1, figsize=(8, 4*nparams), 
                         sharex=True)

for k in range(sampler.chain.shape[-1]):
    for walker in sampler.chain[..., k]:
        axes[k].plot(walker, marker='', 
                     drawstyle='steps-mid', alpha=0.4, color='k')        
fig.tight_layout()

In [None]:
pp = mod.unpack_pars(sampler.chain[0, -1])

In [None]:
zgrid = np.linspace(mod._zlim[0], mod._zlim[1], 252)
egrid = np.linspace(0, 1, 256)
zz, ee = np.meshgrid(zgrid, egrid)

lnval_init = np.sum(mod.ln_ze_dens(mod.unpack_pars(p0), ee, zz), axis=0)
lnval_min = np.sum(mod.ln_ze_dens(mod.unpack_pars(res.x), ee, zz), axis=0)
lnval_emcee = np.sum(mod.ln_ze_dens(pp, ee, zz), axis=0)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 5), sharex=True, sharey=True)
axes[0].pcolormesh(zz, ee, lnval_init)
axes[1].pcolormesh(zz, ee, lnval_min)
axes[2].pcolormesh(zz, ee, lnval_emcee)
fig.tight_layout()

fig, axes = plt.subplots(1, 3, figsize=(15, 5), sharex=True, sharey=True)
axes[0].pcolormesh(zz, ee, np.exp(lnval_init))
axes[1].pcolormesh(zz, ee, np.exp(lnval_min))
axes[2].pcolormesh(zz, ee, np.exp(lnval_emcee))
fig.tight_layout()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
# ax.plot(ez_samples[1].T, ez_samples[0].T, 
ax.plot(ez_samples[1, :, 0], ez_samples[0, :, 0], 
        marker='o', ls='none', alpha=0.4);
ax.set_xlim(zgrid.min(), zgrid.max())
ax.set_ylim(0, 1)

In [None]:
plt.hist(ez_samples[0, :, 0], bins=np.linspace(0, 1, 32));
plt.plot(egrid, np.exp(lnval_emcee).sum(axis=1))