In [None]:
import os
os.environ['APOGEE_CACHE_PATH'] = "/mnt/ceph/users/apricewhelan/apogee-test/"
import warnings
warnings.filterwarnings('ignore', category=Warning) 
import pickle

import sys
import pathlib
_path = str(pathlib.Path('../').resolve())
if _path not in sys.path:
    sys.path.append(_path)

import corner
from astropy.io import fits
import astropy.coordinates as coord
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from tqdm.auto import tqdm
from sklearn.neighbors import KernelDensity
from scipy.spatial import cKDTree

from joaquin.data import JoaquinData, make_Xy
from joaquin.config import dr
from joaquin.logger import logger
from joaquin.plot import simple_corner

In [None]:
cache_path = pathlib.Path(f'../cache/{dr}').resolve()
cache_path.mkdir(exist_ok=True, parents=True)

In [None]:
parent = at.Table.read(cache_path / 'parent-sample.fits')

# HACK: subselect for speed
np.random.seed(42)
_idx = np.random.choice(len(parent), size=8192, replace=False)
parent_stars = parent[_idx]

len(parent_stars)

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))

bins = (np.linspace(3000, 7500, 128),
        np.linspace(0, 5.5, 128))
ax.hist2d(parent['TEFF'], parent['LOGG'],
          bins=bins, norm=mpl.colors.LogNorm(),
          cmap='magma_r')

ax.plot(parent_stars['TEFF'],
        parent_stars['LOGG'],
        ls='none', marker='o', mew=0, ms=3., 
        color='tab:blue', alpha=0.75)

ax.set_ylim(5.5, 0)
ax.set_xlim(7500, 3000)

fig.tight_layout()

In [None]:
data = JoaquinData(parent_stars, 
                   cache_path=cache_path, 
                   lowpass=False)

In [None]:
(X, *_), idx_map = data.get_Xy(['spec'], 
                               spec_mask_thresh=1)

# See cell in PCA-neighbors-training.ipynb
color_names = [
    ('GAIAEDR3_PHOT_BP_MEAN_MAG', 'J'),
    ('J', 'K'),
    ('J', 'w1mpro'),
    ('H', 'w2mpro'),
    ('w1mpro', 'w2mpro')
]
color_X = data.get_colors(color_names)
color_labels = [f'{x1}-{x2}' for x1, x2 in color_names]
X = np.hstack((X, color_X))

good_stars = data.stars[data.stars_mask]
assert X.shape[0] == len(good_stars)

In [None]:
with open(cache_path / 'pca_neighborizer.pkl', 'rb') as f:
    pca = pickle.load(f)

In [None]:
projected_X = pca.transform(X)
projected_X /= pca.singular_values_

First, determine local density at all points:

In [None]:
bw = np.mean(np.std(projected_X, axis=0)) * np.sqrt(projected_X.shape[1]) / 3
kde = KernelDensity(bandwidth=bw, kernel='epanechnikov')
_ = kde.fit(projected_X)
dens = kde.score_samples(projected_X)

Now construct the kdtree to find neighbors:

In [None]:
tree = cKDTree(projected_X)

Setting k=2 below means that each row in `idxs` will be (self, neighbor)

In [None]:
dists, idxs = tree.query(projected_X, k=2)
dists = dists[:, 1]

In [None]:
# # MAGIC NUMBERs
# radius_init_factor = 4
# radius_grow_factor = 2 ** (1/projected_X.shape[1])
# radius_maxiter = 128
# target_neighborhood_size = 1024

# # sort_idx = dists.argsort()
# sort_idx = dens.argsort()[::-1]

# neighborhoods = []
# for (i1, i2), dist in zip(idxs[sort_idx],
#                           dists[sort_idx][:2048]):
#     if i1 in np.ravel(neighborhoods):
#         continue
        
#     radius = radius_init_factor * dist
#     for niter in range(radius_maxiter):
#         results = tree.query_ball_point(projected_X[i1], r=radius)
        
#         if len(results) >= target_neighborhood_size:
#             break
        
#         radius *= radius_grow_factor
#     else:
#         print(f'failed for {i1}')
#         continue
    
#     print(f"{niter} iterations")
#     neighborhoods.append(np.concatenate(([i1], results)))

# MAGIC NUMBERs
target_neighborhood_size = 1024

# sort_idx = dists.argsort()
sort_idx = dens.argsort()[::-1]

neighborhoods = []
for (i1, i2), dist in zip(idxs[sort_idx],
                          dists[sort_idx]):
    if i1 in np.ravel(neighborhoods):
        continue
    
    _, results = tree.query(projected_X[i1], k=target_neighborhood_size)
    neighborhoods.append(np.concatenate(([i1], results)))

len(neighborhoods)

In [None]:
style_main = dict(ls='none', marker='o', mew=0.6, ms=6., 
                  color='tab:blue', zorder=100, 
                  mec='gold')
style_neighbors = dict(ls='none', marker='o', mew=0, ms=2., 
                       alpha=0.75, color='tab:orange', zorder=10)

# for hood in neighborhoods[:20]:
for n in np.arange(0, len(neighborhoods), 4):
    hood = neighborhoods[n]

    fig, axes = plt.subplots(1, 2, figsize=(12, 6))

    ax = axes[0]
    bins = (np.linspace(3000, 8500, 128),
            np.linspace(0, 5.5, 128))
    ax.hist2d(parent['TEFF'], parent['LOGG'],
              bins=bins, norm=mpl.colors.LogNorm(),
              cmap='Greys')

    ax.plot(good_stars['TEFF'][hood[0]],
            good_stars['LOGG'][hood[0]],
            **style_main)

    ax.plot(good_stars['TEFF'][hood[1:]],
            good_stars['LOGG'][hood[1:]],
            **style_neighbors)

    ax.set_xlim(bins[0].max(), bins[0].min())
    ax.set_ylim(bins[1].max(), bins[1].min())

    ax.set_xlabel(r'$T_{\rm eff}$')
    ax.set_ylabel(r'$\log g$')

    # ---

    ax = axes[1]

    # color = ('GAIAEDR3_PHOT_G_MEAN_MAG', 'J')
    # mag = 'J'
    # bins = (np.linspace(-0.5, 4.5, 128),
    #         np.linspace(-6, 10, 128))

    color = ('J', 'K')
    mag = 'H'
    bins = (np.linspace(-0.5, 2, 128),
            np.linspace(-6, 10, 128))

    dist_mask, = np.where((parent['GAIAEDR3_PARALLAX'] / parent['GAIAEDR3_PARALLAX_ERROR']) > 5)
    distmod = coord.Distance(parallax=parent['GAIAEDR3_PARALLAX'][dist_mask]*u.mas).distmod.value
    ax.hist2d((parent[color[0]] - parent[color[1]])[dist_mask], 
              parent[mag][dist_mask] - distmod,
              bins=bins, norm=mpl.colors.LogNorm(),
              cmap='Greys')

    distmod = coord.Distance(parallax=good_stars['GAIAEDR3_PARALLAX']*u.mas, 
                             allow_negative=True).distmod.value
    ax.plot((good_stars[color[0]] - good_stars[color[1]])[hood[0]],
            (good_stars[mag] - distmod)[hood[0]],
            **style_main)

    ax.plot((good_stars[color[0]] - good_stars[color[1]])[hood[1:]],
            (good_stars[mag] - distmod)[hood[1:]],
            **style_neighbors)

    ax.set_xlim(bins[0].min(), bins[0].max())
    ax.set_ylim(bins[1].max(), bins[1].min())

    ax.set_xlabel('$J - K$')
    ax.set_ylabel('$M_H$')

    fig.tight_layout()

# Generate a feature matrix for the entire parent sample:

In [None]:
Xyivar, idx_map, spec_mask_vals = make_Xy(parent, lowpass=False)

In [None]:
import h5py

In [None]:
with h5py.File(cache_path / f'parent-sample-data.hdf5', 'w') as f:
    f.create_dataset('X', data=Xyivar[0])
    f.create_dataset('y', data=Xyivar[1])
    f.create_dataset('y_ivar', data=Xyivar[2])
    f.create_dataset('spec_mask_vals', data=spec_mask_vals)
    
    g = f.create_group('idx_map')
    for key, idx in idx_map.items():
        g.create_dataset(key, data=idx)
        
    parent.write(f, path='stars', serialize_meta=False)

In [None]:
with h5py.File('/tmp/test.hdf5', 'w') as f:
    # parent.write(f, path='stars', serialize_meta=False)
    f.attrs['all_phot_names'] = ['GAIAEDR3_PHOT_G_MEAN_MAG', 'J',' w1mpro']

In [None]:
with h5py.File('/tmp/test.hdf5', 'r') as f:
    # parent.write(f, path='stars', serialize_meta=False)
    print(f.attrs['all_phot_names'])