In [None]:
import os
os.environ['APOGEE_CACHE_PATH'] = "/mnt/ceph/users/apricewhelan/apogee-test/"
import warnings
warnings.filterwarnings('ignore', category=Warning) 
import pickle

import sys
import pathlib
_path = str(pathlib.Path('../').resolve())
if _path not in sys.path:
    sys.path.append(_path)

import corner
from astropy.io import fits
import astropy.coordinates as coord
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from tqdm.auto import tqdm
from sklearn.decomposition import IncrementalPCA
from scipy.spatial import cKDTree

from joaquin import Joaquin
from joaquin.data import JoaquinData
from joaquin.config import all_phot_names, dr
from joaquin.logger import logger
from joaquin.plot import simple_corner, phot_to_label

In [None]:
cache_path = pathlib.Path(f'../cache/{dr}').resolve()
cache_path.mkdir(exist_ok=True, parents=True)

In [None]:
parent = at.Table.read(cache_path / 'parent-sample.fits')

parent_stars = parent[(parent['SNR'] > 100)]

# HACK: subselect for speed
np.random.seed(42)
idx = np.random.choice(len(parent_stars), size=8192, replace=False)
parent_stars = parent_stars[idx]

len(parent_stars)

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))

bins = (np.linspace(3000, 7500, 128),
        np.linspace(0, 5.5, 128))
ax.hist2d(parent['TEFF'], parent['LOGG'],
          bins=bins, norm=mpl.colors.LogNorm(),
          cmap='magma_r')

ax.plot(parent_stars['TEFF'],
        parent_stars['LOGG'],
        ls='none', marker='o', mew=0, ms=3., 
        color='tab:blue', alpha=0.75)

ax.set_ylim(5.5, 0)
ax.set_xlim(7500, 3000)

fig.tight_layout()

In [None]:
data = JoaquinData(parent_stars, lowpass=False, 
                   cache_path=cache_path)

In [None]:
(X, *_), idx_map = data.get_Xy(['spec'], 
                               spec_mask_thresh=1)

# HACK: replace mags with colors
# X = np.hstack((X[:, idx_map['spec']],
#                X[:, idx_map['phot'][:-1]] - X[:, idx_map['phot'][1:]]))
# X = X[:, idx_map['phot'][:-1]] - X[:, idx_map['phot'][1:]]

color_names = [
    ('GAIAEDR3_PHOT_BP_MEAN_MAG', 'J'),
    ('J', 'K'),
    ('J', 'w1mpro'),
    ('H', 'w2mpro'),
    ('w1mpro', 'w2mpro')
]
color_X = data.get_colors(color_names)
color_labels = [f'{x1}-{x2}' for x1, x2 in color_names]
X = np.hstack((X, color_X))

good_stars = data.stars[data.stars_mask]
assert X.shape[0] == len(good_stars)

In [None]:
# _ = simple_corner(color_X, labels=color_labels, 
#                   color_by=good_stars['LOGG'], vmin=0.5, vmax=5.5)

In [None]:
pca = IncrementalPCA(n_components=8, batch_size=1024)
projected_X = pca.fit_transform(X)

This hacky step removes egregious outliers - only run this after the cells above because it overwrites variables!!

In [None]:
mean = np.mean(projected_X, axis=0)
std = np.std(projected_X, axis=0)
bad_mask = np.any(np.abs(projected_X - mean) > 5*std, axis=1)

X = X[~bad_mask]
good_stars = good_stars[~bad_mask]

pca = IncrementalPCA(n_components=8, batch_size=1024)
projected_X = pca.fit_transform(X)

In [None]:
for i in range(pca.n_components):
    fig = plt.figure(figsize=(12, 3))
    plt.plot(pca.components_[i])

In [None]:
np.cumsum(pca.explained_variance_ratio_)

In [None]:
fig, axes = plt.subplots(3, 3, 
                         figsize=(10, 10),
                         sharex=True, sharey=True)

for i in range(pca.n_components):
    ax = axes.flat[i]
    ax.scatter(good_stars['TEFF'],
               good_stars['LOGG'],
               c=projected_X[:, i], s=6)

ax.set_ylim(5.5, 0)
ax.set_xlim(7500, 3000)

fig.tight_layout()

In [None]:
things = {
    'TEFF': (3000, 6500),
    'LOGG': (0.5, 5.5),
    'M_H': (-2, 0.5)
}

In [None]:
for name, (vmin, vmax) in things.items():
    fig, axes = simple_corner(projected_X, 
                              color_by=good_stars[name],
                              vmin=vmin, vmax=vmax)
    axes.flat[0].set_title(f'color: {name}')

In [None]:
with open(cache_path / 'pca_neighborizer.pkl', 'wb') as f:
    pickle.dump(pca, f)

---

## Validation

In [None]:
tree = cKDTree(projected_X / pca.singular_values_)

dist, idx = tree.query(projected_X / pca.singular_values_, 
                       k=np.arange(2, 32+2))

In [None]:
np.std(projected_X / pca.singular_values_, axis=0)

In [None]:
for i, js in enumerate(idx[:20]):    
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    
    ax = axes[0]
    bins = (np.linspace(3000, 7500, 128),
            np.linspace(0, 5.5, 128))
    ax.hist2d(parent['TEFF'], parent['LOGG'],
              bins=bins, norm=mpl.colors.LogNorm(),
              cmap='Greys')

    ax.plot(good_stars['TEFF'][i],
            good_stars['LOGG'][i],
            ls='none', marker='s', mew=0, ms=6., 
            color='tab:blue', zorder=100)

    ax.plot(good_stars['TEFF'][js],
            good_stars['LOGG'][js],
            ls='none', marker='o', mew=0, ms=4., 
            color='tab:orange', zorder=10)

    ax.set_ylim(5.5, 0)
    ax.set_xlim(7500, 3000)
    
    ax.set_xlabel(r'ASPCAP $T_{\rm EFF}$')
    ax.set_ylabel(r'ASPCAP $\log g$')
    
    # ---
    
    ax = axes[1]
    
    bins = (np.linspace(-0.5, 4.5, 128),
            np.linspace(-4, 12, 128))
    dist_mask, = np.where((parent['GAIAEDR3_PARALLAX'] / parent['GAIAEDR3_PARALLAX_ERROR']) > 5)
    distmod = coord.Distance(parallax=parent['GAIAEDR3_PARALLAX'][dist_mask]*u.mas).distmod.value
    ax.hist2d((parent['GAIAEDR3_PHOT_G_MEAN_MAG'] - parent['J'])[dist_mask], 
              parent['GAIAEDR3_PHOT_G_MEAN_MAG'][dist_mask] - distmod,
              bins=bins, norm=mpl.colors.LogNorm(),
              cmap='Greys')
    
    distmod = coord.Distance(parallax=good_stars['GAIAEDR3_PARALLAX']*u.mas, 
                             allow_negative=True).distmod.value
    ax.plot((good_stars['GAIAEDR3_PHOT_G_MEAN_MAG'] - good_stars['J'])[i],
            (good_stars['GAIAEDR3_PHOT_G_MEAN_MAG'] - distmod)[i],
            ls='none', marker='s', mew=0, ms=6., 
            color='tab:blue', zorder=100)
    
    dist_mask = (parent['GAIAEDR3_PARALLAX'][js] > 0.5) & (parent['GAIAEDR3_PARALLAX_ERROR'][js] < 0.1)
    ax.plot((good_stars['GAIAEDR3_PHOT_G_MEAN_MAG'] - good_stars['J'])[js][dist_mask],
            (good_stars['GAIAEDR3_PHOT_G_MEAN_MAG'] - distmod)[js][dist_mask],
            ls='none', marker='o', mew=0, ms=4., 
            color='tab:orange', zorder=10)

    ax.set_xlim(bins[0].min(), bins[0].max())
    ax.set_ylim(bins[1].max(), bins[1].min())
    
    ax.set_xlabel(r'$G - J$')
    ax.set_ylabel(r'$M_G$')

    fig.tight_layout()

In [None]:
is_ = np.arange(len(good_stars))

for i, js in zip(is_[good_stars['LOGG'] > 4], 
                 idx[good_stars['LOGG'] > 4][:20]):
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    
    ax = axes[0]
    bins = (np.linspace(3000, 7500, 128),
            np.linspace(0, 5.5, 128))
    ax.hist2d(parent['TEFF'], parent['LOGG'],
              bins=bins, norm=mpl.colors.LogNorm(),
              cmap='Greys')

    ax.plot(good_stars['TEFF'][i],
            good_stars['LOGG'][i],
            ls='none', marker='s', mew=0, ms=6., 
            color='tab:blue', zorder=100)

    ax.plot(good_stars['TEFF'][js],
            good_stars['LOGG'][js],
            ls='none', marker='o', mew=0, ms=4., 
            color='tab:orange', zorder=10)

    ax.set_ylim(5.5, 0)
    ax.set_xlim(7500, 3000)
    
    ax.set_xlabel(r'ASPCAP $T_{\rm EFF}$')
    ax.set_ylabel(r'ASPCAP $\log g$')
    
    # ---
    
    ax = axes[1]
    
    bins = (np.linspace(-0.5, 4.5, 128),
            np.linspace(-4, 12, 128))
    dist_mask, = np.where((parent['GAIAEDR3_PARALLAX'] / parent['GAIAEDR3_PARALLAX_ERROR']) > 5)
    distmod = coord.Distance(parallax=parent['GAIAEDR3_PARALLAX'][dist_mask]*u.mas).distmod.value
    ax.hist2d((parent['GAIAEDR3_PHOT_G_MEAN_MAG'] - parent['J'])[dist_mask], 
              parent['GAIAEDR3_PHOT_G_MEAN_MAG'][dist_mask] - distmod,
              bins=bins, norm=mpl.colors.LogNorm(),
              cmap='Greys')
    
    distmod = coord.Distance(parallax=good_stars['GAIAEDR3_PARALLAX']*u.mas, 
                             allow_negative=True).distmod.value
    ax.plot((good_stars['GAIAEDR3_PHOT_G_MEAN_MAG'] - good_stars['J'])[i],
            (good_stars['GAIAEDR3_PHOT_G_MEAN_MAG'] - distmod)[i],
            ls='none', marker='s', mew=0, ms=6., 
            color='tab:blue', zorder=100)

    dist_mask = (parent['GAIAEDR3_PARALLAX'][js] > 0.5) & (parent['GAIAEDR3_PARALLAX_ERROR'][js] < 0.1)
    ax.plot((good_stars['GAIAEDR3_PHOT_G_MEAN_MAG'] - good_stars['J'])[js][dist_mask],
            (good_stars['GAIAEDR3_PHOT_G_MEAN_MAG'] - distmod)[js][dist_mask],
            ls='none', marker='o', mew=0, ms=4., 
            color='tab:orange', zorder=10)

    ax.set_xlim(bins[0].min(), bins[0].max())
    ax.set_ylim(bins[1].max(), bins[1].min())
    
    ax.set_xlabel(r'$G - J$')
    ax.set_ylabel(r'$M_G$')

    fig.tight_layout()