In [None]:
import os
from os import path
from astropy.io import fits
import astropy.units as u
from astropy.table import Table
from astropy.constants import G

import numpy as np
import matplotlib.pyplot as plt
plt.style.use('apw-notebook')
%matplotlib inline
import h5py
from sqlalchemy import func
from scipy.optimize import root

from twoface.config import TWOFACE_CACHE_PATH
from twoface.db import (db_connect, AllStar, AllVisit, AllVisitToAllStar, RedClump,
                        StarResult, Status, JokerRun, initialize_db)

import emcee
from scipy.misc import logsumexp

In [None]:
TWOFACE_CACHE_PATH = path.abspath('../cache/')

In [None]:
Session, _ = db_connect(path.join(TWOFACE_CACHE_PATH, 'apogee.sqlite'))
session = Session()

In [None]:
for i in range(5):
    status = session.query(Status).filter(Status.id == i).limit(1).one()
    n = session.query(AllStar)\
               .join(StarResult, JokerRun, Status)\
               .filter(Status.id == i).count()
    print("Status: {0} ({1}) - {2}".format(status.message, status.id, n))

In [None]:
samples_file = path.join(TWOFACE_CACHE_PATH, 'apogee-jitter.hdf5')

In [None]:
K_n = []
with h5py.File(samples_file) as f:
    N = 10000
    # N = len(f.keys())
    ln_p0 = np.full((N, 128), np.inf)
    y_nk = np.zeros((N, 128))
    
    for n,key in enumerate(f):
        K = len(f[key]['jitter'])
        y_nk[n,:K] = 2*np.log(f[key]['jitter'][:] * 1000.) # km/s to m/s
        ln_p0[n,:K] = f[key]['ln_prior_probs'][:]
        K_n.append(K)
        
        if n >= (N-1): 
            break
            
        elif n % 1000 == 0:
            print(n)    

K_n = np.array(K_n)

mask = np.zeros_like(y_nk)
mask[y_nk == 0] = -np.inf

In [None]:
plt.hist(np.ravel(y_nk), bins=np.linspace(-8, 8, 32));

In [None]:
def ln_normal(x, mu, var):
    return -0.5*((x-mu)**2 / var + np.log(2*np.pi*var))

In [None]:
ln_p0 = ln_normal(y_nk, 5., 4.)

def ln_prob(pars, y, K):
    mu, var = pars   
    delta_ln_prior = ln_normal(y, mu, var) - ln_p0 + mask
    return np.sum(logsumexp(delta_ln_prior, axis=1) - np.log(K))

In [None]:
derp = ln_prob([4, 4], y_nk, K_n)
derp

In [None]:
lls = []
vals = np.linspace(-10, 10, 128)
for val in vals:
    lls.append(ln_prob([val,5], y_nk, K_n))

In [None]:
plt.plot(vals, lls)

In [None]:
%timeit ln_prob([5, 4], all_l, all_K)

In [None]:
nwalkers = 16
p0 = np.random.normal([5, 4], [1E-3, 1E-3], size=(nwalkers, 2))

In [None]:
sampler = emcee.EnsembleSampler(nwalkers, dim=2, lnpostfn=ln_prob,
                                args=(all_l, all_ln_prior, all_K))
pos,*_ = sampler.run_mcmc(p0, 256)
# sampler.reset()
# _ = sampler.run_mcmc(pos, 512)

In [None]:
sampler.chain.shape

In [None]:
for dim in range(sampler.dim):
    plt.figure()
    for walker in sampler.chain[...,dim]:
        plt.plot(walker, marker='', linestyle='-', color='k', 
                 alpha=0.2, drawstyle='steps-mid')