### Issues:
- The missing data are being filled in wrong. The NUFFT will pull missing regions down to 1, but we should really use PCA to patch the missing data. So, iteratively: low-pass filter, PCA patch the previously missing data, iterate.
- Lowpass filter is whack and could be checked against Barnett/DFM
- The LSF housekeeping data is super hacky - do we need something better?

In [None]:
import os
os.environ['APOGEE_CACHE_PATH'] = "/mnt/ceph/users/apricewhelan/apogee"

import sys
import pathlib
_path = str(pathlib.Path('../').resolve())
if _path not in sys.path:
    sys.path.append(_path)

import corner
from astropy.io import fits
import astropy.coordinates as coord
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from tqdm.auto import tqdm

from jax.config import config
config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp

from joaquin import Joaquin
from joaquin.features import default_phot_names as phot_names
from joaquin.logger import logger

from gala.mpl_style import hesperia, laguna

In [None]:
cache_path = pathlib.Path('../cache').resolve()
cache_path.mkdir(exist_ok=True)

In [None]:
allstar = at.Table.read('/mnt/home/apricewhelan/data/APOGEE_DR17/allStar-dr17-turbo20-beta.fits', 
                        hdu=1)
# allstar = at.Table.read('/mnt/home/apricewhelan/data/APOGEE_DR16/allStar-r12-gaiaedr3.fits')

wise = at.Table.read('/mnt/home/apricewhelan/data/APOGEE_DR17/APOGEE-DR17-wise-result.fits.gz')
wise.rename_column('apogee_id', 'APOGEE_ID')

allstar['APOGEE_ID'] = allstar['APOGEE_ID'].astype(str)
wise['APOGEE_ID'] = [x.strip() for x in wise['APOGEE_ID'].astype(str)]

In [None]:
allstar = at.join(allstar, wise, keys='APOGEE_ID', join_type='left')
allstar = at.unique(allstar, keys='APOGEE_ID')

In [None]:
phot_mask = np.ones(len(allstar), dtype=bool)
for name in phot_names:
    phot_mask &= np.isfinite(allstar[name]) & (allstar[name] > 0)

allstar['TELESCOPE'] = np.array([x.strip() for x in allstar['TELESCOPE']])
stars = allstar[(allstar['LOGG'] < 2.2) & 
                (allstar['LOGG'] > 1.5) &
                (allstar['TEFF'] > 3500) &
                (allstar['TEFF'] < 5000) &
                (allstar['SNR'] > 100) &
                np.isin(allstar['TELESCOPE'], ['apo25m', 'lco25m']) & 
                phot_mask]
len(stars)

In [None]:
joa = Joaquin(stars, # terms=['phot'],
              frozen={'L2_ivar': 1e-1, 
                      'parallax_zpt': -0.03})  # MAGIC NUMBER

In [None]:
plt.hist(joa.dm.stars['GAIAEDR3_PARALLAX'], 
         bins=np.linspace(-0.5, 2, 128));
plt.yscale('log')

TODO: also color by mean fiber number? MEANFIB

In [None]:
def simple_corner(X, labels=None, color_by=None, axes=None, 
                  colorbar=False, **style):
    if X.shape[1] > X.shape[0]:
        raise ValueError("I don't believe you")
        
    if color_by is None:
        plotfunc = 'plot'
        style.setdefault('marker', 'o')
        style.setdefault('mew', style.pop('markeredgewidth', 0))
        style.setdefault('ls', style.pop('linestyle', 'none'))
        style.setdefault('ms', style.pop('markersize', 2.))
    else:
        plotfunc = 'scatter'
        style.setdefault('marker', 'o')
        style.setdefault('lw', style.pop('linewidth', 0))
        style.setdefault('s', 5)
        style.setdefault('c', color_by)
        
    nside = X.shape[1] - 1
    
    # Some magic numbers for pretty axis layout.
    K = X.shape[1]
    factor = 2.0  # size of one side of one panel
    lbdim = 0.5 * factor  # size of left/bottom margin
    trdim = 0.2 * factor  # size of top/right margin
    whspace = 0.05  # w/hspace size
    plotdim = factor * K + factor * (K - 1.0) * whspace
    dim = lbdim + plotdim + trdim
    
    if axes is None:
        fig, axes = plt.subplots(nside, nside, 
                                 figsize=(dim, dim), # (3*nside, 3*nside),
                                 sharex='col', sharey='row',
                                 constrained_layout=True)
    else:
        fig = axes.flat[0].figure
        
    if not isinstance(axes, np.ndarray):
        axes = np.array([[axes]])
    
    n = 0
    cs = None
    for i in range(nside):
        for j in range(nside):
            ax = axes[i, j]
            if i < j:
                ax.set_visible(False)
            else:
                cs = getattr(ax, plotfunc)(X[:, j], X[:, i+1], **style)
    
    if labels is not None:
        for i in range(nside):
            axes[i, 0].set_ylabel(labels[i+1])

        for j in range(nside):
            axes[-1, j].set_xlabel(labels[j])
    
    return_stuff = [fig, axes]
    
    if colorbar and color_by is not None and cs is not None:
        cb = fig.colorbar(cs, ax=axes)
        return_stuff.append(cb)
    
    return return_stuff

In [None]:
axes = None
for tele in ['apo25m', 'lco25m']:
    mask = joa.dm.stars['TELESCOPE'] == tele
    
    if axes is None:
        fig, axes = simple_corner(joa.dm.get_Xy(['lsf'])[0][mask], 
                                  color_by=joa.dm.stars['MEANFIB'][mask],
                                  alpha=0.75, cmap=hesperia, 
                                  labels=[r'$a_{\rm b}$', r'$b_{\rm b}$', 
                                          r'$a_{\rm g}$', r'$b_{\rm g}$',
                                          r'$a_{\rm r}$', r'$b_{\rm r}$'])
    else:
        fig, axes = simple_corner(joa.dm.get_Xy(['lsf'])[0][mask], 
                                  color_by=joa.dm.stars['MEANFIB'][mask],
                                  alpha=0.75, cmap=laguna, axes=axes)
        
fig.set_facecolor('w')

---

Optimizing the model

In [None]:
init_beta = joa.init_beta()

In [None]:
plt.figure(figsize=(15, 4))
plt.plot(init_beta)
plt.ylabel('init beta')

In [None]:
res = joa.optimize(options={'maxfun': 10_000})
# res, wrapper, ps = joa.optimize()

In [None]:
res

In [None]:
fit_pars = joa.unpack_pars(res.x)

In [None]:
plt.figure(figsize=(15, 4))
plt.plot(init_beta[joa.idx_map['spec']] - fit_pars['beta'][joa.idx_map['spec']])
plt.ylabel('init beta - fit beta')

plt.figure(figsize=(15, 4))
plt.plot(init_beta[joa.idx_map['spec']], label='init')
plt.plot(fit_pars['beta'][joa.idx_map['spec']], label='fit')
plt.ylabel('beta')
plt.legend(loc='best')

In [None]:
pred_plx = np.exp(np.dot(joa.X, fit_pars['beta']))

In [None]:
bins = np.linspace(-0.5, 2, 128)
plt.hist(joa.dm.stars['GAIAEDR3_PARALLAX'], 
         bins=bins);
plt.hist(pred_plx, bins=bins)
plt.yscale('log')

In [None]:
chi = joa.chi(**fit_pars)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

ax = axes[0]
ax.plot(joa.dm.stars['GAIAEDR3_PARALLAX'], 
        pred_plx,
        marker='o', ls='none', mew=0, ms=1.5, alpha=0.75)
ax.set_xlim(-0.5, 1.5)
ax.set_ylim(ax.get_xlim())
ax.set_xlabel('Gaia plx')
ax.set_ylabel('Joaquin plx')

ax = axes[1]
ax.plot(joa.dm.stars['GAIAEDR3_PARALLAX'], 
        chi,
        marker='o', ls='none', mew=0, ms=1.5, alpha=0.75)
ax.set_xlim(-0.5, 1.5)
# ax.set_ylim(ax.get_xlim())
ax.set_xlabel('Gaia plx')
ax.set_ylabel(r'$\chi$')

fig.tight_layout()

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

ax = axes[0]
ax.scatter(joa.dm.stars['GAIAEDR3_PARALLAX'], 
           pred_plx,
           c=np.log(joa.dm.stars['VSCATTER']), 
           vmin=-1, vmax=4, cmap='turbo',
           marker='o', lw=0, s=10, alpha=0.75)

ax.errorbar(joa.dm.stars['GAIAEDR3_PARALLAX'], 
            pred_plx,
            xerr=joa.dm.stars['GAIAEDR3_PARALLAX_ERROR'],
            marker='', ls='', ecolor='#666666', 
            elinewidth=0.5, alpha=0.5)

ax.set_xlim(-0.5, 1.5)
ax.set_ylim(ax.get_xlim())
ax.set_xlabel('Gaia plx')
ax.set_ylabel('Joaquin plx')

ax = axes[1]
ax.scatter(joa.dm.stars['GAIAEDR3_PARALLAX'], 
           chi,
           c=np.log(joa.dm.stars['VSCATTER']), 
           vmin=-1, vmax=4, cmap='turbo',
           marker='o', lw=0, s=10, alpha=0.75)
ax.set_xlim(-0.5, 1.5)
# ax.set_ylim(ax.get_xlim())
ax.set_xlabel('Gaia plx')
ax.set_ylabel(r'$\chi$')

fig.tight_layout()
fig.set_facecolor('w')

In [None]:
phot_to_label = {
    'GAIAEDR3_PHOT_BP_MEAN_MAG': 'BP',
    'GAIAEDR3_PHOT_RP_MEAN_MAG': 'RP',
    'GAIAEDR3_PHOT_G_MEAN_MAG': 'G',
    'w1mpro': 'W1',
    'w2mpro': 'W2',
    'w3mpro': 'W3',
    'w3mpro': 'W4',
}

In [None]:
# Photometry / colors:
plot_X = []
labels = []

colors = [
    ('GAIAEDR3_PHOT_BP_MEAN_MAG', 'GAIAEDR3_PHOT_RP_MEAN_MAG'),
    ('J', 'K'),
    ('w1mpro', 'w3mpro'),
    ('GAIAEDR3_PHOT_G_MEAN_MAG', 'J'),
    ('H', 'w2mpro')
]
for i, (p1, p2) in enumerate(colors):
    vals = (joa.X[:, joa.idx_map['phot'][phot_names.index(p1)]] -
            joa.X[:, joa.idx_map['phot'][phot_names.index(p2)]])
    plot_X.append(vals)
    
    lbl1 = p1
    if p1 in phot_to_label:
        lbl1 = phot_to_label[p1]
    
    lbl2 = p2
    if p2 in phot_to_label:
        lbl2 = phot_to_label[p2]
    
    lbl = f"{lbl1} $-$ {lbl2}"
    labels.append(lbl)
    
plot_X = np.array(plot_X).T

fig, axes, cb = simple_corner(
    plot_X, 
    color_by=chi,
    colorbar=True,
    labels=labels,
    vmin=-4, vmax=4, s=8,
    alpha=0.75, cmap='RdBu')
cb.ax.set_aspect(40)

fig.set_facecolor('w')

In [None]:
# Housekeeping:
plot_X = joa.X[:, joa.idx_map['lsf']]
labels = [r'$a_{\rm b}$', r'$b_{\rm b}$', 
          r'$a_{\rm g}$', r'$b_{\rm g}$',
          r'$a_{\rm r}$', r'$b_{\rm r}$']

fig, axes, cb = simple_corner(
    plot_X, 
    color_by=chi,
    colorbar=True,
    labels=labels,
    vmin=-4, vmax=4, s=8,
    alpha=0.75, cmap='RdBu')
cb.ax.set_aspect(40)

fig.set_facecolor('w')

In [None]:
beta = joa.init_beta(L2_ivar=0.5)

In [None]:
p0 = [0., 0.5] + list(beta)
joa(p0)

In [None]:
# test = jax.value_and_grad(joa.__call__)
obj = jax.value_and_grad(neg_ln_posterior, argnums=[3, 4, 5])
def wrapper(*args, **kwargs):
    val, grads = obj(*args, **kwargs)
    return val, jnp.concatenate([g.reshape(-1) for g in grads])

In [None]:
# test(p0)
val, grad = wrapper(joa.X, joa.y, joa.y_ivar, 
                    0., 0.5, beta, joa.L2_slice)

In [None]:
grad

In [None]:
grad

In [None]:
jnp.dot(np.random.random(size=(10, 3)))

In [None]:
plt.plot(beta[dm.idx_map['lsf']])

In [None]:
plt.figure(figsize=(15, 3))
plt.plot(phot_names, beta[dm.idx_map['phot']])
plt.xticks(rotation=45, ha='right')

In [None]:
plt.figure(figsize=(15, 3))
plt.plot(beta[dm.idx_map['spec']])
plt.xlim(800, 1200)

### Old plots:

In [None]:
np.where(all_spec_mask)[0].size

In [None]:
pix = np.arange(8575, dtype='f8')
wvln = 10 ** (star_hdul[1].header['CRVAL1'] +
              np.arange(star_hdul[1].header['NAXIS1']) * star_hdul[1].header['CDELT1'])
ln_wvln = np.log(wvln)
flux = star_hdul[1].data
err = star_hdul[2].data

mask = (flux == 0 ) | (err > (3 * np.median(err)))

plt.figure(figsize=(15, 5))
plt.plot(wvln[~mask], flux[~mask], marker='', drawstyle='steps-mid')
# plt.plot(wvln[mask], flux[mask], marker='o', ls='none', color='r')

In [None]:
new_flux = nufft_lowpass(ln_wvln, flux, 
                         fcut=0.5 * 22500, bad_mask=mask)

In [None]:
plt.figure(figsize=(15, 5))
plt.plot(wvln, flux, marker='', drawstyle='steps-mid')
plt.plot(wvln, new_flux, 
         marker='', drawstyle='steps-mid', color='tab:blue')
plt.plot(wvln[mask], flux[mask], 
         marker='.', ls='none', color='r')
plt.xlim(15500+500, 15600+500)
plt.axhline(1.)

---

In [None]:
for star in stars[:4]:
    star_hdul = get_aspcapstar(star)
    lsf_hdul = get_lsf(star)
    
    plt.figure(figsize=(15, 4))
    plt.plot(lsf_hdul[0].data[:, 7], 
             marker='', drawstyle='steps-mid', alpha=0.5)