TODO: 
- Cross-validate on L2_ivar and training size
    - Cross-validate on likelihood evaluated on held-out data
    - Make a 3 (neighborhood size) by 7 (L2_ivar)
- Check parallax zeropoint: reverse sign and make sure we do worse

In [None]:
import os
os.environ['APOGEE_CACHE_PATH'] = "/mnt/ceph/users/apricewhelan/apogee/"
os.environ['JOAQUIN_CACHE_PATH'] = "/mnt/ceph/users/apricewhelan/projects/joaquin/cache"
import warnings
warnings.filterwarnings('ignore', category=Warning) 
import pickle

import sys
import pathlib
_path = str(pathlib.Path('../').resolve())
if _path not in sys.path:
    sys.path.append(_path)

import corner
from astropy.io import fits
import astropy.coordinates as coord
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from tqdm.auto import tqdm, trange
from sklearn.decomposition import PCA
from scipy.spatial import cKDTree
from scipy.stats import binned_statistic_2d

from joaquin import Joaquin
from joaquin.data import JoaquinData
from joaquin.config import (dr, root_cache_path, 
                            max_neighborhood_size, block_size)
from joaquin.plot import simple_corner

In [None]:
cache_path = pathlib.Path(f'../cache/{dr}').resolve()
cache_path.mkdir(exist_ok=True, parents=True)

plot_path = pathlib.Path('../plot') / dr
plot_path = plot_path.resolve()
plot_path.mkdir(parents=True, exist_ok=True)

See the first two notebooks (1- and 2-) to set up the necessary files...

In [None]:
parent_data = JoaquinData.read('parent-sample')
parent_data = parent_data[np.all(np.isfinite(parent_data.X), axis=1)]

global_spec_mask = np.load(cache_path / 'global_spec_bad_mask.npy')

In [None]:
filename = cache_path / f'good_parent_neighborhood_indices-{max_neighborhood_size}.npy'
neighborhood_idx = np.load(filename)

In [None]:
# parent_stars = parent_data.stars[parent_data.stars_mask]
# parent_d, *_ = parent_data.get_Xy(spec_mask_thresh=1.)  # disable spec mask
# assert len(parent_stars) == parent_d['X'].shape[0]

## PCA patching

In [None]:
for idx in neighborhood_idx[190:]:  # RGB, above the clump
# for idx in neighborhood_idx[192:]:  # MSTO
    data = parent_data[idx]
    
    spec_bad_mask = (data.spec_bad_masks.sum(axis=0) / len(data.stars)) > 0.25
    patched_data = data.patch_spec()
    patched_data.spec_bad_masks = None
    patched_data = patched_data.mask_spec_pixels(spec_bad_mask)
    break

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(11, 5), 
                         sharex=True, 
                         constrained_layout=True)

axes[0].scatter(data.stars['TEFF'][0],
                data.stars['LOGG'][0],
                s=6, color='tab:green', zorder=100)

stat = binned_statistic_2d(
    data.stars['TEFF'],
    data.stars['LOGG'], 
    np.arange(len(data.stars)),
    bins=(np.linspace(3000, 8500, 256),
          np.linspace(-0.5, 5.5, 256)))
axes[0].pcolormesh(stat.x_edge, stat.y_edge, stat.statistic.T)

axes[0].set_xlim(stat.x_edge.max(), stat.x_edge.min())
axes[0].set_ylim(stat.y_edge.max(), stat.y_edge.min())

axes[0].set_xlabel('TEFF')
axes[0].set_ylabel('LOGG')

# ----

axes[1].scatter(data.stars['TEFF'][0],
                data.stars['M_H'][0],
                s=6, color='tab:green', zorder=100)

stat = binned_statistic_2d(
    data.stars['TEFF'],
    data.stars['M_H'], 
    np.arange(len(data.stars)),
    bins=(np.linspace(3000, 8500, 256),
          np.linspace(-2.5, 0.6, 256)))
axes[1].pcolormesh(stat.x_edge, stat.y_edge, stat.statistic.T)

axes[1].set_ylim(-2.5, 0.5)

axes[1].set_xlabel('TEFF')
axes[1].set_ylabel('M_H')

# cb = fig.colorbar(cs, ax=axes, aspect=40)

In [None]:
tmp, _ = data.get_X('spec')
npix_fixed = (tmp[:, ~global_spec_mask] == 0).sum()

tmp_patched, _ = patched_data.get_X('spec')
assert (tmp_patched == 0).sum() == 0

print(f"{npix_fixed} pixels patched, ~{npix_fixed/tmp.shape[0]:.0f} pixels patched per star")

TODO: could make 2d images showing before/after patching. Turn masked pixels into hot pixels so they are very obvious in the before pics.

## Low-pass filter

In [None]:
lowpass_data = patched_data.lowpass_filter_spec()

In [None]:
tmp_data = lowpass_data[lowpass_data.stars['SNR'] > 300]
tmp, _ = tmp_data.get_X('spec')

# dist = coord.Distance(parallax=tmp_data.stars['GAIAEDR3_PARALLAX']*u.mas, allow_negative=True)
# MG = tmp_data.stars['GAIAEDR3_PHOT_G_MEAN_MAG'] - dist.distmod.value

fig, ax = plt.subplots(figsize=(10, 10 * tmp.shape[0] / tmp.shape[1]))

diff = tmp[tmp_data.stars['LOGG'].argsort()] - np.median(tmp, axis=0)
ax.imshow(diff, origin='lower', 
          vmin=np.percentile(diff.ravel(), 1),
          vmax=np.percentile(diff.ravel(), 99),
          cmap='RdBu')

ax.set_xticks([])
ax.set_yticks([])

ax.set_xlabel('wavelength')
ax.set_ylabel('spectra - mean, ordered by LOGG')

fig.tight_layout()

TODO: make some before/after 1D plots showing that the low-pass filter is actually doing something. show full spectrum and zoomed window, before/after.

In [None]:
# for lim in [False, 'zoom', 'zoomer']:
#     plt.figure(figsize=(16, 5))
#     plt.plot(parent_data._X_wvln, subX[i], marker='', drawstyle='steps-mid')
#     plt.plot(parent_data._X_wvln, subX_patched[i], marker='', drawstyle='steps-mid')
#     plt.plot(parent_data._X_wvln, new_ln_flux, marker='', drawstyle='steps-mid')
#     if lim == 'zoom':
#         plt.xlim(16000, 16500)
#     elif lim == 'zoomer':
#         plt.xlim(16150, 16220)

## Now try running the rest of the pipeline

Training sample is the full neighborhood, with some parallax and S/N cuts:

TODO: also crossvalidate on block_size??

In [None]:
block_size = 4096

In [None]:
# masked_data = lowpass_data.mask_spec_pixels()

# Subselect to stars that we want in train/test sets:
masked_data = lowpass_data

In [None]:
# All MAGIC NUMBERs
train_mask = (masked_data.stars['GAIAEDR3_PARALLAX_ERROR'] < 0.1)

train_mask.sum(), len(train_mask)

In [None]:
# from joaquin.crossval import get_Kfold_indices

def get_Kfold_indices(K, train_mask, block_size, rng=None):

    if rng is None:
        rng = np.random.default_rng()
    
    if train_mask.dtype is np.dtype(bool):
        train_idx = np.argwhere(train_mask).ravel()
    else:
        train_idx = train_mask
    
    assert block_size < len(train_mask)
    
    # Now split into block and zone 2 stars: the K-fold will 
    # only happen on the block stars, and the zone 2 stars 
    # will be appended to all blocks
    block_idx = train_idx[:block_size].copy()
    rng.shuffle(block_idx)
    
    zone2_idx = train_idx[block_size:]

    batch_size = block_size // K
    train_batches = []
    test_batches = []
    for k in range(K):
        if k == K-1:
            test_batch = block_idx[k*batch_size:]
        else:
            test_batch = block_idx[k*batch_size:(k+1)*batch_size]

        train_batch = np.concatenate((block_idx[~np.isin(block_idx, test_batch)],
                                      zone2_idx))
            
        test_batches.append(test_batch)
        train_batches.append(train_batch)
    
    assert np.all(np.array([len(train_batches[i]) + len(test_batches[i])
                            for i in range(len(train_batches))]) == len(train_idx))

    return train_batches, test_batches

In [None]:
rng = np.random.default_rng(seed=42)
train_idxs, test_idxs = get_Kfold_indices(K=8, train_mask=train_mask, block_size=4096, 
                                          rng=rng)

In [None]:
j = 0

marker_style = dict(marker='o', mew=0, ls='none', ms=1.5, alpha=0.5)

fig, axes = plt.subplots(1, 2, figsize=(11, 5), 
                         sharex=True, 
                         constrained_layout=True)

axes[0].scatter(data.stars['TEFF'][0],
                data.stars['LOGG'][0],
                s=6, color='tab:red', zorder=100)
axes[0].plot(masked_data.stars['TEFF'][train_idxs[j]],
             masked_data.stars['LOGG'][train_idxs[j]],
             color='tab:blue', **marker_style)
axes[0].plot(masked_data.stars['TEFF'][test_idxs[j]],
             masked_data.stars['LOGG'][test_idxs[j]],
             color='tab:orange', **marker_style)

axes[0].set_xlim(8500, 3000)
axes[0].set_ylim(5.5, -0.5)

axes[0].set_xlabel('TEFF')
axes[0].set_ylabel('LOGG')

# ----

# axes[1].scatter(data.stars['TEFF'][0],
#                 data.stars['M_H'][0],
#                 s=6, color='tab:green', zorder=100)

# stat = binned_statistic_2d(
#     data.stars['TEFF'],
#     data.stars['M_H'], 
#     np.arange(len(data.stars)),
#     bins=(np.linspace(3000, 8500, 256),
#           np.linspace(-2.5, 0.6, 256)))
# axes[1].pcolormesh(stat.x_edge, stat.y_edge, stat.statistic.T)

# axes[1].set_ylim(-2.5, 0.5)

# axes[1].set_xlabel('TEFF')
# axes[1].set_ylabel('M_H')

# # cb = fig.colorbar(cs, ax=axes, aspect=40)

In [None]:
phot_names = [
    'phot_g_mean_mag', 
    'phot_bp_mean_mag',
    'phot_rp_mean_mag', 
    'J', 'H', 'K', 
    'w1mpro', 'w2mpro'
]

In [None]:
k = 0

train_idx = train_idxs[k]
test_idx = test_idxs[k]

test_block = masked_data[test_idx]
test_X, _ = test_block.get_X(phot_names=phot_names)
test_y = test_block.stars['GAIAEDR3_PARALLAX']
test_y_ivar = 1 / test_block.stars['GAIAEDR3_PARALLAX_ERROR'] ** 2

train_block = masked_data[train_idx]
train_X, idx_map = train_block.get_X(phot_names=phot_names)
train_y = train_block.stars['GAIAEDR3_PARALLAX']
train_y_ivar = 1 / train_block.stars['GAIAEDR3_PARALLAX_ERROR'] ** 2

In [None]:
train_X.shape, test_X.shape

In [None]:
# plt.figure(figsize=(10, 10))
# plt.imshow(train_X, vmin=-0.5, vmax=0.5)
# # plt.xlim(1000, 2000)

In [None]:
_ = simple_corner(train_X[:, idx_map['phot']])

In [None]:
_ = simple_corner(train_X[:, idx_map['lsf']])

In [None]:
bins = np.linspace(-0.5, train_y.max(), 256)

plt.hist(train_y, bins=bins);

plt.hist(test_block.stars['GAIAEDR3_PARALLAX'], 
         bins=bins);

plt.yscale('log')

In [None]:
L2_ivar_vals = 10 ** np.arange(0., 5+1, 0.5)

train_lls = np.full_like(L2_ivar_vals, np.nan)
test_lls = np.full_like(L2_ivar_vals, np.nan)
for i, val in enumerate(L2_ivar_vals):
    frozen = {'L2_ivar': val, 
              'parallax_zpt': -0.03}  # MAGIC NUMBERs

    joa = Joaquin(
        train_X, 
        train_y,
        train_y_ivar, 
        idx_map, 
        frozen=frozen)
    
    test_joa = Joaquin(
        test_X, 
        test_y,
        test_y_ivar,
        idx_map, 
        frozen=frozen)
    
    init_beta = joa.init_beta()
    
    test_lls[i] = test_joa.ln_likelihood(beta=init_beta, **frozen)[0]
    train_lls[i] = joa.ln_likelihood(beta=init_beta, **frozen)[0]
    
    print(f"ivar={val:.2f} \t stddev={1/np.sqrt(val):.3f} \t "
          f"train_ll={train_lls[i]:.0f} \t test_ll={test_lls[i]:.0f}")

In [None]:
print(f"Best L2 stddev: {1 / np.sqrt(L2_ivar_vals[test_lls.argmax()]):.2f}")

# TODO: decide whether to use train or test loglikelihoods here!
frozen = {'L2_ivar': L2_ivar_vals[test_lls.argmax()],
          'parallax_zpt': -0.03}  # MAGIC NUMBERs

# Free zpt
# frozen = {'L2_ivar': L2_ivar_vals[train_lls.argmax()]}

joa = Joaquin(
    train_X, 
    train_y, 
    train_y_ivar, 
    idx_map, 
    frozen=frozen)

test_joa = Joaquin(
    test_X, 
    test_y,
    test_y_ivar,
    idx_map, 
    frozen=frozen)

In [None]:
init = joa.init(parallax_zpt=frozen.get('parallax_zpt', -0.03), 
                pack=False)
res = joa.optimize(init=init, 
                   options={'maxiter': 128})

In [None]:
res

In [None]:
fit_pars = joa.unpack_pars(res.x)

In [None]:
fig, axes = plt.subplots(3, 1, figsize=(15, 8))

for ax in axes[:2]:
    ax.plot(init['beta'])
    ax.plot(fit_pars['beta'])
    
axes[2].plot(fit_pars['beta'] - init['beta'])
    
axes[0].set_xlim(0, len(init_beta))
axes[1].set_xlim(1000, 2000)
axes[2].set_xlim(0, len(init_beta))
fig.tight_layout()

In [None]:
# pred_plx = joa.model_y(train_X, **fit_pars)  # np.exp(np.dot(X, fit_pars['beta'])) - fit_pars['parallax_zpt']
# chi = (pred_plx - train_y) * np.sqrt(train_y_ivar)

# test_pred_plx = joa.model_y(test_X, **fit_pars)  # np.exp(np.dot(test_X, fit_pars['beta'])) - fit_pars['parallax_zpt']
# test_chi = (test_pred_plx - test_y) * np.sqrt(test_y_ivar)

chi = joa.chi(**fit_pars)
test_chi = test_joa.chi(**fit_pars)

pred_plx = joa.model_y(train_X, **fit_pars)
test_pred_plx = joa.model_y(test_X, **fit_pars)

In [None]:
# c = masked_data.stars['ruwe'][train_mask]
c = None

fig, axes = plt.subplots(1, 2, figsize=(10, 5))

ax = axes[0]
ax.scatter(train_y + fit_pars['parallax_zpt'],
           pred_plx,
           c=c,
           marker='o', s=4, alpha=0.75)
ax.set_xlim(-0.5, train_y.max())
ax.set_ylim(ax.get_xlim())
ax.set_xlabel('Gaia plx')
ax.set_ylabel('Joaquin plx')

_grid = np.linspace(-0.5, train_y.max(), 10)
ax.plot(_grid, _grid, marker='', 
        zorder=-10, color='yellow')

ax = axes[1]
ax.plot(train_y,
        chi,
        marker='o', ls='none', mew=0, ms=1.5, alpha=0.75)
ax.set_xlim(-0.5, train_y.max())
ax.set_ylim(-8, 8)
ax.set_xlabel('Gaia plx')
ax.set_ylabel(r'$\chi$')

fig.tight_layout()

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

ax = axes[0]
ax.plot(test_block.stars['GAIAEDR3_PARALLAX'] + fit_pars['parallax_zpt'],
        test_pred_plx,
        marker='o', ls='none', mew=0, ms=1.5, alpha=0.75)
ax.set_xlim(-0.5, train_y.max())
ax.set_ylim(ax.get_xlim())
ax.set_xlabel('Gaia plx')
ax.set_ylabel('Joaquin plx')

_grid = np.linspace(-0.5, train_y.max(), 10)
ax.plot(_grid, _grid, marker='', 
        zorder=-10, color='yellow')

ax = axes[1]
ax.plot(test_block.stars['GAIAEDR3_PARALLAX'],
        test_chi,
        marker='o', ls='none', mew=0, ms=1.5, alpha=0.75)
ax.set_xlim(-0.5, train_y.max())
ax.set_ylim(-8, 8)
ax.set_xlabel('Gaia plx')
ax.set_ylabel(r'$\chi$')

fig.tight_layout()

In [None]:
for ccc in [chi, test_chi]:
    plt.figure()
    plt.hist(ccc, bins=np.linspace(-5, 5, 64));
    for x in np.percentile(ccc, [16, 84]):
        plt.axvline(x, color='tab:blue')

    plt.axvline(1, linestyle='--', color='#666666')
    plt.axvline(-1, linestyle='--', color='#666666')

In [None]:
plt.figure()

plx_check_mask = (
    ((test_block.stars['GAIAEDR3_PARALLAX'] / test_block.stars['GAIAEDR3_PARALLAX_ERROR']) > 20) &
    (test_block.stars['ruwe'] < 1.4)
)
diff = (test_block.stars['GAIAEDR3_PARALLAX'] + fit_pars['parallax_zpt'] - test_pred_plx) / test_pred_plx
diff = diff[plx_check_mask]
print(len(diff))

plt.hist(diff, bins=np.linspace(-1, 1, 64));

MAD = lambda x: np.median(np.abs(x - np.median(x)))
plt.axvline(-1.5 * MAD(diff))
plt.axvline(1.5 * MAD(diff))
# plt.axvline(np.median(diff))

print(1.5 * MAD(diff))

# plt.figure()
# plt.hist(np.log(test_block.stars['GAIAEDR3_PARALLAX']) - np.log(test_pred_plx),
#          bins=np)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

ax = axes[0]
ax.plot(test_block.stars['GAIAEDR3_PARALLAX'][plx_check_mask] + fit_pars['parallax_zpt'],
        test_pred_plx[plx_check_mask],
        marker='o', ls='none', mew=0, ms=1.5, alpha=0.75)
ax.set_xlim(-0.5, train_y.max())
ax.set_ylim(ax.get_xlim())
ax.set_xlabel('Gaia plx')
ax.set_ylabel('Joaquin plx')

_grid = np.linspace(-0.5, train_y.max(), 10)
ax.plot(_grid, _grid, marker='', 
        zorder=-10, color='yellow')

ax = axes[1]
ax.plot(test_block.stars['GAIAEDR3_PARALLAX'][plx_check_mask],
        test_chi[plx_check_mask],
        marker='o', ls='none', mew=0, ms=1.5, alpha=0.75)
ax.set_xlim(-0.5, train_y.max())
ax.set_ylim(-8, 8)
ax.set_xlabel('Gaia plx')
ax.set_ylabel(r'$\chi$')

fig.tight_layout()

# Full Pipeline

In [None]:
Kfold_K = 4
block_size = 1024
train_mask = np.argwhere(
    masked_data.stars['GAIAEDR3_PARALLAX_ERROR'] < 0.1
).ravel()
L2_ivar_vals = 10 ** np.arange(0., 5+1, 0.5)
train_sizes = np.array([4096, 8192, 16384, 32768])

In [None]:
rng = np.random.default_rng(seed=42)

train_lls = np.full((len(train_sizes), Kfold_K, len(L2_ivar_vals)), 
                    np.nan)
test_lls = np.full((len(train_sizes), Kfold_K, len(L2_ivar_vals)), 
                   np.nan)

for i, train_size in enumerate(tqdm(train_sizes)):
    train_idxs, test_idxs = get_Kfold_indices(
        K=Kfold_K, 
        train_mask=train_mask[:train_size], 
        block_size=block_size, 
        rng=rng
    )

    for k in tqdm(np.arange(Kfold_K), leave=False):
        train_idx = train_idxs[k]
        test_idx = test_idxs[k]

        test_block = masked_data[test_idx]
        test_X, _ = test_block.get_X(phot_names=phot_names)
        test_y = test_block.stars['GAIAEDR3_PARALLAX']
        test_y_ivar = 1 / test_block.stars['GAIAEDR3_PARALLAX_ERROR'] ** 2

        train_block = masked_data[train_idx]
        train_X, idx_map = train_block.get_X(phot_names=phot_names)
        train_y = train_block.stars['GAIAEDR3_PARALLAX']
        train_y_ivar = 1 / train_block.stars['GAIAEDR3_PARALLAX_ERROR'] ** 2
        
        joa = Joaquin(
            train_X, 
            train_y,
            train_y_ivar, 
            idx_map)

        test_joa = Joaquin(
            test_X, 
            test_y,
            test_y_ivar,
            idx_map)
        
        for j, L2_ivar in enumerate(tqdm(L2_ivar_vals, leave=False)):
            frozen = {'L2_ivar': L2_ivar, 
                      'parallax_zpt': -0.03}  # MAGIC NUMBERs

            init_beta = joa.init_beta(**frozen)

            test_lls[i, k, j] = test_joa.ln_likelihood(beta=init_beta, **frozen)[0]
            train_lls[i, k, j] = joa.ln_likelihood(beta=init_beta, **frozen)[0]

In [None]:
train_ll = np.mean(train_lls, axis=1)
test_ll = np.mean(test_lls, axis=1)

In [None]:
L2_ivar_vals_2d, train_sizes_2d = np.meshgrid(L2_ivar_vals, train_sizes)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 3))
ax.scatter(L2_ivar_vals_2d, train_sizes_2d, 
           c=test_ll, 
           vmin=np.percentile(test_ll, 25),
           vmax=np.percentile(test_ll, 99.5),
           marker='s', s=500)
ax.set_xscale('log')
ax.set_yscale('log', base=2)
ax.set_xlabel('L2 ivar')
ax.set_ylabel('train size')

In [None]:
# fig, ax = plt.subplots(1, 1, figsize=(10, 3))
# ax.pcolormesh(L2_ivar_vals, train_sizes, test_ll, 
#               vmin=np.percentile(test_ll, 25),
#               vmax=np.percentile(test_ll, 99.5),)
# ax.set_xscale('log')
# ax.set_xlabel('L2 ivar')
# ax.set_ylabel('train size')

In [None]:
cross_val_L2_ivar = L2_ivar_vals_2d.ravel()[test_ll.argmax()]
cross_val_train_size = train_sizes_2d.ravel()[test_ll.argmax()]
print(cross_val_L2_ivar, cross_val_train_size)
print(f"Best L2 stddev: {1 / np.sqrt(cross_val_L2_ivar):.3f}")

In [None]:
rng = np.random.default_rng(seed=42)
train_idxs, test_idxs = get_Kfold_indices(
    K=2, 
    train_mask=train_mask[:cross_val_train_size], 
    block_size=block_size, 
    rng=rng
)

frozen = {'L2_ivar': cross_val_L2_ivar,
          'parallax_zpt': -0.03}  # MAGIC NUMBERs

for k in tqdm(len(train_idxs)):
    train_idx = train_idxs[k]
    test_idx = test_idxs[k]

    test_block = masked_data[test_idx]
    test_X, _ = test_block.get_X(phot_names=phot_names)
    test_y = test_block.stars['GAIAEDR3_PARALLAX']
    test_y_ivar = 1 / test_block.stars['GAIAEDR3_PARALLAX_ERROR'] ** 2

    train_block = masked_data[train_idx]
    train_X, idx_map = train_block.get_X(phot_names=phot_names)
    train_y = train_block.stars['GAIAEDR3_PARALLAX']
    train_y_ivar = 1 / train_block.stars['GAIAEDR3_PARALLAX_ERROR'] ** 2

    joa = Joaquin(
        train_X, 
        train_y,
        train_y_ivar, 
        idx_map,
        frozen=frozen)

    test_joa = Joaquin(
        test_X, 
        test_y,
        test_y_ivar,
        idx_map, 
        frozen=frozen)
    
    init = joa.init(parallax_zpt=frozen.get('parallax_zpt', -0.03), 
                    pack=False)
    res = joa.optimize(init=init, 
                       options={'maxiter': 128})
    
    break

In [None]:
# block = masked_data[:block_size]
block = masked_data[test_idxs[1]]
block_X, idx_map = block.get_X(phot_names=phot_names)
block_y = block.stars['GAIAEDR3_PARALLAX']
block_y_ivar = 1 / block.stars['GAIAEDR3_PARALLAX_ERROR'] ** 2

block_joa = Joaquin(
    block_X, 
    block_y,
    block_y_ivar, 
    idx_map,
    frozen=frozen)

In [None]:
res

In [None]:
fit_pars = joa.unpack_pars(res.x)

In [None]:
fig, axes = plt.subplots(3, 1, figsize=(15, 8), sharey=True)

for ax in axes[:2]:
    ax.plot(init['beta'])
    ax.plot(fit_pars['beta'])
    
axes[2].plot(fit_pars['beta'] - init['beta'])
    
axes[0].set_xlim(0, len(init_beta))
axes[1].set_xlim(1000, 2000)
axes[2].set_xlim(0, len(init_beta))
axes[0].set_ylim(-5 / np.sqrt(cross_val_L2_ivar),
                 5 / np.sqrt(cross_val_L2_ivar))
fig.tight_layout()

In [None]:
chi = joa.chi(**fit_pars)
test_chi = test_joa.chi(**fit_pars)
block_chi = block_joa.chi(**fit_pars)

pred_plx = joa.model_y(train_X, **fit_pars)
test_pred_plx = joa.model_y(test_X, **fit_pars)
block_pred_plx = block_joa.model_y(block_X, **fit_pars)

In [None]:
# c = masked_data.stars['ruwe'][train_mask]
c = None

fig, axes = plt.subplots(1, 2, figsize=(10, 5))

ax = axes[0]
ax.scatter(train_y + fit_pars['parallax_zpt'],
           pred_plx,
           c=c,
           marker='o', s=4, alpha=0.75)
ax.set_xlim(-0.5, train_y.max())
ax.set_ylim(ax.get_xlim())
ax.set_xlabel('Gaia plx')
ax.set_ylabel('Joaquin plx')

_grid = np.linspace(-0.5, train_y.max(), 10)
ax.plot(_grid, _grid, marker='', 
        zorder=-10, color='yellow')

ax = axes[1]
ax.plot(train_y,
        chi,
        marker='o', ls='none', mew=0, ms=1.5, alpha=0.75)
ax.set_xlim(-0.5, train_y.max())
ax.set_ylim(-8, 8)
ax.set_xlabel('Gaia plx')
ax.set_ylabel(r'$\chi$')

fig.tight_layout()

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

ax = axes[0]
ax.plot(test_block.stars['GAIAEDR3_PARALLAX'] + fit_pars['parallax_zpt'],
        test_pred_plx,
        marker='o', ls='none', mew=0, ms=1.5, alpha=0.75)
ax.set_xlim(-0.5, train_y.max())
ax.set_ylim(ax.get_xlim())
ax.set_xlabel('Gaia plx')
ax.set_ylabel('Joaquin plx')

_grid = np.linspace(-0.5, train_y.max(), 10)
ax.plot(_grid, _grid, marker='', 
        zorder=-10, color='yellow')

ax = axes[1]
ax.plot(test_block.stars['GAIAEDR3_PARALLAX'],
        test_chi,
        marker='o', ls='none', mew=0, ms=1.5, alpha=0.75)
ax.set_xlim(-0.5, train_y.max())
ax.set_ylim(-8, 8)
ax.set_xlabel('Gaia plx')
ax.set_ylabel(r'$\chi$')

fig.tight_layout()

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

ax = axes[0]
ax.plot(block.stars['GAIAEDR3_PARALLAX'] + fit_pars['parallax_zpt'],
        block_pred_plx,
        marker='o', ls='none', mew=0, ms=1.5, alpha=0.75)
ax.set_xlim(-0.5, train_y.max())
ax.set_ylim(ax.get_xlim())
ax.set_xlabel('Gaia plx')
ax.set_ylabel('Joaquin plx')

_grid = np.linspace(-0.5, train_y.max(), 10)
ax.plot(_grid, _grid, marker='', 
        zorder=-10, color='yellow')

ax = axes[1]
ax.plot(block.stars['GAIAEDR3_PARALLAX'],
        block_chi,
        marker='o', ls='none', mew=0, ms=1.5, alpha=0.75)
ax.set_xlim(-0.5, train_y.max())
ax.set_ylim(-8, 8)
ax.set_xlabel('Gaia plx')
ax.set_ylabel(r'$\chi$')

fig.tight_layout()

In [None]:
from astropy.stats import median_absolute_deviation as MAD

In [None]:
diff = (block.stars['GAIAEDR3_PARALLAX'] + fit_pars['parallax_zpt'] - block_pred_plx) / block_pred_plx

snr_cuts = np.geomspace(5, 100, 10).astype(int)
vals = np.full_like(snr_cuts, np.nan, dtype=float)
for i, plx_snr_cut in enumerate(snr_cuts):
    plx_check_mask = (
        ((block.stars['GAIAEDR3_PARALLAX'] / block.stars['GAIAEDR3_PARALLAX_ERROR']) > plx_snr_cut) &
        (block.stars['ruwe'] < 1.4)
    )
    if plx_check_mask.sum() < 10:
        break
    vals[i] = 1.5 * MAD(diff[plx_check_mask])

In [None]:
plt.figure()
plt.plot(snr_cuts, vals)
plt.plot(snr_cuts, 1/snr_cuts)
plt.xlabel('parallax S/N')
plt.ylabel('(joaquin - gaia) / joaquin')

In [None]:
plt.figure()

plx_snr_cut = 60

plx_check_mask = (
    ((block.stars['GAIAEDR3_PARALLAX'] / block.stars['GAIAEDR3_PARALLAX_ERROR']) > plx_snr_cut) &
    (block.stars['ruwe'] < 1.4)
)
diff = (block.stars['GAIAEDR3_PARALLAX'] + fit_pars['parallax_zpt'] - block_pred_plx) / block_pred_plx
diff = diff[plx_check_mask]
print(len(diff))

plt.hist(diff, bins=np.linspace(-1, 1, 64));
plt.axvline(-1.5 * MAD(diff))
plt.axvline(1.5 * MAD(diff))
# plt.axvline(np.median(diff))

print(1.5 * MAD(diff))