In [None]:
import os
os.environ['APOGEE_CACHE_PATH'] = "/mnt/ceph/users/apricewhelan/apogee/"
os.environ['JOAQUIN_CACHE_PATH'] = "/mnt/ceph/users/apricewhelan/projects/joaquin/cache"
import warnings
import pickle

import sys
import pathlib
_path = str(pathlib.Path('../').resolve())
if _path not in sys.path:
    sys.path.append(_path)

import astropy.coordinates as coord
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from tqdm.auto import tqdm, trange

from scipy.spatial import cKDTree
from sklearn.decomposition import IncrementalPCA
from sklearn.neighbors import KernelDensity

from joaquin.data import JoaquinData
from joaquin.config import (all_phot_names, dr, root_cache_path, neighborhood_n_components,
                            block_size, neighborhood_color_names, max_neighborhood_size)
from joaquin.plot import simple_corner, phot_to_label, plot_hr_cmd

In [None]:
cache_path = pathlib.Path(f'../cache/{dr}').resolve()
cache_path.mkdir(exist_ok=True, parents=True)

plot_path = pathlib.Path('../plot') / dr
plot_path = plot_path.resolve()
plot_path.mkdir(parents=True, exist_ok=True)

In [None]:
parent_data = JoaquinData.read('parent-sample')
parent_data = parent_data[np.all(np.isfinite(parent_data.X), axis=1)]

parent = parent_data.stars
global_spec_mask = np.load(cache_path / 'global_spec_bad_mask.npy')

In [None]:
parent_X = parent_data.mask_spec_pixels(global_spec_mask).get_neighborhood_X()
parent_X.shape

### Select a subset of stars to use for defining the neighborhoods:

In [None]:
rng = np.random.default_rng(42)

idx, = np.where(
    (parent['SNR'] > 200) & 
    (parent['ruwe'] < 1.2)
)
idx = rng.choice(idx, size=8192, replace=False)
sub_data = parent_data[idx]

len(sub_data)

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))

teff_logg_bins = (
    np.linspace(3000, 9000, 128),
    np.linspace(-0.5, 5.75, 128))
ax.hist2d(parent['TEFF'], 
          parent['LOGG'],
          bins=teff_logg_bins, 
          norm=mpl.colors.LogNorm(),
          cmap='Greys')

ax.plot(sub_data.stars['TEFF'],
        sub_data.stars['LOGG'],
        ls='none', marker='o', mew=0, ms=2., 
        color='tab:blue', alpha=0.75)

ax.set_xlim(teff_logg_bins[0].max(), 
            teff_logg_bins[0].min())
ax.set_ylim(teff_logg_bins[1].max(), 
            teff_logg_bins[1].min())

ax.set_xlabel('TEFF')
ax.set_ylabel('LOGG')

fig.tight_layout()

### Construct the neighborhood feature matrix for the neighborhood node sample:

This currently uses the spectra and a set of colors

In [None]:
sub_data.all_phot_names

In [None]:
color_labels = [f'{phot_to_label[x1]}-{phot_to_label[x2]}'
                for x1, x2 in neighborhood_color_names]
sub_X = sub_data.mask_spec_pixels(global_spec_mask).get_neighborhood_X()

In [None]:
tmp = sub_X[:, -len(color_labels):]
fig = simple_corner(
    tmp, 
    labels=color_labels, 
    color_by=sub_data.stars['LOGG'], vmin=0.5, vmax=5.5, 
    colorbar=True)

### Run PCA on the neighborhood node features and project down:

In [None]:
pca = IncrementalPCA(n_components=neighborhood_n_components, 
                     batch_size=1024)
projected_X = pca.fit_transform(sub_X)
projected_X /= pca.singular_values_

This hacky step removes egregious outliers - only run this after the cells above because it overwrites variables!!

In [None]:
mean = np.mean(projected_X, axis=0)
std = np.std(projected_X, axis=0)
bad_mask = np.any(np.abs(projected_X - mean) > 5*std, axis=1)

neighborhood_node_X = sub_X[~bad_mask]
neighborhood_node_stars = sub_data.stars[~bad_mask]

pca = IncrementalPCA(n_components=neighborhood_n_components, 
                     batch_size=1024)
node_projected_X = pca.fit_transform(neighborhood_node_X)
node_projected_X /= pca.singular_values_

In [None]:
fig, axes = plt.subplots(pca.n_components_ // 2, 2,
                         figsize=(16, 12), sharex=True)

for i in range(pca.n_components):
    axes.flat[i].plot(pca.components_[i])
    
fig.tight_layout()

In [None]:
np.cumsum(pca.explained_variance_ratio_)

In [None]:
fig, axes = plt.subplots(3, 3, 
                         figsize=(10, 10),
                         sharex=True, sharey=True)

for i in range(pca.n_components):
    ax = axes.flat[i]
    ax.scatter(neighborhood_node_stars['TEFF'],
               neighborhood_node_stars['LOGG'],
               c=node_projected_X[:, i], s=6)
    ax.text(teff_logg_bins[0].max() - 100, 
            teff_logg_bins[1].min() + 0.1,
            f'PCA feature {i}', va='top', ha='left')

for i in range(pca.n_components, len(axes.flat)):
    axes.flat[i].set_visible(False)
    
ax.set_xlim(teff_logg_bins[0].max(), 
            teff_logg_bins[0].min())
ax.set_ylim(teff_logg_bins[1].max(), 
            teff_logg_bins[1].min())

fig.tight_layout()

In [None]:
things = {
    'TEFF': (3000, 6500),
    'LOGG': (0.5, 5.5),
    'M_H': (-2, 0.5)
}
for name, (vmin, vmax) in things.items():
    fig, axes, cb = simple_corner(
        node_projected_X, 
        color_by=neighborhood_node_stars[name],
        colorbar=True,
        vmin=vmin, vmax=vmax,
        labels=[f'PCA {i}' 
                for i in range(pca.n_components_)])
    cb.ax.set_aspect(40)
    axes.flat[0].set_title(f'color: {name}')
    
    fig.savefig(plot_path / f'neighborhood-pca-nodes-{name}.png', dpi=200)
    plt.close(fig)

In [None]:
with open(root_cache_path / 'pca_neighborizer.pkl', 'wb') as f:
    pickle.dump(pca, f)

## Now use the neighborhood sample to define the neighborhoods:

First, determine local density at all points:

In [None]:
bw = np.mean(np.std(node_projected_X, axis=0)) * np.sqrt(node_projected_X.shape[1]) / 3
kde = KernelDensity(bandwidth=bw, kernel='epanechnikov')
_ = kde.fit(node_projected_X)
dens = kde.score_samples(node_projected_X)

Now construct the kdtree to find neighbors:

In [None]:
tree = cKDTree(node_projected_X)

Setting k=2 below means that each row in `idxs` will be (self, neighbor)

In [None]:
_, idxs = tree.query(node_projected_X, k=2)

In [None]:
sort_idx = dens.argsort()[::-1]

this_size = int(neighborhood_size * len(neighborhood_node_stars) / len(parent))

neighborhoods = []
for (i1, i2) in idxs[sort_idx]:
    # if i1 in np.ravel(neighborhoods):
    #     continue
    # This relaxes the uniqueness requirement: we only skip now if a 
    # star appears in multiple neighborhoods
    if (np.ravel(neighborhoods) == i1).sum() >= 8:
        continue
    
    _, results = tree.query(node_projected_X[i1], k=this_size)
    neighborhoods.append(np.concatenate(([i1], results)))

len(neighborhoods)

In [None]:
node_TEFF = [neighborhood_node_stars['TEFF'][idx[0]]
             for idx in neighborhoods]
node_LOGG = [neighborhood_node_stars['LOGG'][idx[0]]
             for idx in neighborhoods]

fig, ax = plt.subplots(figsize=(6, 6))

ax.scatter(node_TEFF, node_LOGG)

ax.set_xlim(teff_logg_bins[0].max(), 
            teff_logg_bins[0].min())
ax.set_ylim(teff_logg_bins[1].max(), 
            teff_logg_bins[1].min())

ax.set_title('neighborhood nodes')

ax.set_xlabel('TEFF')
ax.set_ylabel('LOGG')

fig.tight_layout()

This makes plots of all of the neighborhood stars:

In [None]:
neighbor_plot_path = plot_path / 'neighborhoods'
neighbor_plot_path.mkdir(exist_ok=True)

for name in tqdm(neighbor_plot_path.rglob('*.png')):
    name.unlink()

for n, hood in enumerate(tqdm(neighborhoods)):
    fig = plot_hr_cmd(parent, neighborhood_node_stars,
                      idx0=hood[0], other_idx=hood[1:])
    fig.tight_layout()
    fig.savefig(neighbor_plot_path / f'neighborhood-{max_neighborhood_size}-{n:03d}.png', dpi=200)
    plt.close(fig)

### Apply to the full parent sample

In [None]:
parent_projected_X = np.zeros((parent_X.shape[0], node_projected_X.shape[1]),
                               dtype=np.float32)

vals = np.linspace(0, parent_X.shape[0], 32).astype(int)
for i1, i2 in zip(vals[:-1], vals[1:]):
    if i2 >= parent_X.shape[0]-1:
        i2 = parent_X.shape[0]
        
    parent_projected_X[i1:i2] = pca.transform(parent_X[i1:i2])
    parent_projected_X[i1:i2] /= pca.singular_values_

In [None]:
parent_tree = cKDTree(parent_projected_X)

In [None]:
zone_idx = []
for i, nidx in enumerate(tqdm(neighborhoods)):
    dist, idx = parent_tree.query(node_projected_X[nidx[0]],
                                  k=max_neighborhood_size)
    zone_idx.append(idx[dist.argsort()])

In [None]:
# for n in trange(len(zone_idx)):
#     fig = plot_hr_cmd(parent_data.stars, parent_data.stars,
#                       idx0=zone_idx[n][0], other_idx=zone_idx[n][1:])
    
#     fig.tight_layout()
#     fig.savefig(neighbor_plot_path / f'parent-neighborhood-{neighborhood_size}-{n:03d}.png', dpi=200)
#     plt.close(fig)

In [None]:
for size in 2**np.arange(13, 15+1):
    all_indices = []
    for n in range(len(zone_idx)):
        all_indices.append(zone_idx[n][:size])
    tmp = np.unique(np.ravel(all_indices)).shape[0] / parent_projected_X.shape[0]
    print(f"{tmp*100:.1f}% of stars end up in a neighborhood of size {size}")

In [None]:
len(parent), parent_projected_X.shape

In [None]:
filename = cache_path / f'good_parent_neighborhood_indices-{neighborhood_size}.npy'
np.save(filename, np.array(all_indices))