In [None]:
import glob
from os import path
import os
import sys
path_ = path.abspath('../scripts/')
if path_ not in sys.path:
    sys.path.insert(0, path_)
import pickle
    
import astropy.coordinates as coord
from astropy.constants import G
from astropy.table import Table
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from tqdm import tqdm

from hq.config import HQ_CACHE_PATH, config_to_alldata
from hq.plot import plot_two_panel, plot_phase_fold
from hq.data import get_rvdata
from hq.physics_helpers import period_at_surface, stellar_radius

from model import Model, lntruncnorm
from run_sampler import logg_bincenters, teff_bincenters, mh_bincenters, rg_mh_bincenters

### RG:

In [None]:
mean_pars = []
std_pars = []
for filename in sorted(glob.glob('../cache/rg*.pkl')):
    with open(filename, 'rb') as f:
        sampler = pickle.load(f)
    
    flatchain = np.vstack(sampler.chain[:, 256:][:, ::8])
    pars = Model.unpack_pars(flatchain.T)
    
    mean_pars.append({k: np.mean(v) for k, v in pars.items()})
    std_pars.append({k: np.std(v) for k, v in pars.items()})
    
mean_pars = Table(mean_pars)
std_pars = Table(std_pars)

mean_pars['logg'] = logg_bincenters[:len(mean_pars)]
std_pars['logg'] = logg_bincenters[:len(mean_pars)]

mean_pars = mean_pars[mean_pars['logg'] >= 1]
std_pars = std_pars[std_pars['logg'] >= 1]

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(12, 6), 
                         sharex=True)
for i, colname in enumerate(mean_pars.colnames):
    if colname == 'logg': continue
    ax = axes.flat[i]
    ax.errorbar(mean_pars['logg'], 
                mean_pars[colname], std_pars[colname],
                marker='o', ls='none')
    # plt.xlabel(r'$\log g$')
    ax.set_ylabel(colname)

axes.flat[-1].set_visible(False)
    
fig.tight_layout()

In [None]:
norm = mpl.colors.Normalize(vmin=mean_pars['logg'].min(),
                            vmax=mean_pars['logg'].max())
cmap = plt.get_cmap('magma')

In [None]:
plt.figure()
P_grid = np.logspace(np.log10(2), np.log10(65536), 1024)
for row in mean_pars:
    sigma = np.exp(row['lnsigz'])
    mu = row['muz']
    
    y = np.exp(lntruncnorm(np.log(P_grid), mu, sigma, np.log(2.), np.log(65536)))
    plt.plot(P_grid, y, marker='', color=cmap(norm(row['logg'])))
    
plt.xscale('log', basex=10)

### MS:

In [None]:
mean_pars = []
std_pars = []
for filename in sorted(glob.glob('../cache/ms_*.pkl')):
    with open(filename, 'rb') as f:
        sampler = pickle.load(f)
    
    flatchain = np.vstack(sampler.chain[:, 256:][:, ::8])
    pars = Model.unpack_pars(flatchain.T)
    
    mean_pars.append({k: np.mean(v) for k, v in pars.items()})
    std_pars.append({k: np.std(v) for k, v in pars.items()})
    
mean_pars = Table(mean_pars)
std_pars = Table(std_pars)

mean_pars['teff'] = teff_bincenters[:len(mean_pars)]
std_pars['teff'] = teff_bincenters[:len(mean_pars)]

mean_pars = mean_pars[mean_pars['teff'] < 6500]
std_pars = std_pars[std_pars['teff'] < 6500]

In [None]:
norm = mpl.colors.Normalize(vmin=mean_pars['teff'].min(),
                            vmax=mean_pars['teff'].max())
cmap = plt.get_cmap('magma')

In [None]:
plt.figure()
P_grid = np.logspace(np.log10(2), np.log10(65536), 1024)
for row in mean_pars:
    sigma = np.exp(row['lnsigz'])
    mu = row['muz']
    
    y = np.exp(lntruncnorm(np.log(P_grid), mu, sigma, np.log(2.), np.log(65536)))
    plt.plot(P_grid, y, marker='', color=cmap(norm(row['teff'])))
    
plt.xscale('log', basex=10)

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(12, 6), 
                         sharex=True)
for i, colname in enumerate(mean_pars.colnames):
    if colname == 'teff': continue
    ax = axes.flat[i]
    ax.errorbar(mean_pars['teff'], 
                mean_pars[colname], std_pars[colname],
                marker='o', ls='none')
    # plt.xlabel(r'$\log g$')
    ax.set_ylabel(colname)

fig.tight_layout()

### MS M/H:

In [None]:
mean_pars = []
std_pars = []
for filename in sorted(glob.glob('../cache/ms-mh*.pkl')):
    with open(filename, 'rb') as f:
        sampler = pickle.load(f)
    
    flatchain = np.vstack(sampler.chain[:, 256:][:, ::8])
    pars = Model.unpack_pars(flatchain.T)
    
    mean_pars.append({k: np.mean(v) for k, v in pars.items()})
    std_pars.append({k: np.std(v) for k, v in pars.items()})
    
mean_pars = Table(mean_pars)
std_pars = Table(std_pars)

mean_pars['m_h'] = mh_bincenters[:len(mean_pars)]
std_pars['m_h'] = mh_bincenters[:len(mean_pars)]

# mean_pars = mean_pars[mean_pars['teff'] < 6500]
# std_pars = std_pars[std_pars['teff'] < 6500]

In [None]:
for colname in mean_pars.colnames:
    if colname == 'm_h': continue
    plt.figure()
    plt.errorbar(mean_pars['m_h'], 
                 mean_pars[colname], std_pars[colname],
                 marker='o', ls='none')
    # plt.xlabel(r'$\log g$')
    plt.ylabel(colname)

In [None]:
norm = mpl.colors.Normalize(vmin=mean_pars['m_h'].min(),
                            vmax=mean_pars['m_h'].max())
cmap = plt.get_cmap('magma')

In [None]:
plt.figure()
P_grid = np.logspace(np.log10(2), np.log10(65536), 1024)
for row in mean_pars:
    sigma = np.exp(row['lnsigz'])
    mu = row['muz']
    
    y = np.exp(lntruncnorm(np.log(P_grid), mu, sigma, np.log(2.), np.log(65536)))
    plt.plot(P_grid, y, marker='', color=cmap(norm(row['m_h'])))
    
plt.xscale('log', basex=10)

### RG M/H:

In [None]:
mean_pars = []
std_pars = []
for filename in sorted(glob.glob('../cache/rg-mh*.pkl')):
    with open(filename, 'rb') as f:
        sampler = pickle.load(f)
    
    flatchain = np.vstack(sampler.chain[:, 256:][:, ::8])
    pars = Model.unpack_pars(flatchain.T)
    
    mean_pars.append({k: np.mean(v) for k, v in pars.items()})
    std_pars.append({k: np.std(v) for k, v in pars.items()})
    
mean_pars = Table(mean_pars)
std_pars = Table(std_pars)

mean_pars['m_h'] = rg_mh_bincenters[:len(mean_pars)]
std_pars['m_h'] = rg_mh_bincenters[:len(mean_pars)]
print(len(std_pars))

# mean_pars = mean_pars[mean_pars['teff'] < 6500]
# std_pars = std_pars[std_pars['teff'] < 6500]

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(12, 6), 
                         sharex=True)
for i, colname in enumerate(mean_pars.colnames):
    if colname == 'm_h': continue
    ax = axes.flat[i]
    ax.errorbar(mean_pars['m_h'], 
                mean_pars[colname], std_pars[colname],
                marker='o', ls='none')
    # plt.xlabel(r'$\log g$')
    ax.set_ylabel(colname)

fig.tight_layout()

In [None]:
norm = mpl.colors.Normalize(vmin=mean_pars['m_h'].min(),
                            vmax=mean_pars['m_h'].max())
cmap = plt.get_cmap('magma')

In [None]:
plt.figure()
P_grid = np.logspace(np.log10(2), np.log10(65536), 1024)
for row in mean_pars:
    sigma = np.exp(row['lnsigz'])
    mu = row['muz']
    
    y = np.exp(lntruncnorm(np.log(P_grid), mu, sigma, np.log(2.), np.log(65536)))
    plt.plot(P_grid, y, marker='', color=cmap(norm(row['m_h'])))
    
plt.xscale('log', basex=10)