### Issues:
- The missing data are being filled in wrong. The NUFFT will pull missing regions down to 1, but we should really use PCA to patch the missing data. So, iteratively: low-pass filter, PCA patch the previously missing data, iterate.
- Lowpass filter is whack and could be checked against Barnett/DFM
- The LSF housekeeping data is super hacky - do we need something better?

In [None]:
import os
os.environ['APOGEE_CACHE_PATH'] = "/mnt/ceph/users/apricewhelan/apogee-test/"

import sys
import pathlib
_path = str(pathlib.Path('../').resolve())
if _path not in sys.path:
    sys.path.append(_path)

import corner
from astropy.io import fits
import astropy.coordinates as coord
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from tqdm.auto import tqdm

from joaquin import Joaquin
from joaquin.design_matrix import JoaquinData
from joaquin.config import phot_names, dr, Kfold_K
from joaquin.logger import logger
from joaquin.plot import simple_corner, phot_to_label

from gala.mpl_style import hesperia, laguna

In [None]:
cache_path = pathlib.Path(f'../cache/{dr}').resolve()
cache_path.mkdir(exist_ok=True, parents=True)

In [None]:
parent = at.Table.read(cache_path / 'parent-sample.fits')

parent_stars = parent[
    (parent['LOGG'] < 2.2) & 
    (parent['LOGG'] > 1.5) &
    (parent['TEFF'] > 3500) &
    (parent['TEFF'] < 5000)]

# HACK: subselect for speed
np.random.seed(42)
idx = np.random.choice(len(parent_stars), size=4096, replace=False)
parent_stars = parent_stars[idx]

len(parent_stars)

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))

bins = (np.linspace(3000, 7500, 128),
        np.linspace(0, 5.5, 128))
ax.hist2d(parent['TEFF'], parent['LOGG'],
          bins=bins, norm=mpl.colors.LogNorm(),
          cmap='magma_r')

ax.plot(parent_stars['TEFF'],
        parent_stars['LOGG'],
        ls='none', marker='o', mew=0, ms=3., 
        color='tab:blue', alpha=0.75)

ax.set_ylim(5.5, 0)
ax.set_xlim(7500, 3000)

fig.tight_layout()

In [None]:
data, star_mask = JoaquinData.from_stars(
    parent_stars, cache_path=cache_path, 
    spec_mask_thresh=0.2)  # MAGIC NUMBER

In [None]:
plt.figure(figsize=(15, 5))
plt.plot(data._spec_mask_vals)

In [None]:
clean_stars = parent_stars[star_mask]
clean_data = data[star_mask]

### Get training sample from parent sample

In [None]:
train_mask = ((clean_stars['SNR'] > 100) &
              (clean_stars['GAIAEDR3_PARALLAX_ERROR'] < 0.1))
# TODO: add RUWE selection

check_mask = (clean_stars['GAIAEDR3_PARALLAX'] / clean_stars['GAIAEDR3_PARALLAX_ERROR']) > 20

train_mask.sum(), len(clean_stars)

In [None]:
train_stars = clean_stars[train_mask]
train_data = clean_data[train_mask]

In [None]:
plt.hist(train_stars['GAIAEDR3_PARALLAX'], 
         bins=np.linspace(-0.5, 2, 128));
plt.yscale('log')

In [None]:
from skimage.transform import downscale_local_mean

for sort_by in ['TEFF']:
    tmp_X, *_ = train_data.get_sub_Xy('spec')
    tmp_X = downscale_local_mean(
        tmp_X[train_stars[sort_by].argsort()],
        (4, 4))

    fig, ax = plt.subplots(figsize=(15, 7.5))
    ax.imshow(tmp_X, 
              origin='lower', 
              vmin=np.percentile(tmp_X, 5), 
              vmax=np.percentile(tmp_X, 95))
    # ax.set_aspect(2)
    fig.tight_layout()

In [None]:
tmp_X, *_ = train_data.get_sub_Xy('spec')
_, vals, _ = np.linalg.svd(tmp_X, full_matrices=False)

In [None]:
plt.plot(vals)
plt.xscale('log')
plt.yscale('log')
plt.ylim(1e-4, 1e2)
plt.axvline(min(tmp_X.shape))

In [None]:
def get_Kfold_indices(stars, K, rng=None):
    
    if rng is None:
        rng = np.random.default_rng()
        
    idx = np.arange(len(stars))
    rng.shuffle(idx)
    
    batch_size = len(stars) // K
    train_batches = []
    test_batches = []
    for k in range(K):
        if k == K-1:
            batch = idx[k*batch_size:]
        else:
            batch = idx[k*batch_size:(k+1)*batch_size]
            
        test_batches.append(batch)
        train_batches.append(idx[~np.isin(idx, batch)])
        
    assert np.all(np.array([len(train_batches[i]) + len(test_batches[i]) 
                            for i in range(len(train_batches))]) == len(stars))
    
    return train_batches, test_batches

In [None]:
def fit_K_batches(data, K, frozen=None, optimize_kw=None):
    """
    TODO: could take a pool argument and parallelize the loop below
    """

    if frozen is None:
        frozen = dict()

    if optimize_kw is None:
        optimize_kw = dict()
    optimize_kw.setdefault('options', {'maxiter': 1_000})  # TODO: make this bigger

    train_batches, test_batches = get_Kfold_indices(clean_stars, K=K)

    batch_fit_pars = []
    batch_res = []
    test_loss = []
    for k, (train_batch, test_batch) in enumerate(zip(train_batches, test_batches)):
        joa = Joaquin(data[train_batch], frozen=frozen)
        test_joa = Joaquin(data[test_batch], frozen=frozen)
        
        res = joa.optimize(**optimize_kw)
        fit_pars = joa.unpack_pars(res.x)

        batch_res.append(res)
        batch_fit_pars.append(fit_pars)
        
        # evaluate the fit model on the test batch
        test_loss.append(test_joa.neg_ln_posterior(**fit_pars)[0])

    return batch_fit_pars, batch_res, test_loss

In [None]:
from scipy.optimize import minimize

def cross_validate_hyperpars(data, K, frozen, **kwargs):
    kwargs = kwargs.copy()
    kwargs.setdefault('method', 'powell')
    
    # HACK / BAD: hardcoded names
    assert len(frozen) == 1
    if 'L2_ivar' in frozen:
        xval_par = 'parallax_zpt'
        kwargs.setdefault('x0', -0.03)
        
    elif 'parallax_zpt' in frozen:
        xval_par = 'L2_ivar'
        kwargs.setdefault('x0', 1e2)
    
    def objective(p):
        pars = frozen.copy()
        pars[xval_par] = p
        fit_pars, reses, losses = fit_K_batches(data, K, frozen=pars)
        return sum(losses)
    
    res = minimize(objective, **kwargs)
    return {xval_par: float(res.x)}, res

In [None]:
out = cross_validate_hyperpars(clean_data, K=8, 
                               frozen={'parallax_zpt': -0.03})

In [None]:
%load_ext line_profiler

In [None]:
joa = Joaquin(clean_data, frozen=frozen)
beta = joa.init_beta()
pars = joa.unpack_pars(beta)

In [None]:
%lprun -f joa.ln_likelihood joa.ln_likelihood(**pars)

In [None]:
def cross_validate_param(param, joa, batches, frozen=):

In [None]:
frozen = {'L2_ivar': 1e-1, 
          'parallax_zpt': -0.03}  # MAGIC NUMBERs

joa = Joaquin(train_data, 
              frozen=frozen)

In [None]:
lsf_X, *_ = train_data.get_sub_Xy('lsf')
axes = None
for tele in ['apo25m', 'lco25m']:
    mask = train_stars['TELESCOPE'] == tele
    
    if axes is None:
        fig, axes = simple_corner(lsf_X[mask], 
                                  color_by=train_stars['MEANFIB'][mask],
                                  alpha=0.75, cmap=hesperia, 
                                  labels=[r'$a_{\rm b}$', r'$b_{\rm b}$', 
                                          r'$a_{\rm g}$', r'$b_{\rm g}$',
                                          r'$a_{\rm r}$', r'$b_{\rm r}$'])
    else:
        fig, axes = simple_corner(lsf_X[mask], 
                                  color_by=train_stars['MEANFIB'][mask],
                                  alpha=0.75, cmap=laguna, axes=axes)
        
fig.set_facecolor('w')

---

Optimizing the model

In [None]:
init_beta = joa.init_beta()

In [None]:
plt.figure(figsize=(15, 4))
plt.plot(init_beta)
plt.ylabel('init beta')

In [None]:
res = joa.optimize(options={'maxiter': 10_000})
# res, wrapper, ps = joa.optimize()

In [None]:
res

In [None]:
fit_pars = joa.unpack_pars(res.x)

In [None]:
plt.figure(figsize=(15, 4))
plt.plot(init_beta[joa.idx_map['spec']] - fit_pars['beta'][joa.idx_map['spec']])
plt.ylabel('init beta - fit beta')

plt.figure(figsize=(15, 4))
plt.plot(init_beta[joa.idx_map['spec']], label='init')
plt.plot(fit_pars['beta'][joa.idx_map['spec']], label='fit')
plt.ylabel('beta')
plt.legend(loc='best')

In [None]:
pred_plx = np.exp(np.dot(joa.X, fit_pars['beta']))

In [None]:
bins = np.linspace(-0.5, 2, 128)
plt.hist(train_stars['GAIAEDR3_PARALLAX'], 
         bins=bins);
plt.hist(pred_plx, bins=bins)
plt.yscale('log')

In [None]:
chi = joa.chi(**fit_pars)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

ax = axes[0]
ax.plot(train_stars['GAIAEDR3_PARALLAX'], 
        pred_plx,
        marker='o', ls='none', mew=0, ms=1.5, alpha=0.75)
ax.set_xlim(-0.5, 1.5)
ax.set_ylim(ax.get_xlim())
ax.set_xlabel('Gaia plx')
ax.set_ylabel('Joaquin plx')

ax = axes[1]
ax.plot(train_stars['GAIAEDR3_PARALLAX'], 
        chi,
        marker='o', ls='none', mew=0, ms=1.5, alpha=0.75)
ax.set_xlim(-0.5, 1.5)
# ax.set_ylim(ax.get_xlim())
ax.set_xlabel('Gaia plx')
ax.set_ylabel(r'$\chi$')

fig.tight_layout()

In [None]:
plt.hist(chi, bins=np.linspace(-5, 5, 64));
for x in np.percentile(chi, [16, 84]):
    plt.axvline(x, color='tab:blue')
    
plt.axvline(1, linestyle='--', color='#666666')
plt.axvline(-1, linestyle='--', color='#666666')

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

ax = axes[0]
ax.scatter(train_stars['GAIAEDR3_PARALLAX'], 
           pred_plx,
           c=np.log(train_stars['VSCATTER']), 
           vmin=-1, vmax=4, cmap='turbo',
           marker='o', lw=0, s=10, alpha=0.75)

ax.errorbar(train_stars['GAIAEDR3_PARALLAX'], 
            pred_plx,
            xerr=train_stars['GAIAEDR3_PARALLAX_ERROR'],
            marker='', ls='', ecolor='#666666', 
            elinewidth=0.5, alpha=0.5)

ax.set_xlim(-0.5, 1.5)
ax.set_ylim(ax.get_xlim())
ax.set_xlabel('Gaia plx')
ax.set_ylabel('Joaquin plx')

ax = axes[1]
ax.scatter(train_stars['GAIAEDR3_PARALLAX'], 
           chi,
           c=np.log(train_stars['VSCATTER']), 
           vmin=-1, vmax=4, cmap='turbo',
           marker='o', lw=0, s=10, alpha=0.75)
ax.set_xlim(-0.5, 1.5)
# ax.set_ylim(ax.get_xlim())
ax.set_xlabel('Gaia plx')
ax.set_ylabel(r'$\chi$')

fig.tight_layout()
fig.set_facecolor('w')

In [None]:
# fig, ax = plt.subplots(1, 1, figsize=(5, 5))

# ax.scatter(train_stars['GAIAEDR3_PARALLAX_ERROR'], 
#            chi,
#            c=np.log(train_stars['VSCATTER']), 
#            vmin=-1, vmax=4, cmap='turbo',
#            marker='o', lw=0, s=10, alpha=0.75)
# ax.set_xlim(-0.01, 0.1)
# # ax.set_ylim(ax.get_xlim())
# ax.set_xlabel('Gaia plx error')
# ax.set_ylabel(r'$\chi$')

# fig.tight_layout()
# fig.set_facecolor('w')

In [None]:
# Photometry / colors:
plot_X = []
labels = []

colors = [
    ('GAIAEDR3_PHOT_BP_MEAN_MAG', 'GAIAEDR3_PHOT_RP_MEAN_MAG'),
    ('J', 'K'),
    ('w1mpro', 'w3mpro'),
    ('GAIAEDR3_PHOT_G_MEAN_MAG', 'J'),
    ('H', 'w2mpro')
]
for i, (p1, p2) in enumerate(colors):
    vals = (joa.X[:, train_data.idx_map['phot'][phot_names.index(p1)]] -
            joa.X[:, train_data.idx_map['phot'][phot_names.index(p2)]])
    plot_X.append(vals)
    
    lbl1 = p1
    if p1 in phot_to_label:
        lbl1 = phot_to_label[p1]
    
    lbl2 = p2
    if p2 in phot_to_label:
        lbl2 = phot_to_label[p2]
    
    lbl = f"{lbl1} $-$ {lbl2}"
    labels.append(lbl)
    
plot_X = np.array(plot_X).T

fig, axes, cb = simple_corner(
    plot_X, 
    color_by=chi,
    colorbar=True,
    labels=labels,
    vmin=-3, vmax=3, s=8,
    alpha=0.75, cmap='RdBu')
cb.ax.set_aspect(40)

fig.set_facecolor('w')

In [None]:
# Housekeeping:
plot_X = train_data.get_sub_Xy(['lsf'])[0]
labels = [r'$a_{\rm b}$', r'$b_{\rm b}$', 
          r'$a_{\rm g}$', r'$b_{\rm g}$',
          r'$a_{\rm r}$', r'$b_{\rm r}$']

fig, axes, cb = simple_corner(
    plot_X, 
    color_by=chi,
    colorbar=True,
    labels=labels,
    vmin=-3, vmax=3, s=8,
    alpha=0.75, cmap='RdBu')
cb.ax.set_aspect(40)

fig.set_facecolor('w')

In [None]:
beta = joa.init_beta(L2_ivar=0.5)

In [None]:
p0 = [0., 0.5] + list(beta)
joa(p0)

In [None]:
# test = jax.value_and_grad(joa.__call__)
obj = jax.value_and_grad(neg_ln_posterior, argnums=[3, 4, 5])
def wrapper(*args, **kwargs):
    val, grads = obj(*args, **kwargs)
    return val, jnp.concatenate([g.reshape(-1) for g in grads])

In [None]:
# test(p0)
val, grad = wrapper(joa.X, joa.y, joa.y_ivar, 
                    0., 0.5, beta, joa.L2_slice)

In [None]:
grad

In [None]:
grad

In [None]:
jnp.dot(np.random.random(size=(10, 3)))

In [None]:
plt.plot(beta[dm.idx_map['lsf']])

In [None]:
plt.figure(figsize=(15, 3))
plt.plot(phot_names, beta[dm.idx_map['phot']])
plt.xticks(rotation=45, ha='right')

In [None]:
plt.figure(figsize=(15, 3))
plt.plot(beta[dm.idx_map['spec']])
plt.xlim(800, 1200)

### Old plots:

In [None]:
np.where(all_spec_mask)[0].size

In [None]:
pix = np.arange(8575, dtype='f8')
wvln = 10 ** (star_hdul[1].header['CRVAL1'] +
              np.arange(star_hdul[1].header['NAXIS1']) * star_hdul[1].header['CDELT1'])
ln_wvln = np.log(wvln)
flux = star_hdul[1].data
err = star_hdul[2].data

mask = (flux == 0 ) | (err > (3 * np.median(err)))

plt.figure(figsize=(15, 5))
plt.plot(wvln[~mask], flux[~mask], marker='', drawstyle='steps-mid')
# plt.plot(wvln[mask], flux[mask], marker='o', ls='none', color='r')

In [None]:
new_flux = nufft_lowpass(ln_wvln, flux, 
                         fcut=0.5 * 22500, bad_mask=mask)

In [None]:
plt.figure(figsize=(15, 5))
plt.plot(wvln, flux, marker='', drawstyle='steps-mid')
plt.plot(wvln, new_flux, 
         marker='', drawstyle='steps-mid', color='tab:blue')
plt.plot(wvln[mask], flux[mask], 
         marker='.', ls='none', color='r')
plt.xlim(15500+500, 15600+500)
plt.axhline(1.)

---

In [None]:
for star in stars[:4]:
    star_hdul = get_aspcapstar(star)
    lsf_hdul = get_lsf(star)
    
    plt.figure(figsize=(15, 4))
    plt.plot(lsf_hdul[0].data[:, 7], 
             marker='', drawstyle='steps-mid', alpha=0.5)