In [None]:
import glob
from os import path
import os
import sys
path_ = path.abspath('../scripts/')
if path_ not in sys.path:
    sys.path.insert(0, path_)
import pickle
    
import astropy.coordinates as coord
from astropy.constants import G
from astropy.table import Table
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from tqdm import tqdm
from schwimmbad import MultiPool

from hq.config import HQ_CACHE_PATH, config_to_alldata
from hq.plot import plot_two_panel, plot_phase_fold
from hq.data import get_rvdata
from hq.physics_helpers import period_at_surface, stellar_radius
from hq.log import logger

from helpers import get_metadata, get_rg_mask
from model_z import Model, lntruncnorm
from run_sampler import (logg_bincenters, teff_bincenters, mh_bincenters, 
                         logg_binsize)

In [None]:
cache_path = path.abspath('../cache/')
plot_path = path.abspath('../plots/')

In [None]:
# Load all data:
metadata = get_metadata()
rg_mask = get_rg_mask(metadata['TEFF'], metadata['LOGG'])
metadata = metadata[rg_mask]

In [None]:
from os import path
from astropy.io import fits
def get_z_samples(apogee_ids, n_samples=256):
    samples_path = path.join(HQ_CACHE_PATH, 'dr16/samples')

    z_samples = np.full((len(apogee_ids), n_samples), np.nan)
    for n, apogee_id in enumerate(apogee_ids):
        filename = path.join(samples_path, apogee_id[:4],
                             '{}.fits.gz'.format(apogee_id))
        t = fits.getdata(filename)
        K = min(n_samples, len(t))
        z_samples[n, :K] = np.log10(t['P'][:K])

    return z_samples

In [None]:
for i, ctr in enumerate(logg_bincenters[8:9]):
    l = ctr - logg_binsize / 2
    r = ctr + logg_binsize / 2
    print(l, r)
    pixel_mask = ((metadata['LOGG'] > l) & (metadata['LOGG'] <= r))

    # Load samples for this bin:
    # logger.debug("{} {}: Loading samples".format(name, i))
    z_samples = get_z_samples(metadata['APOGEE_ID'][pixel_mask])

#     # Run
#     with MultiPool() as pool:
#         run_pixel(name, i, ez_samples, '/dev/null', '/dev/null', pool,
#                   nwalkers=80)

In [None]:
from scipy.optimize import minimize
import emcee
import pickle

def run_pixel(name, i, z_samples, cache_path, plot_path, pool,
              nwalkers=80, progress=False, overwrite=False):
    min_filename = path.join(cache_path, '{}_{:02d}_res.npy'.format(name, i))
    emcee_filename = path.join(cache_path,
                               '{}_{:02d}_emcee.pkl'.format(name, i))

    # Create a model instance so we can evaluate likelihood, etc.
    nparams = 2
    mod = Model(z_samples)

    if not path.exists(min_filename) and not path.exists(emcee_filename):
        # Initial parameters for optimization
        p0 = mod.pack_pars({'muz': np.log10(10**5.), 'lnsigz': np.log(4.)})

        logger.debug("{} {}: Starting minimize".format(name, i))
        res = minimize(lambda *args: -mod(*args), x0=p0, method='powell')
        min_x = res.x
        np.save(min_filename, min_x)

    # emcee run:
    logger.debug("{} {}: Done with minimize".format(name, i))

    if not path.exists(emcee_filename) or overwrite:
        min_x = np.load(min_filename)

        # initialization for all walkers
        all_p0 = emcee.utils.sample_ball(min_x, [1e-3] * nparams,
                                         size=nwalkers)
        
        print("HERE")
        sampler = emcee.EnsembleSampler(nwalkers=nwalkers,
                                        ndim=nparams,
                                        log_prob_fn=mod,
                                        pool=pool)
        pos, *_ = sampler.run_mcmc(all_p0, 512, progress=progress)
        sampler.pool = None

        with open(emcee_filename, "wb") as f:
            pickle.dump(sampler, f)

    else:
        with open(emcee_filename, "rb") as f:
            sampler = pickle.load(f)

    # Plot walker traces:
    fig, axes = plt.subplots(nparams, 1, figsize=(8, 4*nparams),
                             sharex=True)

    for k in range(nparams):
        for walker in sampler.chain[..., k]:
            axes[k].plot(walker, marker='',
                         drawstyle='steps-mid', alpha=0.4, color='k')
    axes[0].set_title(str(i))
    fig.tight_layout()
    fig.savefig(path.join(plot_path, '{}_{:02d}_trace.png'.format(name, i)),
                dpi=250)

    return fig, sampler

In [None]:
# Run
with MultiPool(processes=4) as pool:
    _, sampler = run_pixel('test', i, z_samples, 
                           cache_path, plot_path,
                           pool, nwalkers=80, progress=True,
                           overwrite=True)