In [None]:
import os
os.environ['APOGEE_CACHE_PATH'] = "/mnt/ceph/users/apricewhelan/apogee-test/"
os.environ['JOAQUIN_CACHE_PATH'] = "/mnt/ceph/users/apricewhelan/projects/joaquin/cache"
import warnings
warnings.filterwarnings('ignore', category=Warning) 
import pickle

import sys
import pathlib
_path = str(pathlib.Path('../').resolve())
if _path not in sys.path:
    sys.path.append(_path)

import corner
from astropy.io import fits
import astropy.coordinates as coord
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from tqdm.auto import tqdm
from sklearn.neighbors import KernelDensity
from scipy.spatial import cKDTree

from joaquin.data import JoaquinData
from joaquin.config import root_cache_path
from joaquin.plot import simple_corner
from joaquin.neighborhoods import get_neighborhood_X

See: `PCA-neighbord-training.ipynb`

In [None]:
parent_data = JoaquinData(cache_file='parent-sample')
neighbor_data = JoaquinData(cache_file='neighborhood-sample')
spec_good_mask = np.load(root_cache_path / 'spec_good_mask.npy')

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))

bins = (np.linspace(3000, 7500, 128),
        np.linspace(0, 5.5, 128))
ax.hist2d(parent_data.stars['TEFF'], 
          parent_data.stars['LOGG'],
          bins=bins, norm=mpl.colors.LogNorm(),
          cmap='magma_r')

ax.plot(neighbor_data.stars['TEFF'],
        neighbor_data.stars['LOGG'],
        ls='none', marker='o', mew=0, ms=3., 
        color='tab:blue', alpha=0.75)

ax.set_ylim(5.5, 0)
ax.set_xlim(7500, 3000)

fig.tight_layout()

In [None]:
neighbor_X, color_labels, neighbor_stars = get_neighborhood_X(neighbor_data, spec_good_mask)

In [None]:
with open(root_cache_path / 'pca_neighborizer.pkl', 'rb') as f:
    pca = pickle.load(f)

In [None]:
projected_X = pca.transform(neighbor_X)
projected_X /= pca.singular_values_

First, determine local density at all points:

In [None]:
bw = np.mean(np.std(projected_X, axis=0)) * np.sqrt(projected_X.shape[1]) / 3
kde = KernelDensity(bandwidth=bw, kernel='epanechnikov')
_ = kde.fit(projected_X)
dens = kde.score_samples(projected_X)

Now construct the kdtree to find neighbors:

In [None]:
tree = cKDTree(projected_X)

Setting k=2 below means that each row in `idxs` will be (self, neighbor)

In [None]:
dists, idxs = tree.query(projected_X, k=2)
dists = dists[:, 1]

In [None]:
# # MAGIC NUMBERs
# radius_init_factor = 4
# radius_grow_factor = 2 ** (1/projected_X.shape[1])
# radius_maxiter = 128
# target_neighborhood_size = 1024

# # sort_idx = dists.argsort()
# sort_idx = dens.argsort()[::-1]

# neighborhoods = []
# for (i1, i2), dist in zip(idxs[sort_idx],
#                           dists[sort_idx][:2048]):
#     if i1 in np.ravel(neighborhoods):
#         continue
        
#     radius = radius_init_factor * dist
#     for niter in range(radius_maxiter):
#         results = tree.query_ball_point(projected_X[i1], r=radius)
        
#         if len(results) >= target_neighborhood_size:
#             break
        
#         radius *= radius_grow_factor
#     else:
#         print(f'failed for {i1}')
#         continue
    
#     print(f"{niter} iterations")
#     neighborhoods.append(np.concatenate(([i1], results)))

# MAGIC NUMBERs
target_neighborhood_size = 256

# sort_idx = dists.argsort()
sort_idx = dens.argsort()[::-1]

neighborhoods = []
for (i1, i2), dist in zip(idxs[sort_idx],
                          dists[sort_idx]):
    if i1 in np.ravel(neighborhoods):
        continue
    
    _, results = tree.query(projected_X[i1], k=target_neighborhood_size)
    neighborhoods.append(np.concatenate(([i1], results)))

len(neighborhoods)

In [None]:
zones_X = np.array([
    projected_X[idx[0]] 
    for idx in neighborhoods])
# zones_X = np.array([
#     np.mean(projected_X[idx], axis=0) 
#     for idx in neighborhoods])

np.save(root_cache_path / 'neighborhoods_projected_X.npy', 
        zones_X)

In [None]:
derp_TEFF = np.zeros(len(neighborhoods))
derp_LOGG = np.zeros(len(neighborhoods))

for i, idx in enumerate(neighborhoods):
    teffs = neighbor_data.stars['TEFF'][idx]
    loggs = neighbor_data.stars['LOGG'][idx]

#     derp_TEFF[i] = np.median(teffs[teffs > 0])
#     derp_LOGG[i] = np.median(loggs[loggs > -1])

    derp_TEFF[i] = teffs[0]
    derp_LOGG[i] = loggs[0]
    
plt.scatter(derp_TEFF, derp_LOGG)
# plt.scatter(teffs, loggs)
plt.xlim(8500, 3000)
plt.ylim(5.5, -0.5)

In [None]:
style_main = dict(ls='none', marker='o', mew=0.6, ms=6., 
                  color='tab:blue', zorder=100, 
                  mec='gold')
style_neighbors = dict(ls='none', marker='o', mew=0, ms=2., 
                       alpha=0.75, color='tab:orange', zorder=10)

# for hood in neighborhoods[:20]:
for n in np.arange(0, len(neighborhoods), 4)[:20]:
    hood = neighborhoods[n]

    fig, axes = plt.subplots(1, 2, figsize=(12, 6))

    ax = axes[0]
    bins = (np.linspace(3000, 8500, 128),
            np.linspace(0, 5.5, 128))
    ax.hist2d(parent_data.stars['TEFF'], parent_data.stars['LOGG'],
              bins=bins, norm=mpl.colors.LogNorm(),
              cmap='Greys')

    ax.plot(neighbor_stars['TEFF'][hood[0]],
            neighbor_stars['LOGG'][hood[0]],
            **style_main)

    ax.plot(neighbor_stars['TEFF'][hood[1:]],
            neighbor_stars['LOGG'][hood[1:]],
            **style_neighbors)

    ax.set_xlim(bins[0].max(), bins[0].min())
    ax.set_ylim(bins[1].max(), bins[1].min())

    ax.set_xlabel(r'$T_{\rm eff}$')
    ax.set_ylabel(r'$\log g$')

    # ---

    ax = axes[1]

    # color = ('GAIAEDR3_PHOT_G_MEAN_MAG', 'J')
    # mag = 'J'
    # bins = (np.linspace(-0.5, 4.5, 128),
    #         np.linspace(-6, 10, 128))

    color = ('J', 'K')
    mag = 'H'
    bins = (np.linspace(-0.5, 2, 128),
            np.linspace(-6, 10, 128))

    dist_mask, = np.where((parent_data.stars['GAIAEDR3_PARALLAX'] / parent_data.stars['GAIAEDR3_PARALLAX_ERROR']) > 5)
    distmod = coord.Distance(parallax=parent_data.stars['GAIAEDR3_PARALLAX'][dist_mask]*u.mas).distmod.value
    ax.hist2d((parent_data.stars[color[0]] - parent_data.stars[color[1]])[dist_mask], 
              parent_data.stars[mag][dist_mask] - distmod,
              bins=bins, norm=mpl.colors.LogNorm(),
              cmap='Greys')

    distmod = coord.Distance(parallax=neighbor_stars['GAIAEDR3_PARALLAX']*u.mas, 
                             allow_negative=True).distmod.value
    ax.plot((neighbor_stars[color[0]] - neighbor_stars[color[1]])[hood[0]],
            (neighbor_stars[mag] - distmod)[hood[0]],
            **style_main)

    ax.plot((neighbor_stars[color[0]] - neighbor_stars[color[1]])[hood[1:]],
            (neighbor_stars[mag] - distmod)[hood[1:]],
            **style_neighbors)

    ax.set_xlim(bins[0].min(), bins[0].max())
    ax.set_ylim(bins[1].max(), bins[1].min())

    ax.set_xlabel('$J - K$')
    ax.set_ylabel('$M_H$')

    fig.tight_layout()

# Apply to the full parent sample

In [None]:
parent_X, _, parent_stars = get_neighborhood_X(
    parent_data, spec_good_mask)
parent_X.shape

In [None]:
parent_projected_X = np.zeros((parent_X.shape[0], projected_X.shape[1]),
                               dtype=np.float32)
parent_projected_X.shape

In [None]:
vals = np.linspace(0, parent_X.shape[0], 8).astype(int)
for i1, i2 in zip(vals[:-1], vals[1:]):
    if i2 >= parent_X.shape[0]-1:
        i2 = parent_X.shape[0]
    
    print(i1, i2)
    parent_projected_X[i1:i2] = pca.transform(parent_X[i1:i2])
    parent_projected_X[i1:i2] /= pca.singular_values_

In [None]:
np.save(root_cache_path / 'parent_projected_X.npy', 
        parent_projected_X)

TODO: Cache parent_projected_X to be used to construct neighborhoods??

In [None]:
from scipy.stats import binned_statistic_2d
from sklearn.manifold import TSNE

In [None]:
tsne = TSNE(n_components=1)
X_embedded = tsne.fit_transform(parent_projected_X)

In [None]:
X_embedded.shape

In [None]:
bins = (np.linspace(3000, 8500, 128),
        np.linspace(0, 5.5, 128))

stat = binned_statistic_2d(
    parent_stars['TEFF'], 
    parent_stars['LOGG'], 
    # values=parent_projected_X[:, 0],
    values=X_embedded[:, 0],
    bins=bins)

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))

ax.pcolormesh(
    stat.x_edge, stat.y_edge, 
    stat.statistic.T,
    cmap='magma_r')

ax.set_ylim(5.5, 0)
ax.set_xlim(8500, 3000)

fig.tight_layout()