In [None]:
import os
os.environ['APOGEE_CACHE_PATH'] = "/mnt/ceph/users/apricewhelan/apogee-test/"
os.environ['JOAQUIN_CACHE_PATH'] = "/mnt/ceph/users/apricewhelan/projects/joaquin/cache"
import warnings
import pickle

import sys
import pathlib
_path = str(pathlib.Path('../').resolve())
if _path not in sys.path:
    sys.path.append(_path)

import astropy.coordinates as coord
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from tqdm.auto import tqdm, trange

from scipy.spatial import cKDTree
from sklearn.decomposition import IncrementalPCA
from sklearn.neighbors import KernelDensity

from joaquin.data import JoaquinData
from joaquin.config import all_phot_names, dr, root_cache_path, target_neighborhood_size
from joaquin.plot import simple_corner, phot_to_label, plot_hr_cmd
from joaquin.neighborhoods import get_neighborhood_X

In [None]:
plot_path = pathlib.Path('../plot') / dr
plot_path = plot_path.resolve()
plot_path.mkdir(parents=True, exist_ok=True)

### Build the design matrix data for the full parent sample

In [None]:
parent = at.Table.read(root_cache_path / 'parent-sample.fits')

In [None]:
parent_data = JoaquinData(
    parent, lowpass=False, 
    cache_file='parent-sample-raw')

### Define and save a global spectral mask based on the fraction of pixels over the full parent sample that are masked:

In [None]:
spec_good_mask = parent_data._spec_mask_vals < 0.25
np.save(root_cache_path / 'spec_good_mask.npy', 
        spec_good_mask)
spec_good_mask.sum()

### Select a subset of stars to use for defining the neighborhoods:

In [None]:
rng = np.random.default_rng(42)

sub_stars = parent[(parent['SNR'] > 200)]
idx = rng.choice(len(sub_stars), size=8192, replace=False)
sub_stars = sub_stars[idx]

len(sub_stars)

In [None]:
sub_data = JoaquinData(
    sub_stars, lowpass=False, 
    cache_file='neighborhood-node-sample')

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))

teff_logg_bins = (
    np.linspace(3000, 9000, 128),
    np.linspace(-0.5, 5.75, 128))
ax.hist2d(parent_data.stars['TEFF'], 
          parent_data.stars['LOGG'],
          bins=bins, norm=mpl.colors.LogNorm(),
          cmap='Greys')

ax.plot(sub_data.stars['TEFF'],
        sub_data.stars['LOGG'],
        ls='none', marker='o', mew=0, ms=2., 
        color='tab:blue', alpha=0.75)

ax.set_xlim(teff_logg_bins[0].max(), 
            teff_logg_bins[0].min())
ax.set_ylim(teff_logg_bins[1].max(), 
            teff_logg_bins[1].min())

ax.set_xlabel('TEFF')
ax.set_ylabel('LOGG')

fig.tight_layout()

### Construct the neighborhood feature matrix for the neighborhood node sample:

This currently uses the spectra and a set of colors

In [None]:
sub_X, color_labels, good_sub_stars = get_neighborhood_X(
    sub_data, spec_good_mask)

In [None]:
tmp = sub_X[:, -len(color_labels):]
fig = simple_corner(
    tmp, 
    labels=color_labels, 
    color_by=good_sub_stars['LOGG'], vmin=0.5, vmax=5.5, 
    colorbar=True)

### Run PCA on the neighborhood node features and project down:

In [None]:
pca = IncrementalPCA(n_components=8, batch_size=1024)
projected_X = pca.fit_transform(sub_X)
projected_X /= pca.singular_values_

This hacky step removes egregious outliers - only run this after the cells above because it overwrites variables!!

In [None]:
mean = np.mean(projected_X, axis=0)
std = np.std(projected_X, axis=0)
bad_mask = np.any(np.abs(projected_X - mean) > 5*std, axis=1)

neighborhood_node_X = sub_X[~bad_mask]
neighborhood_node_stars = good_sub_stars[~bad_mask]

pca = IncrementalPCA(n_components=8, batch_size=1024)
node_projected_X = pca.fit_transform(neighborhood_node_X)
node_projected_X /= pca.singular_values_

In [None]:
fig, axes = plt.subplots(pca.n_components_ // 2, 2,
                         figsize=(16, 12), sharex=True)

for i in range(pca.n_components):
    axes.flat[i].plot(pca.components_[i])
    
fig.tight_layout()

In [None]:
np.cumsum(pca.explained_variance_ratio_)

In [None]:
fig, axes = plt.subplots(3, 3, 
                         figsize=(10, 10),
                         sharex=True, sharey=True)

for i in range(pca.n_components):
    ax = axes.flat[i]
    ax.scatter(neighborhood_node_stars['TEFF'],
               neighborhood_node_stars['LOGG'],
               c=node_projected_X[:, i], s=6)
    ax.text(teff_logg_bins[0].max() - 100, 
            teff_logg_bins[1].min() + 0.1,
            f'PCA feature {i}', va='top', ha='left')

for i in range(pca.n_components, len(axes.flat)):
    axes.flat[i].set_visible(False)
    
ax.set_xlim(teff_logg_bins[0].max(), 
            teff_logg_bins[0].min())
ax.set_ylim(teff_logg_bins[1].max(), 
            teff_logg_bins[1].min())

fig.tight_layout()

In [None]:
things = {
    'TEFF': (3000, 6500),
    'LOGG': (0.5, 5.5),
    'M_H': (-2, 0.5)
}
for name, (vmin, vmax) in things.items():
    fig, axes, cb = simple_corner(
        node_projected_X, 
        color_by=neighborhood_node_stars[name],
        colorbar=True,
        vmin=vmin, vmax=vmax,
        labels=[f'PCA {i}' 
                for i in range(pca.n_components_)])
    cb.ax.set_aspect(40)
    axes.flat[0].set_title(f'color: {name}')
    
    fig.savefig(plot_path / f'neighborhood-pca-nodes-{name}.png', dpi=200)
    plt.close(fig)

In [None]:
with open(root_cache_path / 'pca_neighborizer.pkl', 'wb') as f:
    pickle.dump(pca, f)

## Now use the neighborhood sample to define the neighborhoods:

First, determine local density at all points:

In [None]:
bw = np.mean(np.std(node_projected_X, axis=0)) * np.sqrt(node_projected_X.shape[1]) / 3
kde = KernelDensity(bandwidth=bw, kernel='epanechnikov')
_ = kde.fit(node_projected_X)
dens = kde.score_samples(node_projected_X)

Now construct the kdtree to find neighbors:

In [None]:
tree = cKDTree(node_projected_X)

Setting k=2 below means that each row in `idxs` will be (self, neighbor)

In [None]:
_, idxs = tree.query(node_projected_X, k=2)

In [None]:
sort_idx = dens.argsort()[::-1]

this_size = int(target_neighborhood_size * len(neighborhood_node_stars) / len(parent))

neighborhoods = []
for (i1, i2) in idxs[sort_idx]:
    if i1 in np.ravel(neighborhoods):
        continue
    
    _, results = tree.query(node_projected_X[i1], k=this_size)
    neighborhoods.append(np.concatenate(([i1], results)))

len(neighborhoods)

In [None]:
node_TEFF = [neighborhood_node_stars['TEFF'][idx[0]]
             for idx in neighborhoods]
node_LOGG = [neighborhood_node_stars['LOGG'][idx[0]]
             for idx in neighborhoods]

fig, ax = plt.subplots(figsize=(6, 6))

ax.scatter(node_TEFF, node_LOGG)

ax.set_xlim(teff_logg_bins[0].max(), 
            teff_logg_bins[0].min())
ax.set_ylim(teff_logg_bins[1].max(), 
            teff_logg_bins[1].min())

ax.set_title('neighborhood nodes')

ax.set_xlabel('TEFF')
ax.set_ylabel('LOGG')

fig.tight_layout()

In [None]:
neighbor_plot_path = plot_path / 'neighborhoods'
neighbor_plot_path.mkdir(exist_ok=True)

for n, hood in enumerate(tqdm(neighborhoods)):
    fig = plot_hr_cmd(parent_data, neighborhood_node_stars,
                      idx0=hood[0], other_idx=hood[1:])
    fig.tight_layout()
    fig.savefig(neighbor_plot_path / f'neighborhood-{n:03d}.png', dpi=200)
    plt.close(fig)

### Apply to the full parent sample

In [None]:
parent_X, _, parent_good_stars = get_neighborhood_X(
    parent_data, spec_good_mask)
parent_X.shape

In [None]:
parent_projected_X = np.zeros((parent_X.shape[0], node_projected_X.shape[1]),
                               dtype=np.float32)

vals = np.linspace(0, parent_X.shape[0], 8).astype(int)
for i1, i2 in zip(vals[:-1], vals[1:]):
    if i2 >= parent_X.shape[0]-1:
        i2 = parent_X.shape[0]
        
    parent_projected_X[i1:i2] = pca.transform(parent_X[i1:i2])
    parent_projected_X[i1:i2] /= pca.singular_values_

In [None]:
parent_tree = cKDTree(parent_projected_X)

In [None]:
zone_idx = []
for i, nidx in enumerate(tqdm(neighborhoods)):
    dist, idx = parent_tree.query(node_projected_X[nidx[0]],
                                  k=target_neighborhood_size)
    zone_idx.append(idx[dist.argsort()])

In [None]:
for n in trange(len(zone_idx)):
    fig = plot_hr_cmd(parent_data, parent_good_stars,
                      idx0=zone_idx[n][0], other_idx=zone_idx[n][1:])
    
    fig.tight_layout()
    fig.savefig(neighbor_plot_path / f'parent-neighborhood-{n:03d}.png', dpi=200)
    plt.close(fig)

---

## Validation

In [None]:
tree = cKDTree(projected_X / pca.singular_values_)

dist, idx = tree.query(projected_X / pca.singular_values_, 
                       k=np.arange(2, 32+2))

In [None]:
np.std(projected_X / pca.singular_values_, axis=0)

In [None]:
for i, js in enumerate(idx[:20]):    
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    
    ax = axes[0]
    bins = (np.linspace(3000, 7500, 128),
            np.linspace(0, 5.5, 128))
    ax.hist2d(parent['TEFF'], parent['LOGG'],
              bins=bins, norm=mpl.colors.LogNorm(),
              cmap='Greys')

    ax.plot(good_stars['TEFF'][i],
            good_stars['LOGG'][i],
            ls='none', marker='s', mew=0, ms=6., 
            color='tab:blue', zorder=100)

    ax.plot(good_stars['TEFF'][js],
            good_stars['LOGG'][js],
            ls='none', marker='o', mew=0, ms=4., 
            color='tab:orange', zorder=10)

    ax.set_ylim(5.5, 0)
    ax.set_xlim(7500, 3000)
    
    ax.set_xlabel(r'ASPCAP $T_{\rm EFF}$')
    ax.set_ylabel(r'ASPCAP $\log g$')
    
    # ---
    
    ax = axes[1]
    
    bins = (np.linspace(-0.5, 4.5, 128),
            np.linspace(-4, 12, 128))
    dist_mask, = np.where((parent['GAIAEDR3_PARALLAX'] / parent['GAIAEDR3_PARALLAX_ERROR']) > 5)
    distmod = coord.Distance(parallax=parent['GAIAEDR3_PARALLAX'][dist_mask]*u.mas).distmod.value
    ax.hist2d((parent['GAIAEDR3_PHOT_G_MEAN_MAG'] - parent['J'])[dist_mask], 
              parent['GAIAEDR3_PHOT_G_MEAN_MAG'][dist_mask] - distmod,
              bins=bins, norm=mpl.colors.LogNorm(),
              cmap='Greys')
    
    distmod = coord.Distance(parallax=good_stars['GAIAEDR3_PARALLAX']*u.mas, 
                             allow_negative=True).distmod.value
    ax.plot((good_stars['GAIAEDR3_PHOT_G_MEAN_MAG'] - good_stars['J'])[i],
            (good_stars['GAIAEDR3_PHOT_G_MEAN_MAG'] - distmod)[i],
            ls='none', marker='s', mew=0, ms=6., 
            color='tab:blue', zorder=100)
    
    dist_mask = (parent['GAIAEDR3_PARALLAX'][js] > 0.5) & (parent['GAIAEDR3_PARALLAX_ERROR'][js] < 0.1)
    ax.plot((good_stars['GAIAEDR3_PHOT_G_MEAN_MAG'] - good_stars['J'])[js][dist_mask],
            (good_stars['GAIAEDR3_PHOT_G_MEAN_MAG'] - distmod)[js][dist_mask],
            ls='none', marker='o', mew=0, ms=4., 
            color='tab:orange', zorder=10)

    ax.set_xlim(bins[0].min(), bins[0].max())
    ax.set_ylim(bins[1].max(), bins[1].min())
    
    ax.set_xlabel(r'$G - J$')
    ax.set_ylabel(r'$M_G$')

    fig.tight_layout()

In [None]:
is_ = np.arange(len(good_stars))

for i, js in zip(is_[good_stars['LOGG'] > 4], 
                 idx[good_stars['LOGG'] > 4][:20]):
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    
    ax = axes[0]
    bins = (np.linspace(3000, 7500, 128),
            np.linspace(0, 5.5, 128))
    ax.hist2d(parent['TEFF'], parent['LOGG'],
              bins=bins, norm=mpl.colors.LogNorm(),
              cmap='Greys')

    ax.plot(good_stars['TEFF'][i],
            good_stars['LOGG'][i],
            ls='none', marker='s', mew=0, ms=6., 
            color='tab:blue', zorder=100)

    ax.plot(good_stars['TEFF'][js],
            good_stars['LOGG'][js],
            ls='none', marker='o', mew=0, ms=4., 
            color='tab:orange', zorder=10)

    ax.set_ylim(5.5, 0)
    ax.set_xlim(7500, 3000)
    
    ax.set_xlabel(r'ASPCAP $T_{\rm EFF}$')
    ax.set_ylabel(r'ASPCAP $\log g$')
    
    # ---
    
    ax = axes[1]
    
    bins = (np.linspace(-0.5, 4.5, 128),
            np.linspace(-4, 12, 128))
    dist_mask, = np.where((parent['GAIAEDR3_PARALLAX'] / parent['GAIAEDR3_PARALLAX_ERROR']) > 5)
    distmod = coord.Distance(parallax=parent['GAIAEDR3_PARALLAX'][dist_mask]*u.mas).distmod.value
    ax.hist2d((parent['GAIAEDR3_PHOT_G_MEAN_MAG'] - parent['J'])[dist_mask], 
              parent['GAIAEDR3_PHOT_G_MEAN_MAG'][dist_mask] - distmod,
              bins=bins, norm=mpl.colors.LogNorm(),
              cmap='Greys')
    
    distmod = coord.Distance(parallax=good_stars['GAIAEDR3_PARALLAX']*u.mas, 
                             allow_negative=True).distmod.value
    ax.plot((good_stars['GAIAEDR3_PHOT_G_MEAN_MAG'] - good_stars['J'])[i],
            (good_stars['GAIAEDR3_PHOT_G_MEAN_MAG'] - distmod)[i],
            ls='none', marker='s', mew=0, ms=6., 
            color='tab:blue', zorder=100)

    dist_mask = (parent['GAIAEDR3_PARALLAX'][js] > 0.5) & (parent['GAIAEDR3_PARALLAX_ERROR'][js] < 0.1)
    ax.plot((good_stars['GAIAEDR3_PHOT_G_MEAN_MAG'] - good_stars['J'])[js][dist_mask],
            (good_stars['GAIAEDR3_PHOT_G_MEAN_MAG'] - distmod)[js][dist_mask],
            ls='none', marker='o', mew=0, ms=4., 
            color='tab:orange', zorder=10)

    ax.set_xlim(bins[0].min(), bins[0].max())
    ax.set_ylim(bins[1].max(), bins[1].min())
    
    ax.set_xlabel(r'$G - J$')
    ax.set_ylabel(r'$M_G$')

    fig.tight_layout()