In [None]:
import os
os.environ['APOGEE_CACHE_PATH'] = "/mnt/ceph/users/apricewhelan/apogee-test/"
os.environ['JOAQUIN_CACHE_PATH'] = "/mnt/ceph/users/apricewhelan/projects/joaquin/cache"
import warnings
warnings.filterwarnings('ignore', category=Warning) 
import pickle

import sys
import pathlib
_path = str(pathlib.Path('../').resolve())
if _path not in sys.path:
    sys.path.append(_path)

import corner
from astropy.io import fits
import astropy.coordinates as coord
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from tqdm.auto import tqdm
from sklearn.decomposition import PCA
from scipy.spatial import cKDTree

from joaquin.data import JoaquinData
from joaquin.config import dr, root_cache_path, neighborhood_size, testing_zone_size
from joaquin.plot import simple_corner

In [None]:
patching_n_components = 8

In [None]:
cache_path = pathlib.Path(f'../cache/{dr}').resolve()
cache_path.mkdir(exist_ok=True, parents=True)

plot_path = pathlib.Path('../plot') / dr
plot_path = plot_path.resolve()
plot_path.mkdir(parents=True, exist_ok=True)

See the first two notebooks (1- and 2-) to set up the necessary files...

In [None]:
parent_data = JoaquinData(
    cache_file='parent-sample-raw')

spec_good_mask = np.load(root_cache_path / 'spec_good_mask.npy')

neighborhood_idx = np.load(cache_path / 'good_parent_neighborhood_indices.npy')

In [None]:
parent_stars = parent_data.stars[parent_data.stars_mask]
parent_d, *_ = parent_data.get_Xy(spec_mask_thresh=1.)  # disable spec mask
assert len(parent_stars) == parent_d['X'].shape[0]

## PCA patching

In [None]:
for idx in neighborhood_idx[10:]:
    pca = PCA(n_components=patching_n_components)
    subX = spec_parent_d['X'][idx].copy()
    subX[:, ~spec_good_mask] = 0.
    
    subX_pca = pca.fit_transform(subX)
    tmp_patched = pca.inverse_transform(subX_pca)
    
    subX_patched = subX.copy()
    subX_patched[subX_patched == 0] = tmp_patched[subX_patched == 0]
    
    break

In [None]:
npix_fixed = (subX[:, spec_good_mask] == 0).sum()
assert (subX_patched[:, spec_good_mask] == 0).sum() == 0

print(f"{npix_fixed} pixels patched, ~{npix_fixed/subX.shape[0]:.0f} pixels per star")

In [None]:
plt.figure(figsize=(15, 3))
for i in range(10):
    tmp = subX[i, :].copy()
    tmp[~spec_good_mask] = np.nan
    plt.plot(tmp, marker='')

In [None]:
plt.figure(figsize=(15, 3))
for i in range(10):
    tmp = subX_patched[i, :].copy()
    tmp[~spec_good_mask] = np.nan
    plt.plot(tmp, marker='')

## Low-pass filter

In [None]:
from joaquin.filters import nufft_lowpass

In [None]:
foo = (parent_data.stars[idx]['SNR'] < 60) & (parent_data.stars[idx]['SNR'] > 40)
i = np.where(foo)[0][0]

In [None]:
# i = 4

new_ln_flux = nufft_lowpass(np.log(parent_data._X_wvln), 
                            subX_patched[i],
                            fcut=0.5 * 22500,
                            bad_mask=~spec_good_mask)
new_ln_flux[~spec_good_mask] = np.nan

for lim in [False, 'zoom', 'zoomer']:
    plt.figure(figsize=(16, 5))
    plt.plot(parent_data._X_wvln, subX[i], marker='', drawstyle='steps-mid')
    plt.plot(parent_data._X_wvln, subX_patched[i], marker='', drawstyle='steps-mid')
    plt.plot(parent_data._X_wvln, new_ln_flux, marker='', drawstyle='steps-mid')
    if lim == 'zoom':
        plt.xlim(16000, 16500)
    elif lim == 'zoomer':
        plt.xlim(16150, 16220)

## Now try running the rest of the pipeline

Training sample is the full neighborhood, with some parallax and S/N cuts:

In [None]:
neighborhood_stars = parent_stars[idx]
neighborhood_X = np.hstack((subX_patched[:, spec_good_mask], 
                            lsfphot_parent_d['X'][idx]))

test_stars = neighborhood_stars[:testing_zone_size]
test_X = neighborhood_X[:testing_zone_size]

len(neighborhood_stars), len(test_stars)

In [None]:
train_mask = ((neighborhood_stars['SNR'] > 100) &
              (neighborhood_stars['GAIAEDR3_PARALLAX_ERROR'] < 0.1))
# TODO: add RUWE selection

test_mask = ((test_stars['SNR'] > 100) &
             (test_stars['GAIAEDR3_PARALLAX_ERROR'] < 0.1))
validate_mask = (
    (test_stars['SNR'] > 100) &
    ((test_stars['GAIAEDR3_PARALLAX'] / test_stars['GAIAEDR3_PARALLAX_ERROR']) > 20))

train_mask.sum(), test_mask.sum(), validate_mask.sum()

In [None]:
bins = np.linspace(-0.5, 5, 256)
plt.hist(neighborhood_stars['GAIAEDR3_PARALLAX'][train_mask], 
         bins=bins);

plt.hist(test_stars['GAIAEDR3_PARALLAX'][test_mask], 
         bins=bins);

plt.yscale('log')

In [None]:
# from skimage.transform import downscale_local_mean

# tmp_X = downscale_local_mean(
#     neighborhood_X[train_mask],
#     (4, 4))

# fig, ax = plt.subplots(figsize=(15, 7.5))
# ax.imshow(tmp_X, 
#           origin='lower', 
#           vmin=np.percentile(tmp_X, 5), 
#           vmax=np.percentile(tmp_X, 95))
# # ax.set_aspect(2)
# fig.tight_layout()

In [None]:
class Joaquin:

    def __init__(self, X, y, y_ivar, idx_map, frozen=None):
        self.X = X 
        self.y = y
        self.y_ivar = y_ivar

        # Currently, stores parameter names and shapes
        self._param_info = {}

        # duh
        self._param_info['parallax_zpt'] = 1

        # the inv-var of the prior on the spectral components in beta
        self._param_info['L2_ivar'] = 1

        # linear coefficients (in the exp argument)
        self._param_info['beta'] = self.X.shape[1]

        if 'spec' in terms:
            L2_slice = self.idx_map['spec']
        else:
            L2_slice = np.ones(self.X.shape[1], dtype=bool)
        self.L2_slice = L2_slice

        if frozen is None:
            frozen = {}
        self.frozen = frozen

    def unpack_pars(self, par_list):
        i = 0
        par_dict = {}
        for key, par_len in self._param_info.items():
            if key in self.frozen:
                par_dict[key] = self.frozen[key]
            else:
                par_dict[key] = np.array(par_list[i:i+par_len])
                if len(par_dict[key]) == 1:  # HORRIBLE
                    par_dict[key] = par_dict[key][0]

                i += par_len

        return par_dict

    def pack_pars(self, par_dict):
        parvec = []
        for i, k in enumerate(self._param_info):
            if k not in self.frozen:
                parvec.append(par_dict[k])
        return np.concatenate(parvec)

    def init_beta(self, parallax_zpt=None, L2_ivar=None):
        parallax_zpt = self.frozen.get('parallax_zpt', parallax_zpt)
        L2_ivar = self.frozen.get('L2_ivar', L2_ivar)

        if parallax_zpt is None or L2_ivar is None:
            raise ValueError('todo')

        y = self.y + parallax_zpt
        plx_mask = y > (3 / np.sqrt(self.y_ivar))  # 3 sigma

        X = self.X[plx_mask]
        y = y[plx_mask]
        y_ivar = self.y_ivar[plx_mask]

        ln_plx_ivar = y**2 * y_ivar
        ln_y = np.log(y)

        XT_Cinv = X.T * ln_plx_ivar
        XT_Cinv_X = np.dot(XT_Cinv, X)
        XT_Cinv_X[np.diag_indices(X.shape[1])] += L2_ivar

        beta = np.linalg.solve(XT_Cinv_X, np.dot(XT_Cinv, ln_y))
        return beta

    def chi(self, parallax_zpt, L2_ivar, beta):
        y = self.y + parallax_zpt
        model_ln_plx = np.dot(self.X, beta)
        model_y = np.exp(model_ln_plx)
        resid = y - model_y
        return resid * np.sqrt(self.y_ivar)

    def ln_likelihood(self, parallax_zpt, L2_ivar, beta):
        y = self.y + parallax_zpt
        model_ln_plx = np.dot(self.X, beta)
        model_y = np.exp(model_ln_plx)
        resid = y - model_y

        ll = -0.5 * np.sum(resid**2 * self.y_ivar)
        ll_grad = np.dot(self.X.T,
                         model_y * self.y_ivar * resid)

        return ll, ll_grad

    def ln_prior(self, parallax_zpt, L2_ivar, beta):
        lp = - 0.5 * L2_ivar * np.sum(beta[self.L2_slice] ** 2)
        lp_grad = np.zeros_like(beta)
        lp_grad[self.L2_slice] = - L2_ivar * beta[self.L2_slice]
        return lp, lp_grad

    def neg_ln_posterior(self, parallax_zpt, L2_ivar, beta):
        ll, ll_grad = self.ln_likelihood(parallax_zpt, L2_ivar, beta)
        lp, lp_grad = self.ln_prior(parallax_zpt, L2_ivar, beta)
        logger.log(0, f'objective function evaluation: ll={ll}, lp={lp}')
        return - (ll + lp), - (ll_grad + lp_grad)

    def __call__(self, p):
        par_dict = self.unpack_pars(p)
        return self.neg_ln_posterior(**par_dict)

    def optimize(self, init=None, **minimize_kwargs):
        """
        To set the maximum number of function evaluations, pass:

            options={'maxfun': ...}

        """
        if init is None:
            init = {}

        init.setdefault('parallax_zpt', 0.)
        init.setdefault('L2_ivar', 1.)

        if 'beta' not in init:
            init['beta'] = self.init_beta(**init)

        x0 = self.pack_pars(init)

        minimize_kwargs.setdefault('method', 'L-BFGS-B')
        if minimize_kwargs['method'] == 'L-BFGS-B':
            minimize_kwargs.setdefault('options', {'maxfun': 1024})

        res = minimize(
            self,
            x0=x0,
            jac=True,
            **minimize_kwargs)

        return res


In [None]:
def get_Kfold_indices(stars, K, rng=None):
    
    if rng is None:
        rng = np.random.default_rng()
        
    idx = np.arange(len(stars))
    rng.shuffle(idx)
    
    batch_size = len(stars) // K
    train_batches = []
    test_batches = []
    for k in range(K):
        if k == K-1:
            batch = idx[k*batch_size:]
        else:
            batch = idx[k*batch_size:(k+1)*batch_size]
            
        test_batches.append(batch)
        train_batches.append(idx[~np.isin(idx, batch)])
        
    assert np.all(np.array([len(train_batches[i]) + len(test_batches[i]) 
                            for i in range(len(train_batches))]) == len(stars))
    
    return train_batches, test_batches

In [None]:
def fit_K_batches(data, K, frozen=None, optimize_kw=None):
    """
    TODO: could take a pool argument and parallelize the loop below
    """

    if frozen is None:
        frozen = dict()

    if optimize_kw is None:
        optimize_kw = dict()
    optimize_kw.setdefault('options', {'maxiter': 1_000})  # TODO: make this bigger

    train_batches, test_batches = get_Kfold_indices(clean_stars, K=K)

    batch_fit_pars = []
    batch_res = []
    test_loss = []
    for k, (train_batch, test_batch) in enumerate(zip(train_batches, test_batches)):
        joa = Joaquin(data[train_batch], frozen=frozen)
        test_joa = Joaquin(data[test_batch], frozen=frozen)
        
        res = joa.optimize(**optimize_kw)
        fit_pars = joa.unpack_pars(res.x)

        batch_res.append(res)
        batch_fit_pars.append(fit_pars)
        
        # evaluate the fit model on the test batch
        test_loss.append(test_joa.neg_ln_posterior(**fit_pars)[0])

    return batch_fit_pars, batch_res, test_loss

In [None]:
from scipy.optimize import minimize

def cross_validate_hyperpars(data, K, frozen, **kwargs):
    kwargs = kwargs.copy()
    kwargs.setdefault('method', 'powell')
    
    # HACK / BAD: hardcoded names
    assert len(frozen) == 1
    if 'L2_ivar' in frozen:
        xval_par = 'parallax_zpt'
        kwargs.setdefault('x0', -0.03)
        
    elif 'parallax_zpt' in frozen:
        xval_par = 'L2_ivar'
        kwargs.setdefault('x0', 1e2)
    
    def objective(p):
        pars = frozen.copy()
        pars[xval_par] = p
        fit_pars, reses, losses = fit_K_batches(data, K, frozen=pars)
        return sum(losses)
    
    res = minimize(objective, **kwargs)
    return {xval_par: float(res.x)}, res