In [None]:
from os import path

from astropy.constants import G
from astropy.table import Table
import astropy.units as u
import h5py
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from tqdm import tqdm
from scipy.stats import beta
from scipy.special import logsumexp
from scipy.optimize import minimize

from hq.config import HQ_CACHE_PATH
from hq.db import db_connect, AllStar, StarResult, Status, JokerRun
from hq.io import load_samples
from hq.plot import plot_two_panel, plot_phase_fold

from thejoker.plot import plot_rv_curves

In [None]:
run_name = 'apogee-r10-l31c-58297'
NSAMPLES = 256

In [None]:
f = h5py.File(path.join(HQ_CACHE_PATH, '{0}.hdf5'.format(run_name)), 'r')

In [None]:
def get_samples(star, data, run):
    # with h5py.File(path.join(HQ_CACHE_PATH, '{0}.hdf5'.format(run.name)), 'r') as f:
    samples = load_samples(f[star.apogee_id], poly_trend=run.poly_trend, t0=data.t0)
        
    return samples

In [None]:
Session, engine = db_connect(path.join(HQ_CACHE_PATH, '{}.sqlite'.format(run_name)))
s = Session()

In [None]:
run = s.query(JokerRun).filter(JokerRun.name == run_name).limit(1).one()

In [None]:
status_ids = np.sort([x[0] for x in s.query(Status.id).distinct().all()])
for i in status_ids:
    N = s.query(AllStar).join(StarResult, Status, JokerRun)\
                        .filter(Status.id == i)\
                        .filter(JokerRun.name == run.name)\
                        .group_by(AllStar.apogee_id)\
                        .distinct().count()
    msg = s.query(Status).filter(Status.id == i).limit(1).one().message
    print("Status {0} ({2}) : {1}".format(i, N, msg))

In [None]:
stars = s.query(AllStar).join(StarResult, Status, JokerRun)\
                        .filter(Status.id == 4)\
                        .filter(JokerRun.name == run.name)\
                        .filter(AllStar.logg > 1.5)\
                        .filter(AllStar.logg < 4.)\
                        .filter(~AllStar.starflags.like('%BRIGHT_NEIGHBOR%'))\
                        .filter(~AllStar.starflags.like('%STAR_WARN%'))\
                        .filter(~AllStar.starflags.like('%ATMOS%'))\
                        .filter(~AllStar.aspcapflags.like('%ATMOS%'))\
                        .group_by(AllStar.apogee_id).distinct().all()

len(stars)

In [None]:
all_ecc = np.full((len(stars), NSAMPLES), np.nan)

i = 0
for star in tqdm(stars):
    data = star.get_rvdata()
    samples = get_samples(star, data, run)
    P_mask = (samples['P'] > 30*u.day) & (samples['P'] < 365*u.day)
    all_ecc[i, :P_mask.sum()] = samples['e'][P_mask]
    i += 1

In [None]:
all_feh = np.array([star.fe_h for star in stars])

Number of samples for each star:

In [None]:
K_n = np.isfinite(all_ecc).sum(axis=-1)

K_thresh = 16
mask = (K_n >= K_thresh) & (all_feh > -0.5) & (all_feh < 0.5)
ecc = all_ecc[mask]
feh = all_feh[mask]
K_n = K_n[mask]

In [None]:
plt.hist(K_n)
plt.yscale('log')
plt.xlabel('$K_n$')

Re-compute the prior probability at the location of each sample:

In [None]:
ln_p0 = beta.logpdf(ecc, a=0.867, b=3.03)

For nulling out the probability for samples that don't exist:

In [None]:
# for nulling out the probability for non-existing samples
mask = np.zeros_like(ecc)
mask[np.isnan(ecc)] = -np.inf

Class for evaluating the log-posterior of the hierarchical model:

In [None]:
def ab_to_uv(a, b):
    U = a / (a+b)
    V = a + b
    return np.array([U, V])

def uv_to_ab(U, V):
    b = V * (1 - U)
    a = V - b
    return np.array([a, b])

In [None]:
class Model:
    
    def __init__(self, y_nk, K_n, ln_p0):
        self.y = y_nk
        self.K = K_n
        self.ln_p0 = ln_p0

    def ln_likelihood(self, **kw):
        delta_ln_prior = beta.logpdf(self.y, a=kw['a'], b=kw['b']) - self.ln_p0
        delta_ln_prior[~np.isfinite(delta_ln_prior)] = -np.inf
        return logsumexp(delta_ln_prior, axis=1) - np.log(self.K)
    
    def ln_prior(self, **kw):
        lp = 0.
        
        if not 0 <= kw['u'] <= 1:
            return -np.inf
        
        if not 0.1 < kw['v'] < 10:
            return -np.inf
        
        lp += -np.log(kw['v'])
        
        return lp
    
    def unpack_pars(self, pars):
        a, b = uv_to_ab(*pars)
        return {'u': pars[0], 'v': pars[1],
                'a': a, 'b': b}
    
    def pack_pars(self, a, b):
        return np.array(ab_to_uv(a, b))

    def ln_prob(self, pars_vec):
        pars_kw = self.unpack_pars(pars_vec)
        
        lp = self.ln_prior(**pars_kw)
        if not np.isfinite(lp):
            return -np.inf

        ll_n = self.ln_likelihood(**pars_kw)
        if not np.all(np.isfinite(ll_n)):
            return -np.inf

        return np.sum(ll_n)
    
    def __call__(self, p):
        return self.ln_prob(p)

In [None]:
prior_ab = [0.867, 5.03] # initialize at Kipping values (our prior)

mm = Model(ecc, K_n, ln_p0)
p0 = ab_to_uv(*prior_ab)
mm(p0)

In [None]:
%timeit mm(p0)

In [None]:
# res = minimize(mm, x0=p0, method='powell')
# mm(res.x)
# uv_to_ab(*res.x)

---

In [None]:
import emcee

In [None]:
nwalkers = 24
all_p0 = emcee.utils.sample_ball(p0, [1e-3, 1e-3], size=nwalkers)

In [None]:
mm = Model(ecc, K_n, ln_p0)
sampler = emcee.EnsembleSampler(nwalkers=nwalkers, ndim=2, 
                                log_prob_fn=mm)
_ = sampler.run_mcmc(all_p0, 128, progress=True)

In [None]:
fig, axes = plt.subplots(2, 1, figsize=(6, 8), sharex=True)

for k in range(sampler.chain.shape[-1]):
    for walker in sampler.chain[..., k]:
        axes[k].plot(walker, marker='', 
                     drawstyle='steps-mid', alpha=0.4, color='k')        
fig.tight_layout()

In [None]:
fig = plt.figure(figsize=(8, 6))

_x = np.linspace(0, 1, 128)

plt.plot(_x, beta.pdf(_x, a=prior_ab[0], b=prior_ab[1]), 
         marker='', zorder=-100, label='prior')
    
plt.xlabel('eccentricity, $e$')
plt.ylabel('$p(e)$')
plt.legend(loc='upper right', fontsize=18)
plt.title('inferred eccentricity distribution', fontsize=18)
plt.xlim(0, 1)
plt.savefig('../plots/p_e_prior.png', dpi=250)

In [None]:
fig = plt.figure(figsize=(8, 6))

_x = np.linspace(0, 1, 128)

plt.plot(_x, beta.pdf(_x, a=prior_ab[0], b=prior_ab[1]), 
         marker='', zorder=-100, label='prior')

for i in range(nwalkers):
    a, b = uv_to_ab(*sampler.chain[i, -1])
    
    if i == 0:
        plt.plot(_x, beta.pdf(_x, a=a, b=b),
                 color='k', alpha=0.2, marker='', 
                 label='posterior samples')
    else:
        plt.plot(_x, beta.pdf(_x, a=a, b=b),
                 color='k', alpha=0.2, marker='')
    
plt.xlabel('eccentricity, $e$')
plt.ylabel('$p(e)$')
plt.legend(loc='upper right', fontsize=18)
plt.title('inferred eccentricity distribution', fontsize=18)
plt.xlim(0, 1)
plt.savefig('../plots/p_e.png', dpi=250)