In [None]:
import os
os.environ['APOGEE_CACHE_PATH'] = "/mnt/ceph/users/apricewhelan/apogee-test/"
os.environ['JOAQUIN_CACHE_PATH'] = "/mnt/ceph/users/apricewhelan/projects/joaquin/cache"
import warnings
warnings.filterwarnings('ignore', category=Warning) 
import pickle

import sys
import pathlib
_path = str(pathlib.Path('../').resolve())
if _path not in sys.path:
    sys.path.append(_path)

import corner
from astropy.io import fits
import astropy.coordinates as coord
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from tqdm.auto import tqdm
from sklearn.decomposition import PCA
from scipy.spatial import cKDTree

from joaquin.data import JoaquinData
from joaquin.config import (dr, root_cache_path, 
                            neighborhood_size, block_size)
from joaquin.plot import simple_corner

In [None]:
cache_path = pathlib.Path(f'../cache/{dr}').resolve()
cache_path.mkdir(exist_ok=True, parents=True)

plot_path = pathlib.Path('../plot') / dr
plot_path = plot_path.resolve()
plot_path.mkdir(parents=True, exist_ok=True)

See the first two notebooks (1- and 2-) to set up the necessary files...

In [None]:
parent_data = JoaquinData.read('parent-sample')
parent_data = parent_data[np.all(np.isfinite(parent_data.X), axis=1)]

global_spec_mask = np.load(cache_path / 'global_spec_bad_mask.npy')

neighborhood_idx = np.load(cache_path / 'good_parent_neighborhood_indices.npy')

In [None]:
# parent_stars = parent_data.stars[parent_data.stars_mask]
# parent_d, *_ = parent_data.get_Xy(spec_mask_thresh=1.)  # disable spec mask
# assert len(parent_stars) == parent_d['X'].shape[0]

## PCA patching

In [None]:
for idx in neighborhood_idx[131:]:
    data = parent_data[idx]
    
    spec_bad_mask = (data.spec_bad_masks.sum(axis=0) / len(data.stars)) > 0.25
    patched_data = data.patch_spec()
    patched_data.spec_bad_masks = None
    patched_data = patched_data.mask_spec_pixels(spec_bad_mask)
    break

In [None]:
tmp, _ = data.get_X('spec')
npix_fixed = (tmp[:, ~global_spec_mask] == 0).sum()

tmp_patched, _ = patched_data.get_X('spec')
assert (tmp_patched == 0).sum() == 0

print(f"{npix_fixed} pixels patched, ~{npix_fixed/tmp.shape[0]:.0f} pixels patched per star")

TODO: could make 2d images showing before/after patching. Turn masked pixels into hot pixels so they are very obvious in the before pics.

## Low-pass filter

In [None]:
lowpass_data = patched_data.lowpass_filter_spec()

In [None]:
tmp, _ = lowpass_data.get_X('spec')

In [None]:
dist = coord.Distance(parallax=lowpass_data.stars['GAIAEDR3_PARALLAX']*u.mas, allow_negative=True)
MG = lowpass_data.stars['GAIAEDR3_PHOT_G_MEAN_MAG'] - dist.distmod.value

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))
# ax.imshow(tmp[MG.argsort()] - np.mean(tmp, axis=0), origin='lower', 
#           vmin=np.percentile(tmp.ravel(), 1),
#           vmax=np.percentile(tmp.ravel(), 99))

# diff = tmp[MG.argsort()] - np.mean(tmp, axis=0)
diff = tmp[lowpass_data.stars['LOGG'].argsort()] - np.mean(tmp, axis=0)
ax.imshow(diff, origin='lower', 
          vmin=np.percentile(diff.ravel(), 1),
          vmax=np.percentile(diff.ravel(), 99),
          cmap='RdBu')

ax.set_xticks([])
ax.set_yticks([])

ax.set_xlabel('wavelength')
ax.set_ylabel('stars, ordered by LOGG')

fig.tight_layout()

TODO: make some before/after 1D plots showing that the low-pass filter is actually doing something. show full spectrum and zoomed window, before/after.

In [None]:
# for lim in [False, 'zoom', 'zoomer']:
#     plt.figure(figsize=(16, 5))
#     plt.plot(parent_data._X_wvln, subX[i], marker='', drawstyle='steps-mid')
#     plt.plot(parent_data._X_wvln, subX_patched[i], marker='', drawstyle='steps-mid')
#     plt.plot(parent_data._X_wvln, new_ln_flux, marker='', drawstyle='steps-mid')
#     if lim == 'zoom':
#         plt.xlim(16000, 16500)
#     elif lim == 'zoomer':
#         plt.xlim(16150, 16220)

## Now try running the rest of the pipeline

Training sample is the full neighborhood, with some parallax and S/N cuts:

In [None]:
# masked_data = lowpass_data.mask_spec_pixels()
masked_data = lowpass_data

In [None]:
block = masked_data[:block_size]
zone2 = masked_data[block_size:]

In [None]:
plt.scatter(block.stars['TEFF'],
            block.stars['LOGG'], 
            c=block.stars['M_H'],
            s=4, vmin=-1.5, vmax=0.5)

plt.scatter(zone2.stars['TEFF'],
            zone2.stars['LOGG'], 
            s=2, zorder=-10)

plt.xlim(8500, 3000)
plt.ylim(5.5, -0.5)

In [None]:
train_mask = (
    (block.stars['SNR'] > 100) &
    (block.stars['GAIAEDR3_PARALLAX_ERROR'] < 0.1)
)

zone2_train_mask = (
    (zone2.stars['SNR'] > 100) &
    (zone2.stars['GAIAEDR3_PARALLAX_ERROR'] < 0.1)
)

# TODO: add RUWE selection

In [None]:
# from joaquin.crossval import get_Kfold_indices

def get_Kfold_indices(N, K, train_mask=None, rng=None):

    if rng is None:
        rng = np.random.default_rng()

    idx = np.arange(N)
    rng.shuffle(idx)
    
    # We may want to impose other criteria on the training 
    # sample, like high S/N or small parallax error
    if train_mask is not None:
        train_subset_idx = np.argwhere(train_mask).ravel()
    else:
        train_subset_idx = idx

    batch_size = N // K
    train_batches = []
    test_batches = []
    for k in range(K):
        if k == K-1:
            test_batch = idx[k*batch_size:]
        else:
            test_batch = idx[k*batch_size:(k+1)*batch_size]

        train_batch = idx[~np.isin(idx, test_batch) & 
                          np.isin(idx, train_subset_idx)]
        
        # adds the stars that don't meet quality cuts to 
        # appear in the training sample:
        test_batch = idx[~np.isin(idx, train_batch)]
            
        test_batches.append(test_batch)
        train_batches.append(train_batch)
    
    assert np.all(np.array([len(train_batches[i]) + len(test_batches[i])
                            for i in range(len(train_batches))]) == N)

    return train_batches, test_batches

In [None]:
rng = np.random.default_rng(seed=42)
train_idxs, test_idxs = get_Kfold_indices(len(block.stars), K=8, rng=rng, 
                                          train_mask=train_mask)

In [None]:
phot_names = [
    'GAIAEDR3_PHOT_G_MEAN_MAG', 
    'GAIAEDR3_PHOT_BP_MEAN_MAG',
    'GAIAEDR3_PHOT_RP_MEAN_MAG', 
    'J', 'H', 'K', 
    'w1mpro', 'w2mpro'
]

In [None]:
i = 0

train_idx = train_idxs[i]
test_idx = test_idxs[i]

test_block = block[test_idx]
test_X, _ = test_block.get_X(phot_names=phot_names)
test_y = test_block.stars['GAIAEDR3_PARALLAX']
test_y_ivar = 1 / test_block.stars['GAIAEDR3_PARALLAX_ERROR'] ** 2

train_block = block[train_idx]
block_train_X, idx_map = train_block.get_X(phot_names=phot_names)
block_train_y = train_block.stars['GAIAEDR3_PARALLAX']
block_train_y_ivar = 1 / train_block.stars['GAIAEDR3_PARALLAX_ERROR'] ** 2

In [None]:
# X2, idx_map2 = zone2.get_X()
# y2 = zone2.stars['GAIAEDR3_PARALLAX']
# y_ivar2 = 1 / zone2.stars['GAIAEDR3_PARALLAX_ERROR'] ** 2

# X = np.vstack((block_train_X, X2[zone2_train_mask]))
# y = np.concatenate((block_train_y, y2[zone2_train_mask]))
# y_ivar = np.concatenate((block_train_y_ivar, y_ivar2[zone2_train_mask]))

# for k in idx_map:
#     assert np.all(idx_map[k] == idx_map2[k])

# HACK: TESTING
X = block_train_X
y = block_train_y
y_ivar = block_train_y_ivar

In [None]:
# _ = simple_corner(X[:, idx_map['phot']])

In [None]:
# _ = simple_corner(X[:, idx_map['lsf']])

In [None]:
bins = np.linspace(-0.5, 5, 256)
plt.hist(y, bins=bins);

plt.hist(test_block.stars['GAIAEDR3_PARALLAX'], 
         bins=bins);

plt.yscale('log')

In [None]:
from scipy.optimize import minimize
from joaquin.logger import logger


class Joaquin:

    def __init__(self, X, y, y_ivar, idx_map, frozen=None):
        self.X = X 
        self.y = y
        self.y_ivar = y_ivar

        # Currently, stores parameter names and shapes
        self._param_info = {}

        # duh
        self._param_info['parallax_zpt'] = 1

        # the inv-var of the prior on the spectral components in beta
        self._param_info['L2_ivar'] = 1

        # linear coefficients (in the exp argument)
        self._param_info['beta'] = self.X.shape[1]
        
        self.idx_map = idx_map
        if 'spec' in idx_map:
            L2_slice = self.idx_map['spec']
        else:
            L2_slice = np.ones(self.X.shape[1], dtype=bool)
        self.L2_slice = L2_slice

        if frozen is None:
            frozen = {}
        self.frozen = frozen

    def unpack_pars(self, par_list):
        i = 0
        par_dict = {}
        for key, par_len in self._param_info.items():
            if key in self.frozen:
                par_dict[key] = self.frozen[key]
            else:
                par_dict[key] = np.array(par_list[i:i+par_len])
                if len(par_dict[key]) == 1:  # HORRIBLE
                    par_dict[key] = par_dict[key][0]

                i += par_len

        return par_dict

    def pack_pars(self, par_dict):
        parvec = []
        for i, k in enumerate(self._param_info):
            if k not in self.frozen:
                parvec.append(par_dict[k])
        return np.concatenate(parvec)

    def init_beta(self, parallax_zpt=None, L2_ivar=None):
        parallax_zpt = self.frozen.get('parallax_zpt', parallax_zpt)
        L2_ivar = self.frozen.get('L2_ivar', L2_ivar)

        if parallax_zpt is None or L2_ivar is None:
            raise ValueError('todo')

        y = self.y + parallax_zpt
        plx_mask = y > (3 / np.sqrt(self.y_ivar))  # 3 sigma

        X = self.X[plx_mask]
        y = y[plx_mask]
        y_ivar = self.y_ivar[plx_mask]

        ln_plx_ivar = y**2 * y_ivar
        ln_y = np.log(y)

        XT_Cinv = X.T * ln_plx_ivar
        XT_Cinv_X = np.dot(XT_Cinv, X)
        XT_Cinv_X[np.diag_indices(X.shape[1])] += L2_ivar

        beta = np.linalg.solve(XT_Cinv_X, np.dot(XT_Cinv, ln_y))
        return beta

    def chi(self, parallax_zpt, L2_ivar, beta):
        y = self.y + parallax_zpt
        model_ln_plx = np.dot(self.X, beta)
        model_y = np.exp(model_ln_plx)
        resid = y - model_y
        return resid * np.sqrt(self.y_ivar)

    def ln_likelihood(self, parallax_zpt, L2_ivar, beta):
        y = self.y + parallax_zpt
        model_ln_plx = np.dot(self.X, beta)
        model_y = np.exp(model_ln_plx)
        resid = y - model_y

        ll = -0.5 * np.sum(resid**2 * self.y_ivar)
        ll_grad = np.dot(self.X.T,
                         model_y * self.y_ivar * resid)

        return ll, ll_grad

    def ln_prior(self, parallax_zpt, L2_ivar, beta):
        lp = - 0.5 * L2_ivar * np.sum(beta[self.L2_slice] ** 2)
        lp_grad = np.zeros_like(beta)
        lp_grad[self.L2_slice] = - L2_ivar * beta[self.L2_slice]
        return lp, lp_grad

    def neg_ln_posterior(self, parallax_zpt, L2_ivar, beta):
        ll, ll_grad = self.ln_likelihood(parallax_zpt, L2_ivar, beta)
        lp, lp_grad = self.ln_prior(parallax_zpt, L2_ivar, beta)
        logger.log(0, f'objective function evaluation: ll={ll}, lp={lp}')
        return - (ll + lp), - (ll_grad + lp_grad)

    def __call__(self, p):
        par_dict = self.unpack_pars(p)
        return self.neg_ln_posterior(**par_dict)

    def optimize(self, init=None, **minimize_kwargs):
        """
        To set the maximum number of function evaluations, pass:

            options={'maxfun': ...}

        """
        if init is None:
            init = {}

        init.setdefault('parallax_zpt', 0.)
        init.setdefault('L2_ivar', 1.)

        if 'beta' not in init:
            init['beta'] = self.init_beta(**init)

        x0 = self.pack_pars(init)

        minimize_kwargs.setdefault('method', 'L-BFGS-B')
        if minimize_kwargs['method'] == 'L-BFGS-B':
            minimize_kwargs.setdefault('options', {'maxfun': 1024})

        res = minimize(
            self,
            x0=x0,
            jac=True,
            **minimize_kwargs)

        return res

In [None]:
# for val in [1e-4, 1e-3, 1e-2, 1e-1]:
#     frozen = {'L2_ivar': val, 
#               'parallax_zpt': -0.03}  # MAGIC NUMBERs
#     joa = Joaquin(block_train_X, 
#                   block_train_y, 
#                   block_train_y_ivar, 
#                   idx_map, 
#                   frozen=frozen)
#     res = joa.optimize(options={'maxiter': 1000})
#     print(joa(res.x)[0])

In [None]:
# See previous cell
frozen = {'L2_ivar': 1e-3, 
          'parallax_zpt': -0.03}  # MAGIC NUMBERs

joa = Joaquin(X[:4096], y[:4096], y_ivar[:4096], 
              idx_map, frozen=frozen)

In [None]:
res = joa.optimize(options={'maxiter': 1000})

In [None]:
fit_pars = joa.unpack_pars(res.x)

In [None]:
plt.figure(figsize=(15, 5))
plt.plot(fit_pars['beta'])

In [None]:
pred_plx = np.exp(np.dot(X, fit_pars['beta'])) - fit_pars['parallax_zpt']
chi = (pred_plx - y) * np.sqrt(y_ivar)

test_pred_plx = np.exp(np.dot(test_X, fit_pars['beta'])) - fit_pars['parallax_zpt']
test_chi = (test_pred_plx - test_y) * np.sqrt(test_y_ivar)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

ax = axes[0]
ax.plot(y,
        pred_plx,
        marker='o', ls='none', mew=0, ms=1., alpha=0.4)
ax.set_xlim(-0.5, 1.5)
ax.set_ylim(ax.get_xlim())
ax.set_xlabel('Gaia plx')
ax.set_ylabel('Joaquin plx')

_grid = np.linspace(-0.5, 1.5, 10)
ax.plot(_grid, _grid, marker='', 
        zorder=-10, color='#aaaaaa')

ax = axes[1]
ax.plot(y,
        chi,
        marker='o', ls='none', mew=0, ms=1.5, alpha=0.75)
ax.set_xlim(-0.5, 1.5)
ax.set_ylim(-8, 8)
ax.set_xlabel('Gaia plx')
ax.set_ylabel(r'$\chi$')

fig.tight_layout()

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

ax = axes[0]
ax.plot(test_block.stars['GAIAEDR3_PARALLAX'],
        test_pred_plx,
        marker='o', ls='none', mew=0, ms=1.5, alpha=0.75)
ax.set_xlim(-0.5, 1.5)
ax.set_ylim(ax.get_xlim())
ax.set_xlabel('Gaia plx')
ax.set_ylabel('Joaquin plx')

_grid = np.linspace(-0.5, 1.5, 10)
ax.plot(_grid, _grid, marker='', 
        zorder=-10, color='#aaaaaa')

ax = axes[1]
ax.plot(test_block.stars['GAIAEDR3_PARALLAX'],
        test_chi,
        marker='o', ls='none', mew=0, ms=1.5, alpha=0.75)
ax.set_xlim(-0.5, 1.5)
ax.set_ylim(-8, 8)
ax.set_xlabel('Gaia plx')
ax.set_ylabel(r'$\chi$')

fig.tight_layout()

In [None]:
for ccc in [chi, test_chi]:
    plt.figure()
    plt.hist(ccc, bins=np.linspace(-5, 5, 64));
    for x in np.percentile(ccc, [16, 84]):
        plt.axvline(x, color='tab:blue')

    plt.axvline(1, linestyle='--', color='#666666')
    plt.axvline(-1, linestyle='--', color='#666666')

In [None]:
def fit_K_batches(data, K, frozen=None, optimize_kw=None):
    """
    TODO: could take a pool argument and parallelize the loop below
    """

    if frozen is None:
        frozen = dict()

    if optimize_kw is None:
        optimize_kw = dict()
    optimize_kw.setdefault('options', {'maxiter': 1_000})  # TODO: make this bigger

    train_batches, test_batches = get_Kfold_indices(clean_stars, K=K)

    batch_fit_pars = []
    batch_res = []
    test_loss = []
    for k, (train_batch, test_batch) in enumerate(zip(train_batches, test_batches)):
        joa = Joaquin(data[train_batch], frozen=frozen)
        test_joa = Joaquin(data[test_batch], frozen=frozen)
        
        res = joa.optimize(**optimize_kw)
        fit_pars = joa.unpack_pars(res.x)

        batch_res.append(res)
        batch_fit_pars.append(fit_pars)
        
        # evaluate the fit model on the test batch
        test_loss.append(test_joa.neg_ln_posterior(**fit_pars)[0])

    return batch_fit_pars, batch_res, test_loss

In [None]:
from scipy.optimize import minimize

def cross_validate_hyperpars(data, K, frozen, **kwargs):
    kwargs = kwargs.copy()
    kwargs.setdefault('method', 'powell')
    
    # HACK / BAD: hardcoded names
    assert len(frozen) == 1
    if 'L2_ivar' in frozen:
        xval_par = 'parallax_zpt'
        kwargs.setdefault('x0', -0.03)
        
    elif 'parallax_zpt' in frozen:
        xval_par = 'L2_ivar'
        kwargs.setdefault('x0', 1e2)
    
    def objective(p):
        pars = frozen.copy()
        pars[xval_par] = p
        fit_pars, reses, losses = fit_K_batches(data, K, frozen=pars)
        return sum(losses)
    
    res = minimize(objective, **kwargs)
    return {xval_par: float(res.x)}, res