# Define neighborhoods of similar stars

Each "neighborhood" is defined by a "stoop," which is defined by having high S/N in APOGEE and a low RUWE in Gaia.

In [None]:
import pathlib

import astropy.coordinates as coord
from astropy.stats import median_absolute_deviation as MAD
import astropy.table as at
import astropy.units as u
import corner
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from tqdm import tqdm, trange

from scipy.spatial import cKDTree
from sklearn.decomposition import IncrementalPCA
from sklearn.neighbors import KernelDensity

import schlummernd as sch
from schlummernd.plot import colored_corner

In [None]:
conf = sch.Config.parse_yaml('../config.yml')

In [None]:
# Random number generator, using seed from config file
rng = np.random.default_rng(seed=conf.seed)

## Select a subset of stars to use for defining the neighborhoods:

In [None]:
g_all = conf.load_training_data()
len(g_all), np.unique(g_all.source_id).size, np.unique(g_all.APOGEE_ID).size

In [None]:
# TODO: this number (size) should be configurable
g_idx, = np.where(
    (g_all.SNR > 100) & 
    (g_all.ruwe < 1.2)
)
g_idx = g_idx[rng.choice(len(g_idx), size=100_000, replace=False)]
g = g_all[g_idx]
len(g_idx)

Note: No stars hotter than TEFF > 7000 because they don't have M/H measurements in APOGEE

Spectroscopic HR diagram of the subset stars:

In [None]:
teff_logg_bins = (
    np.linspace(3000, 7000, 128),
    np.linspace(-0.5, 5.5, 128)
)

mg_bprp_bins = (
    np.linspace(-0.5, 5.5, 201),
    np.linspace(-6, 11, 201),
)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 6), 
                         constrained_layout=True)

ax = axes[0]
ax.hist2d(g_all.TEFF,
          g_all.LOGG,
          bins=teff_logg_bins,
          norm=mpl.colors.LogNorm(),
          cmap='Greys')

ax.plot(g.TEFF,
        g.LOGG,
        ls='none', marker='o', mew=0, ms=2.,
        color='tab:blue', alpha=0.75)

ax.set_xlim(teff_logg_bins[0].max(),
            teff_logg_bins[0].min())
ax.set_ylim(teff_logg_bins[1].max(),
            teff_logg_bins[1].min())

ax.set_xlabel('TEFF')
ax.set_ylabel('LOGG')

# ---

ax = axes[1]

mg = g_all.phot_g_mean_mag - g_all.get_distance(allow_negative=True).distmod
bprp = g_all.phot_bp_mean_mag - g_all.phot_rp_mean_mag
ax.hist2d(bprp.value, 
          mg.value,
          bins=mg_bprp_bins,
          norm=mpl.colors.LogNorm(),
          cmap='Greys')

ax.plot(bprp.value[g_idx], 
        mg.value[g_idx],
        ls='none', marker='o', mew=0, ms=2.,
        color='tab:blue', alpha=0.75)

ax.set_xlim(mg_bprp_bins[0].min(),
            mg_bprp_bins[0].max())
ax.set_ylim(mg_bprp_bins[1].max(),
            mg_bprp_bins[1].min())

ax.set_xlabel('bp-rp')
ax.set_ylabel('mg')

# fig.savefig(plot_path / 'subset-logg-teff.png', dpi=200)

In [None]:
schmag_factor = 10 ** (0.2 * g_all.phot_g_mean_mag.value) / 100.

pm = np.sqrt(g_all.pmra.value**2 + g_all.pmdec.value**2) / 8.
pm_err = np.sqrt(g_all.pmra_error.value**2 + g_all.pmdec_error.value**2) / 8.

schmag = g_all.parallax.value * schmag_factor
schpro = pm * schmag_factor
schpro_err = pm_err * schmag_factor

fig, ax = plt.subplots(
    figsize=(6, 6)
)
ax.hist2d(
    schmag, 
    schpro,
    bins=(
        np.linspace(-10, 150, 128),
        np.linspace(-10, 150, 128),
    ),
    norm=mpl.colors.LogNorm()
)
ax.set_xlabel('schpar')
ax.set_ylabel('schpro')

In [None]:
fig, ax = plt.subplots(figsize=(7, 6))
cs = ax.scatter(
    g_all.phot_g_mean_mag.value - g_all.H,
    # g_all.phot_g_mean_mag.value - g_all.phot_rp_mean_mag.value,
    schpro,
    c=g_all.SFD_EBV,
    vmin=0, vmax=0.8, cmap='turbo',
    marker='o', lw=0, alpha=0.4, s=2.,
)
cb = fig.colorbar(cs)
# ax.set_ylim(-0.2, 2)

In [None]:
maxn = 4
fig, axes = plt.subplots(
    1, maxn, 
    figsize=(4.5 * maxn, 4.3), 
    sharex=True,
    constrained_layout=True
)

for i, n in enumerate(range(1, 4+1)):
    ratio = g_all.rp[:, n] / g_all.rp[:, 0]
    ratio_err = ratio * np.sqrt((g_all.rp_err[:, n] / g_all.rp[:, n])**2 + 
                                (g_all.rp_err[:, 0] / g_all.rp[:, 0])**2)

    _mask = np.random.choice(len(g_all), size=30_000, replace=False)
    axes[i].errorbar(
        g_all.phot_g_mean_mag.value[_mask],
        ratio[_mask],
        yerr=ratio_err[_mask],
        ls='none', marker='o', mew=0, ms=1.,
        alpha=0.4
    )
    axes[i].set_title(f'rp[{n}]')
    
axes[i].set_xlim(5, 17.5)

In [None]:
_bp = g_all.bp[:, 0:1] / g_all.rp[:, 0:1]
_rp = g_all.rp[:, 1:4] / g_all.rp[:, 0:1]
X = np.hstack((
    (g_all.phot_g_mean_mag.value - g_all.J)[:, None],
    schpro[:, None],
    _bp,
    _rp
))
labels = ['$G-H$', 'schpro'] + ['BP[0]/RP[0]'] + [f'RP[{n}]/RP[0]' for n in range(1, 4)]

_bp_err = _bp * np.sqrt((g_all.bp_err[:, 0:1] / g_all.bp[:, 0:1])**2 + 
                        (g_all.rp_err[:, 0:1] / g_all.rp[:, 0:1])**2)
_rp_err = _rp * np.sqrt((g_all.rp_err[:, 1:4] / g_all.rp[:, 1:4])**2 + 
                        (g_all.rp_err[:, 0:1] / g_all.rp[:, 0:1])**2)
X_err = np.hstack((
    np.sqrt(1/g_all.phot_g_mean_flux_over_error**2 + g_all.H_ERR**2)[:, None],
    schpro_err[:, None],
    _bp_err,
    _rp_err
))

# X_scale = 1.5 * MAD(X, axis=0)[None]
X_scale = (np.nanpercentile(X, 95, axis=0) - np.nanpercentile(X, 5, axis=0))[None]
X = X - np.median(X, axis=0)[None]
X = X / X_scale
X_err = X_err / X_scale

X.shape, X_err.shape, len(labels)

In [None]:
bins = [
    np.linspace(*lim, 128) 
    for lim in np.nanpercentile(X, [0.5, 99.5], axis=0).T
]

for label in ['TEFF', 'LOGG', 'M_H', 'SFD_EBV']:
    fig, axes, cb = colored_corner(
        X[g_idx], 
        scatter=False, 
        color_by=g[label], 
        add_colorbar=True, 
        statistic=np.nanmean,
        bins=bins,
        cmap='turbo_r',
        labels=labels
    )
    cb.set_label(label)

## Run PCA on the features

Use the projected features to define neighborhoods.

In [None]:
# conf.n_neighborhood_pca_components = min(f.X.shape[1], conf.n_neighborhood_pca_components)
conf.n_neighborhood_pca_components = min(X.shape[1], 6)

In [None]:
pca = IncrementalPCA(
    n_components=conf.n_neighborhood_pca_components,
    batch_size=1024  # TODO: magic number, but just a tuning thing
)  
tmp_projected_X = pca.fit_transform(X[g_idx])
tmp_projected_X /= pca.singular_values_

This hacky step removes extreme outliers - only run this after the cells above because it overwrites variables!!

In [None]:
X.shape, len(g)

In [None]:
mean = np.mean(tmp_projected_X, axis=0)
std = np.std(tmp_projected_X, axis=0)
bad_mask = np.any(np.abs(tmp_projected_X - mean) > 5*std, axis=1)

stoop_X = X[g_idx][~bad_mask]
stoop_g = g[~bad_mask]

In [None]:
pca = IncrementalPCA(
    n_components=conf.n_neighborhood_pca_components,
    batch_size=1024
)
stoop_projected_X = pca.fit_transform(stoop_X)
stoop_projected_X /= pca.singular_values_

In [None]:
fig, axes = plt.subplots(pca.n_components_ // 2, 2,
                         figsize=(16, 12), sharex=True)

for i, ax in enumerate(axes.flat):
    ax.plot(pca.components_[i])
    
fig.tight_layout()

In [None]:
print(np.cumsum(pca.explained_variance_ratio_)[-1])
plt.plot(np.cumsum(pca.explained_variance_ratio_))
plt.ylim(0.8, 1)
plt.axhline(0.9, zorder=-10, alpha=0.5, color='tab:blue')

## Plot some spectroscopic parameters, colored by PCA component

In [None]:
fig, axes = plt.subplots(
    3, 3, 
    figsize=(10, 10),
    sharex=True, sharey=True,
    constrained_layout=True
)

XX = np.stack((
    stoop_g.TEFF,
    stoop_g.LOGG
)).T

for i in range(pca.n_components):
    ax = axes.flat[i]
    colored_corner(
        XX, 
        color_by=stoop_X[:, i], 
        scatter=False, 
        bins=teff_logg_bins, 
        axes=np.array([[ax]])
    )
    ax.text(
        teff_logg_bins[0].max() - 100, 
        teff_logg_bins[1].min() + 0.2,
        f'PCA feature {i}', va='top', ha='left'
    )
    
for j in range(i+1, len(axes.flat)):
    axes.flat[j].set_visible(False)
    
ax.set_xlim(teff_logg_bins[0].max(), teff_logg_bins[0].min())
ax.set_ylim(teff_logg_bins[1].max(), teff_logg_bins[1].min())

In [None]:
m_alpha_bins = (np.linspace(-2.5, 0.6, 128),
                np.linspace(-0.2, 0.5, 128))

fig, axes = plt.subplots(
    3, 3, 
    figsize=(10, 10),
    sharex=True, sharey=True,
    constrained_layout=True
)

XX = np.stack((
    stoop_g.M_H,
    stoop_g.ALPHA_M
)).T

for i in range(pca.n_components):
    ax = axes.flat[i]
    colored_corner(
        XX, 
        color_by=stoop_projected_X[:, i], 
        scatter=False, 
        bins=m_alpha_bins, 
        axes=np.array([[ax]])
    )
    ax.text(m_alpha_bins[0].min() + 0.1, 
            m_alpha_bins[1].max() - 0.02,
            f'PCA feature {i}', va='top', ha='left')

for j in range(i+1, len(axes.flat)):
    axes.flat[j].set_visible(False)

In [None]:
# things = {
#     'TEFF': (3000, 6500),
#     'LOGG': (0.5, 5.5),
#     'M_H': (-2, 0.5),
#     'AK_WISE': (0, 1)
# }
# for name, (vmin, vmax) in things.items():
#     fig, axes, cb = simple_corner(
#         node_projected_X, 
#         color_by=neighborhood_node_g[name],
#         colorbar=True,
#         vmin=vmin, vmax=vmax,
#         labels=[f'PCA {i}' 
#                 for i in range(pca.n_components_)])
#     cb.ax.set_aspect(40)
#     axes.flat[0].set_title(f'color: {name}')
    
#     # fig.savefig(plot_path / f'neighborhood-pca-{name}.png', dpi=200)
#     # plt.close(fig)

In [None]:
all_projected_X = np.zeros(
    (X.shape[0], stoop_projected_X.shape[1]),
    dtype=np.float32
)

vals = np.linspace(0, X.shape[0], 32).astype(int)
for i1, i2 in zip(vals[:-1], vals[1:]):
    if i2 >= X.shape[0]-1:
        i2 = X.shape[0]
        
    all_projected_X[i1:i2] = pca.transform(X[i1:i2])
    all_projected_X[i1:i2] /= pca.singular_values_

## Define the neighborhoods:

- Do what we do now to initialize

In [None]:
stoop_tree = cKDTree(stoop_projected_X)
_dist, _idx = stoop_tree.query(all_projected_X, k=[32])  # TODO: magic number 32
all_dens = 1 / _dist[:, 0]

In [None]:
XX = np.stack((
    g_all.TEFF,
    g_all.LOGG
)).T

*_, cb = colored_corner(
    XX, 
    color_by=all_dens, 
    scatter=False, 
    bins=teff_logg_bins,
    add_colorbar=True,
)
cb.set_label('density proxy')

In [None]:
sort_idx = all_dens.argsort()[::-1]
tree = cKDTree(all_projected_X)

In [None]:
# HACK:
max_n_neighborhoods = 8192
conf.max_neighborhood_size = 1024

In [None]:
neighborhood_idx = np.full(
    (max_n_neighborhoods, conf.max_neighborhood_size),
    -1,
    dtype=np.int32
)
neighborhood_lndist = np.full(
    (max_n_neighborhoods, conf.max_neighborhood_size - 1),
    -1.,
    dtype=np.float32
)

j = 0
for i in tqdm(sort_idx):
    if i in neighborhood_idx[:j].ravel():
        continue
    
    dists, _idx = tree.query(
        all_projected_X[i],
        k=conf.max_neighborhood_size
    )
    neighborhood_idx[j] = _idx
    neighborhood_lndist[j] = np.log(dists[1:])
    j += 1
    
    if j >= max_n_neighborhoods:
        print("Reached the maximum number of neighborhoods")
        break
        
neighborhood_idx = neighborhood_idx[:j]
neighborhood_lndist = neighborhood_lndist[:j]

In [None]:
neighborhood_idx.shape

In [None]:
XX = np.stack((
    g_all.TEFF[neighborhood_idx[:, 0]],
    g_all.LOGG[neighborhood_idx[:, 0]]
)).T

*_, cb = colored_corner(
    XX, 
    color_by=all_dens[neighborhood_idx[:, 0]], 
    scatter=False, 
    bins=teff_logg_bins,
    add_colorbar=True,
)
cb.set_label('log(dens)')

In [None]:
stoop_g = g_all[neighborhood_idx[:, 0]]

fig, ax = plt.subplots(figsize=(6, 6))

ax.scatter(stoop_g['TEFF'], stoop_g['LOGG'])

ax.set_xlim(teff_logg_bins[0].max(), 
            teff_logg_bins[0].min())
ax.set_ylim(teff_logg_bins[1].max(), 
            teff_logg_bins[1].min())

ax.set_title('stoops')

ax.set_xlabel('TEFF')
ax.set_ylabel('LOGG')

fig.tight_layout()

In [None]:
def consolidate(sets):
    """
    https://rosettacode.org/wiki/Set_consolidation#Python
    """
    setlist = [s for s in sets if s]
    for i, s1 in enumerate(setlist):
        if s1:
            for s2 in setlist[i+1:]:
                intersection = s1.intersection(s2)
                if intersection:
                    s2.update(s1)
                    s1.clear()
                    s1 = s2
    return [s for s in setlist if s]


def improve_neighborhoods_step(stoop_X, all_X, min_hood_size, merge_threshold_K=32):
    all_tree = cKDTree(all_X)
    
    # In each unique-ified neighborhood (i.e. remove stars that 
    # appear in multiple), find the mean locations in feature-space
    stoop_tree = cKDTree(stoop_X)
    closest_dist, closest_idx = stoop_tree.query(all_X)
    
    new_stoop_X = np.zeros_like(stoop_X)
    for j, i in enumerate(np.unique(closest_idx)):
        n_mask = closest_idx == i
        new_stoop_X[j] = np.mean(all_X[n_mask][:128], axis=0)  # TODO: MAGIC NUMBER 128
        
    # Merge two stoops if closer than some threshold:
    # Right now, the threshold is set by finding the median distance to
    # the K'th closest star over all stoops
    new_stoop_tree = cKDTree(new_stoop_X)
    K_dist, _ = all_tree.query(new_stoop_X, k=[merge_threshold_K])
    K_dist = K_dist[:, 0]
    
    # TODO: weird hack? use the same threshold for all stoops
    r = np.median(K_dist)
    
    to_merge = []
    new_stoop_idx = np.arange(new_stoop_X.shape[0])
    for stoop_i in new_stoop_idx:
        todo = new_stoop_tree.query_ball_point(
            new_stoop_X[stoop_i], 
            r=r
        )
        if len(todo) > 1:
            to_merge.append(set(todo))
    
    if to_merge:
        no_merge = new_stoop_idx[
            np.in1d(new_stoop_idx, list(set.union(*to_merge)), 
                    invert=True)
        ]

        merged = consolidate(to_merge)
        merged_stoop_X = np.array([
            new_stoop_X[np.array(list(idx))].mean(axis=0) 
            for idx in merged
        ])
        merged_stoop_X = np.vstack((merged_stoop_X, new_stoop_X[no_merge]))
    
    else:
        merged_stoop_X = new_stoop_X
    
    # Sort by approximate density:
    merged_stoop_tree = cKDTree(merged_stoop_X)
    # TODO: MAGIC NUMBER below = 16, but just used to sort, so probably ok
    _dist, _idx = merged_stoop_tree.query(merged_stoop_X, k=[16])
    merged_stoop_X = merged_stoop_X[_dist[:, 0].argsort()] 
    merged_stoop_tree = cKDTree(merged_stoop_X)    
    
    # If any neighborhood has less than some threshold 
    # number of points, remove that stoop so its residents merge
    # into other neighborhoods
    closest_dist, closest_idx = merged_stoop_tree.query(all_X)
    flagged = []
    for i in range(merged_stoop_X.shape[0]):
        mask = np.where(closest_idx == i)[0]
        if len(mask) < min_hood_size:
            flagged.append(i)
    
    merged_stoop_X = np.delete(merged_stoop_X, flagged, axis=0)
    merged_stoop_tree = cKDTree(merged_stoop_X)
        
    # Uniquify each neighborhood / tessellate
    new_neighborhood_idx = []
    new_neighborhood_lndist = []
    closest_dist, closest_idx = merged_stoop_tree.query(all_X)
    for i in range(merged_stoop_X.shape[0]):
        mask = np.where(closest_idx == i)[0]
        
        new_neighborhood_idx.append(mask)
        new_neighborhood_lndist.append(np.log(closest_dist[mask]))
        
        assert len(mask) >= min_hood_size
    
    return (
        merged_stoop_X, 
        np.array(new_neighborhood_idx, dtype=object), 
        np.array(new_neighborhood_lndist, dtype=object)
    )

In [None]:
min_neighborhood_size = 128

tmp_stoop_X = all_projected_X[neighborhood_idx[:, 0]]
print(tmp_stoop_X.shape)

for i in trange(8):  # HACK
    tmp_stoop_X, new_neighborhood_idx, new_neighborhood_lndist = improve_neighborhoods_step(
        tmp_stoop_X, 
        all_projected_X, 
        min_neighborhood_size
    )
    print(tmp_stoop_X.shape)

In [None]:
assert np.unique(np.concatenate(new_neighborhood_idx)).size == np.unique(neighborhood_idx).size

This makes plots of all of the neighborhood stars:

In [None]:
this_plot_path = conf.plot_path / 'neighborhoods'
this_plot_path.mkdir(exist_ok=True)

In [None]:
for ff in this_plot_path.glob('*.png'):
    ff.unlink()
    
lims = {
    'TEFF': (8000, 3200), 
    'LOGG': (5.5, -0.5), 
    'M_H': (-2.5, 0.5), 
    'AK_WISE': (0, 2)
}

# for n in np.linspace(0, neighborhood_idx.shape[0]-1, 20).astype(int):
for n in np.linspace(0, new_neighborhood_idx.shape[0]-1, 20).astype(int):
    # hood = neighborhood_idx[n]
    # lndist = neighborhood_lndist[n]
    hood = new_neighborhood_idx[n]
    lndist = new_neighborhood_lndist[n]
    
    stoop = g_all[hood[0]]
    block = g_all[hood]
    
    fig, axes, cb = colored_corner(
        X[hood], 
        scatter=True, 
        color_by=np.exp(lndist - lndist.min()), 
        add_colorbar=True, 
        cmap='turbo_r',
        labels=labels
    )
    cb.set_label('dist. to stoop')
    for ax in axes.flat:
        plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
        plt.setp(ax.get_yticklabels(), rotation=45, ha='right')
    fig.suptitle(f'Neighborhood {n}', fontsize=22)
    fig.savefig(this_plot_path / f'n{n:04d}-features.png', dpi=200)
    
    # ---
    
    fig, axes, cb = colored_corner(
        all_projected_X[hood], 
        scatter=True, 
        color_by=np.exp(lndist - lndist.min()), 
        add_colorbar=True, 
        cmap='turbo_r',
        labels=[f'pca[{i}]' for i in range(len(labels))]
    )
    cb.set_label('dist. to stoop')
    for ax in axes.flat:
        plt.setp(ax.get_xticklabels(), rotation=45, ha='right', fontsize=10)
        plt.setp(ax.get_yticklabels(), rotation=45, ha='right', fontsize=10)
    fig.suptitle(f'Neighborhood {n}', fontsize=22)
    fig.savefig(this_plot_path / f'n{n:04d}-pca-features.png', dpi=200)
    
    # ---
    
    fig, axes = plt.subplots(
        1, 3, 
        figsize=(15, 5), 
        constrained_layout=True
    )
    for ax, names in zip(axes, [('TEFF', 'LOGG'), 
                                ('TEFF', 'M_H'), 
                                ('M_H', 'AK_WISE')]):
#         ax.scatter(
#             stoop[names[0]], 
#             stoop[names[1]],
#             zorder=100, s=40, marker='x')
        cs = ax.scatter(
            block[names[0]], 
            block[names[1]],
            c=np.exp(lndist - lndist.min()),
            cmap='turbo_r',
            s=3, alpha=0.5, zorder=1)
        
        ax.set_xlabel(names[0])
        ax.set_ylabel(names[1])
        
        ax.set_xlim(lims[names[0]])
        ax.set_ylim(lims[names[1]])
    cb = fig.colorbar(cs, ax=axes, aspect=30)
    cb.set_label('dist. to stoop')
    fig.suptitle(f'Neighborhood {n}', fontsize=22)
    fig.savefig(this_plot_path / f'n{n:04d}-apogee-pars.png', dpi=200)
    
    plt.close('all')

## Apply to the full parent sample

Every parent sample star should get a stoop

In [None]:
all_projected_X.shape

In [None]:
np.save(
    conf.data_path / 'training_neighborhoods.npy',
    np.array(new_neighborhood_idx)
)

## Check on things that are far from all stoops:

In [None]:
stoop_tree = cKDTree(all_projected_X[neighborhood_idx[:, 0]])

In [None]:
closest_dist, _ = stoop_tree.query(all_projected_X)

In [None]:
plt.hist(closest_dist, bins=np.geomspace(1e-3, closest_dist.max(), 128));
plt.xscale('log')
plt.yscale('log')

In [None]:
far_from_stoop = g_all[closest_dist > np.percentile(closest_dist, 99.9)]

In [None]:
# Spectroscopic HR diagram of the subset stars:
fig, ax = plt.subplots(figsize=(6, 6))

teff_logg_bins = (
    np.linspace(3000, 7000, 128),
    np.linspace(-0.5, 5.5, 128)
)
ax.hist2d(g_all.TEFF,
          g_all.LOGG,
          bins=teff_logg_bins,
          norm=mpl.colors.LogNorm(),
          cmap='Greys')

ax.plot(far_from_stoop.TEFF, 
        far_from_stoop.LOGG,
        ls='none', marker='o', mew=0, ms=4.,
        color='tab:red', alpha=0.75)

ax.set_xlim(teff_logg_bins[0].max(),
            teff_logg_bins[0].min())
ax.set_ylim(teff_logg_bins[1].max(),
            teff_logg_bins[1].min())

ax.set_xlabel('TEFF')
ax.set_ylabel('LOGG')

fig.tight_layout()
# fig.savefig(plot_path / 'subset-logg-teff.png', dpi=200)

## OLD

TODO: collect `closest_idx` into similar structure as `neighborhood_idx`??

In [None]:
lims = {
    'TEFF': (8000, 3200), 
    'LOGG': (5.5, -0.5), 
    'M_H': (-2.5, 0.5), 
    'AK_WISE': (0, 2)
}

unq_idx = np.unique(closest_idx)

for n in np.linspace(0, unq_idx.shape[0]-1, 20).astype(int):    
    hood = closest_idx == unq_idx[n]
    block = g_all[hood]
    dist = closest_dist[hood]
    
    fig, axes, cb = colored_corner(
        X[hood], 
        scatter=True, 
        color_by=dist, 
        add_colorbar=True, 
        cmap='turbo_r',
        labels=labels
    )
    cb.set_label('dist. to stoop')
    for ax in axes.flat:
        plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
        plt.setp(ax.get_yticklabels(), rotation=45, ha='right')
    fig.suptitle(f'Neighborhood {n}', fontsize=22)
    fig.savefig(this_plot_path / f'n{n:04d}-features.png', dpi=200)
    
    # ---
    
    fig, axes, cb = colored_corner(
        all_projected_X[hood], 
        scatter=True, 
        color_by=dist, 
        add_colorbar=True, 
        cmap='turbo_r',
        labels=[f'pca[{i}]' for i in range(len(labels))]
    )
    cb.set_label('dist. to stoop')
    for ax in axes.flat:
        plt.setp(ax.get_xticklabels(), rotation=45, ha='right', fontsize=10)
        plt.setp(ax.get_yticklabels(), rotation=45, ha='right', fontsize=10)
    fig.suptitle(f'Neighborhood {n}', fontsize=22)
    fig.savefig(this_plot_path / f'n{n:04d}-pca-features.png', dpi=200)
    
    # ---
    
    fig, axes = plt.subplots(
        1, 3, 
        figsize=(15, 5), 
        constrained_layout=True
    )
    for ax, names in zip(axes, [('TEFF', 'LOGG'), 
                                ('TEFF', 'M_H'), 
                                ('M_H', 'AK_WISE')]):
        cs = ax.scatter(
            block[names[0]], 
            block[names[1]],
            c=dist,
            cmap='turbo_r',
            s=3, alpha=0.5, zorder=1)
        
        ax.set_xlabel(names[0])
        ax.set_ylabel(names[1])
        
        #ax.set_xlim(lims[names[0]])
        #ax.set_ylim(lims[names[1]])
    cb = fig.colorbar(cs, ax=axes, aspect=30)
    cb.set_label('dist. to stoop')
    fig.suptitle(f'Neighborhood {n}', fontsize=22)
    fig.savefig(this_plot_path / f'n{n:04d}-apogee-pars.png', dpi=200)
    
    plt.close('all')

In [None]:
hack, closest_counts = np.unique(closest_idx, return_counts=True)
plt.hist(closest_counts, bins=np.geomspace(1, 5e3, 32));
plt.xscale('log')

In [None]:
from matplotlib.colors import ListedColormap

In [None]:
shit = np.isin(closest_idx, hack[(closest_counts < 20)])
print(shit.sum(), np.unique(closest_idx[shit]).size)
tmp = plt.get_cmap('turbo')
rng = np.random.default_rng(42)
cm = ListedColormap([tmp(x) for x in rng.uniform(0, 1, size=np.unique(closest_idx[shit]).size)])

fig, axes = plt.subplots(1, 2, figsize=(11, 5))

ax = axes[0]
XX = np.stack((
    g_all.TEFF,
    g_all.LOGG
)).T[shit]
colored_corner(
    XX, 
    color_by=closest_idx[shit],
    scatter=True, 
    axes=np.array([[ax]]),
    # bins=teff_logg_bins,
    cmap=cm,
    s=4
)
ax.set_xlim(
    teff_logg_bins[0].max(), 
    teff_logg_bins[0].min()
)
ax.set_ylim(
    teff_logg_bins[1].max(), 
    teff_logg_bins[1].min()
)

# ---
ax = axes[1]
XX = np.stack((
    g_all.M_H,
    g_all.ALPHA_M,
)).T[shit]
colored_corner(
    XX, 
    color_by=closest_idx[shit],
    scatter=True, 
    axes=np.array([[ax]]),
    # bins=teff_logg_bins,
    cmap=cm,
    s=4
)
# ax.set_xlim(
#     teff_logg_bins[0].max(), 
#     teff_logg_bins[0].min()
# )
# ax.set_ylim(
#     teff_logg_bins[1].max(), 
#     teff_logg_bins[1].min()
# )