In [None]:
import pathlib
import corner
import astropy.coordinates as coord
from astropy.stats import median_absolute_deviation as MAD
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import h5py
from tqdm import tqdm
from sklearn.neighbors import KDTree
from pyia import GaiaData
from scipy.stats import binned_statistic

from jax.config import config
config.update("jax_enable_x64", True)
import jaxopt
import optax

import schlummernd as sch
from schlummernd.lvm import LinearLVM
from schlummernd.plot import colored_corner

In [None]:
output_path = pathlib.Path(
    '/mnt/home/apricewhelan/projects/schlummernd/output/hyperpar/'
)

In [None]:
conf = sch.Config.parse_yaml('../config.yml')
g_all = conf.load_training_data()

hood_idxs = np.load(
    conf.data_path / 'training_neighborhoods.npy', 
    allow_pickle=True
)

In [None]:
hoods = {}
for filename in sorted(output_path.glob('*.hdf5')):
    hood_idx = filename.parts[-1].split('.')[0].split('-')[1]
    hood_idx = int(hood_idx)

    with h5py.File(filename, mode='r') as f:
        rows = []
        for num in f:
            row = dict(f[num].attrs)
            for lbl in f[num]['true_y'].keys():
                row[f'true_{lbl}'] = f[num]['true_y'][lbl][:]
                row[lbl] = f[num]['predict_y'][lbl][:] - f[num]['true_y'][lbl][:]
                row[lbl + '_err'] = f[num]['true_yerr'][lbl][:]
            rows.append(row)
        hoods[hood_idx] = at.Table(rows)
    
    # ---
    # Spectroscopic HR diagram of the subset stars:
    g = g_all[hood_idxs[hood_idx]]
    fig, ax = plt.subplots(figsize=(6, 6))

    teff_logg_bins = (
        np.linspace(3000, 7000, 128),
        np.linspace(-0.5, 5.5, 128)
    )
    ax.hist2d(g_all.TEFF,
              g_all.LOGG,
              bins=teff_logg_bins,
              norm=mpl.colors.LogNorm(),
              cmap='Greys')

    ax.plot(g.TEFF,
            g.LOGG,
            ls='none', marker='o', mew=0, ms=2.,
            color='tab:blue', alpha=0.75)

    ax.set_xlim(teff_logg_bins[0].max(),
                teff_logg_bins[0].min())
    ax.set_ylim(teff_logg_bins[1].max(),
                teff_logg_bins[1].min())

    ax.set_xlabel('TEFF')
    ax.set_ylabel('LOGG')
    ax.set_title(f'hood: {hood_idx}')

    fig.tight_layout()

In [None]:
cols = ['M_H', 'TEFF', 'logg']

lims = {
    'M_H': (-2.5, 0.6),
    'TEFF': (6500, 3000),
    'logg': (5.5, -0.5)
}

for k in sorted(hoods.keys()):
    print(f'hood {k}')
    tbl = hoods[k]
    
    chi2 = 0
    errs = {}
    for col in cols:
        bias = np.median(tbl[col], axis=1)
        errs[col] = 1.5 * MAD(tbl[col], axis=1)
        
        chi2 = chi2 + (tbl[col] / tbl[col + "_err"])**2
        
        # print(err.argmin(), np.abs(bias).argmin())
        # print(tbl['alpha', 'beta', 'n_latents'][err.argmin()])
    
    chi2 = chi2.sum(axis=1)
    print(f'min chi2: {np.min(chi2):.2f}')
    print(tbl['alpha', 'beta', 'n_latents'][chi2.argmin()])
    
    for col in cols:
        print(f'{col} = {errs[col][chi2.argmin()]:.2f}')
    
    # ---
    
    fig, axes = plt.subplots(
        1, len(cols), 
        figsize=(5.5 * len(cols), 5.5),
        constrained_layout=True
    )
    for col, ax in zip(cols, axes):
        cs = ax.scatter(
            tbl[f'true_{col}'][chi2.argmin()],
            tbl[col][chi2.argmin()], 
            s=2,
        )
        ax.set_xlabel(f'true {col}')
        ax.set_xlim(lims[col])
        
        _sigma = 1.5 * MAD(tbl[col][chi2.argmin()])
        tmp = 5 * _sigma
        ax.set_ylim(-tmp, tmp)
        ax.axhline(-_sigma, zorder=-10, alpha=0.5, color='tab:blue')
        ax.axhline(_sigma, zorder=-10, alpha=0.5, color='tab:blue')

    axes[0].set_ylabel('predict - true')
    
    fig.suptitle(f'hood: {k}', fontsize=24)
    
    print()

In [None]:
for k in sorted(hoods.keys()):
    tbl = hoods[k]
    
    subset = tbl[tbl['beta'] == 0.1]

    fig, axes = plt.subplots(
        1, len(cols), 
        figsize=(5.5 * len(cols), 5.5), 
        sharex=True, sharey=True,
        constrained_layout=True
    )
    for col, ax in zip(cols, axes):
        bias = np.median(subset[col], axis=1)
        err = 1.5 * MAD(subset[col], axis=1)

        cs = ax.scatter(
            subset['n_latents'], 
            subset['alpha'], 
            c=err,
            s=1000,
            marker='s'
        )
        ax.set_title(col)

        ax.set_xlabel('N latents')
        cb = fig.colorbar(cs, ax=ax, aspect=50)

    cb.set_label('err')

    ax.set_xscale('log', base=2)
    ax.set_yscale('log')
    axes[0].set_ylabel(r'$\alpha$')
    
    fig.suptitle(f'hood: {k}', fontsize=24)