### Issues:
- The missing data are being filled in wrong. The NUFFT will pull missing regions down to 1, but we should really use PCA to patch the missing data. So, iteratively: low-pass filter, PCA patch the previously missing data, iterate.
- Lowpass filter is whack and could be checked against Barnett/DFM
- The LSF housekeeping data is super hacky - do we need something better?

In [None]:
import os
os.environ['APOGEE_CACHE_PATH'] = "/mnt/ceph/users/apricewhelan/apogee"

import sys
import pathlib
_path = str(pathlib.Path('../').resolve())
if _path not in sys.path:
    sys.path.append(_path)

import corner
from astropy.io import fits
import astropy.coordinates as coord
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from tqdm.auto import tqdm

from jax.config import config
config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp

from joaquin import Joaquin
from joaquin.features import default_phot_names as phot_names
from joaquin.logger import logger

In [None]:
logger.setLevel(0)

In [None]:
cache_path = pathlib.Path('../cache').resolve()
cache_path.mkdir(exist_ok=True)

In [None]:
allstar = at.Table.read('/mnt/home/apricewhelan/data/APOGEE_DR17/allStar-dr17-turbo20-beta.fits', 
                        hdu=1)
# allstar = at.Table.read('/mnt/home/apricewhelan/data/APOGEE_DR16/allStar-r12-gaiaedr3.fits')

wise = at.Table.read('/mnt/home/apricewhelan/data/APOGEE_DR17/APOGEE-DR17-wise-result.fits.gz')
wise.rename_column('apogee_id', 'APOGEE_ID')

allstar['APOGEE_ID'] = allstar['APOGEE_ID'].astype(str)
wise['APOGEE_ID'] = [x.strip() for x in wise['APOGEE_ID'].astype(str)]

In [None]:
allstar = at.join(allstar, wise, keys='APOGEE_ID', join_type='left')
allstar = at.unique(allstar, keys='APOGEE_ID')

In [None]:
phot_mask = np.ones(len(allstar), dtype=bool)
for name in phot_names:
    phot_mask &= np.isfinite(allstar[name]) & (allstar[name] > 0)

allstar['TELESCOPE'] = np.array([x.strip() for x in allstar['TELESCOPE']])
stars = allstar[(allstar['LOGG'] < 2.2) & 
                (allstar['LOGG'] > 1.5) &
                (allstar['TEFF'] > 3500) &
                (allstar['TEFF'] < 5000) &
                (allstar['SNR'] > 100) &
                np.isin(allstar['TELESCOPE'], ['apo25m', 'lco25m']) & 
                phot_mask][::10]
len(stars)

In [None]:
joa = Joaquin(stars, # terms=['phot'],
              frozen={'L2_ivar': 1e-1, 
                      'parallax_zpt': -0.03})  # MAGIC NUMBER

In [None]:
plt.hist(joa.dm.stars['GAIAEDR3_PARALLAX'], 
         bins=np.linspace(-0.5, 2, 128));
plt.yscale('log')

In [None]:
logger.handlers[0]

TODO: also color by mean fiber number?

In [None]:
fig = corner.corner(joa.dm.get_sub_Xy(['lsf'])[0][joa.dm.stars['TELESCOPE'] == 'apo25m'], 
                    plot_density=False, plot_contours=False, color='tab:blue',
                    labels=[r'$a_{\rm b}$', r'$b_{\rm b}$', 
                            r'$a_{\rm g}$', r'$b_{\rm g}$',
                            r'$a_{\rm r}$', r'$b_{\rm r}$'],
                    hist_kwargs=dict(label='APO'))
_ = corner.corner(joa.dm.get_sub_Xy(['lsf'])[0][joa.dm.stars['TELESCOPE'] == 'lco25m'], 
                  plot_density=False, plot_contours=False, color='tab:orange', 
                  fig=fig, hist_kwargs=dict(label='LCO'))

for ax in fig.axes:
    ax.relim()
    ax.autoscale()

fig.axes[0].legend(loc='upper left')

fig.set_facecolor('w')

In [None]:
# colors = [
#     ('GAIAEDR3_PHOT_G_MEAN_MAG', 'J'),
#     ('GAIAEDR3_PHOT_BP_MEAN_MAG', 'GAIAEDR3_PHOT_RP_MEAN_MAG'),
#     ('J', 'K'),
#     ('w1mpro', 'w2mpro'),
#     ('w3mpro', 'w4mpro')
# ]

# color_X = np.zeros((joa.X.shape[0], len(colors)))
# for i, (p1, p2) in enumerate(colors):
#     color_X[:, i] = (joa.X[:, phot_idx[phot_names.index(p1)]] -
#                      joa.X[:, phot_idx[phot_names.index(p2)]])
    
# _ = corner.corner(color_X, plot_density=False, plot_contours=False)

---

Optimizing the model

In [None]:
init_beta = joa.init_beta()

In [None]:
res = joa.optimize(options={'maxfun': 1000})
# res, wrapper, ps = joa.optimize()

In [None]:
res

In [None]:
fit_pars = joa.unpack_pars(res.x)

In [None]:
plt.figure(figsize=(15, 4))
plt.plot(init_beta[joa.dm.idx_map['spec']] - fit_pars['beta'][joa.dm.idx_map['spec']])
# plt.xlim(2000, 2500)
plt.ylabel('init beta - fit beta')

In [None]:
pred_plx = np.exp(np.dot(joa.X, fit_pars['beta']))

In [None]:
bins = np.linspace(-0.5, 2, 128)
plt.hist(joa.dm.stars['GAIAEDR3_PARALLAX'], 
         bins=bins);
plt.hist(pred_plx, bins=bins)
plt.yscale('log')

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

ax = axes[0]
ax.plot(joa.dm.stars['GAIAEDR3_PARALLAX'], 
        pred_plx,
        marker='o', ls='none', mew=0, ms=1, alpha=0.75)
ax.set_xlim(-0.5, 1.5)
ax.set_ylim(ax.get_xlim())
ax.set_xlabel('Gaia plx')
ax.set_ylabel('Joaquin plx')

ax = axes[1]
ax.plot(joa.dm.stars['GAIAEDR3_PARALLAX'], 
        joa.chi(**fit_pars),
        marker='o', ls='none', mew=0, ms=1, alpha=0.75)
ax.set_xlim(-0.5, 1.5)
# ax.set_ylim(ax.get_xlim())
ax.set_xlabel('Gaia plx')
ax.set_ylabel(r'$\chi$')

fig.tight_layout()

In [None]:
beta = joa.init_beta(L2_ivar=0.5)

In [None]:
p0 = [0., 0.5] + list(beta)
joa(p0)

In [None]:
# test = jax.value_and_grad(joa.__call__)
obj = jax.value_and_grad(neg_ln_posterior, argnums=[3, 4, 5])
def wrapper(*args, **kwargs):
    val, grads = obj(*args, **kwargs)
    return val, jnp.concatenate([g.reshape(-1) for g in grads])

In [None]:
# test(p0)
val, grad = wrapper(joa.X, joa.y, joa.y_ivar, 
                    0., 0.5, beta, joa.L2_slice)

In [None]:
grad

In [None]:
grad

In [None]:
jnp.dot(np.random.random(size=(10, 3)))

In [None]:
plt.plot(beta[dm.idx_map['lsf']])

In [None]:
plt.figure(figsize=(15, 3))
plt.plot(phot_names, beta[dm.idx_map['phot']])
plt.xticks(rotation=45, ha='right')

In [None]:
plt.figure(figsize=(15, 3))
plt.plot(beta[dm.idx_map['spec']])
plt.xlim(800, 1200)

### Old plots:

In [None]:
pix = np.arange(8575, dtype='f8')
wvln = 10 ** (star_hdul[1].header['CRVAL1'] +
              np.arange(star_hdul[1].header['NAXIS1']) * star_hdul[1].header['CDELT1'])
ln_wvln = np.log(wvln)
flux = star_hdul[1].data
err = star_hdul[2].data

mask = (flux == 0 ) | (err > (3 * np.median(err)))

plt.figure(figsize=(15, 5))
plt.plot(wvln[~mask], flux[~mask], marker='', drawstyle='steps-mid')
# plt.plot(wvln[mask], flux[mask], marker='o', ls='none', color='r')

In [None]:
new_flux = nufft_lowpass(ln_wvln, flux, 
                         fcut=0.5 * 22500, bad_mask=mask)

In [None]:
plt.figure(figsize=(15, 5))
plt.plot(wvln, flux, marker='', drawstyle='steps-mid')
plt.plot(wvln, new_flux, 
         marker='', drawstyle='steps-mid', color='tab:blue')
plt.plot(wvln[mask], flux[mask], 
         marker='.', ls='none', color='r')
plt.xlim(15500+500, 15600+500)
plt.axhline(1.)

---

In [None]:
for star in stars[:4]:
    star_hdul = get_aspcapstar(star)
    lsf_hdul = get_lsf(star)
    
    plt.figure(figsize=(15, 4))
    plt.plot(lsf_hdul[0].data[:, 7], 
             marker='', drawstyle='steps-mid', alpha=0.5)