# Experiments related to Joaquin
Technically, this notebook implements something *even dumber* than *Joaquin*.
It implements kNN in *Gaia*-only quantities to get a weighted-mean estimate of schmag.

## Authors:
- **Adrian Price-Whelan** (Flatiron)
- **David W. Hogg** (NYU) (MPIA) (Flatiron)

## Definitions and Conventions:
- `ncoeff`: The number of BP and RP spectral coefficients to use.
- `maxk`: The maximum `k` to which we take neighbors.
- scalings or preprocessing of input features (currently null).
- how we use the neighbors (weighted mean, weighted linear fit, mixture of some kind?).

## TODO / questions
- Do we add "Reduced proper motion" as a feature?
- Use 2MASS or WISE photometry in features?
- Color the CMD by implied density (and store distance to Kth neighbor as proxy for density)
- Predict [M/H] or LOGG, then feed back in as a feature to predict schmag

In [None]:
import pathlib
import astropy.coordinates as coord
from astropy.stats import median_absolute_deviation as MAD
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import h5py
from tqdm import tqdm
from sklearn.neighbors import KDTree
from pyia import GaiaData
from scipy.stats import binned_statistic

from helpers import load_data, Features

Load APOGEE x Gaia data — see `Assemble-data.ipynb` for more information.

In [None]:
# Upper giant branch:
g = load_data(
    filters=dict(
        TEFF=(3000, 5100), 
        LOGG=(-0.5, 2.3),
        M_H=(-3, None),
        phot_g_mean_mag=(None, 15.5*u.mag),
        AK_WISE=(-0.1, None)
    )
)

# For red clump instead:
# g = load_data(
#     filters=dict(
#         TEFF=(4500, 5100), 
#         LOGG=(2.3, 2.6),
#         M_H=(-3, None),
#         phot_g_mean_mag=(None, 15.5*u.mag),
#         AK_WISE=(-0.1, None)
#     )
# )

# g = g[(np.abs(g.b) > 15*u.deg) & (g.SFD_EBV < 0.2)]

len(g)

In [None]:
bprp = (g.phot_bp_mean_mag - g.phot_rp_mean_mag).value
mg = (g.phot_g_mean_mag - g.get_distance(allow_negative=True).distmod).value

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

ax = axes[0]
H, xb, yb, _ = ax.hist2d(
    g.TEFF,
    g.LOGG,
    bins=(
        np.linspace(3000, 8000, 128),
        np.linspace(-0.5, 5.5, 128)
    ),
    norm=mpl.colors.LogNorm()
)
ax.set_xlim(xb.max(), xb.min())
ax.set_ylim(yb.max(), yb.min())
ax.set_xlabel('TEFF')
ax.set_ylabel('LOGG')

ax = axes[1]
H, xb, yb, _ = ax.hist2d(
    bprp,
    mg,
    bins=(
        np.linspace(-0.5, 3, 128),
        np.linspace(-4, 10.5, 128)
    ),
    norm=mpl.colors.LogNorm()
)
ax.set_xlim(xb.min(), xb.max())
ax.set_ylim(yb.max(), yb.min())
ax.set_xlabel('BP-RP')
ax.set_ylabel('$M_G$')

fig.tight_layout()

In [None]:
other_features = {
    r"$G_{\rm BP}-G_{\rm RP}$": 0.1 * (g.phot_bp_mean_mag - g.phot_rp_mean_mag)
}
f_all = Features.from_gaiadata(g, n_bp=32, n_rp=32, **other_features)

Make list of possible labels (and label weights), aligned with the features.

In [None]:
# Make list of labels (and label weights), aligned with the features.

label_ys = {}
label_weights = {}
label_latex = {}

schmag_factor = 10 ** (0.2 * g.phot_g_mean_mag.value) / 100.
schmag_err = g.parallax_error.value * schmag_factor
label_ys['schmag'] = g.parallax.value * schmag_factor
label_weights['schmag'] = 1 / schmag_err**2
label_latex['schmag'] = '$G$-band schmag (absmgy$^{-1/2}$)'

for name in ['M_H', 'LOGG', 'TEFF', 'AK_WISE']:
    err_col = f'{name}_ERR'
    label_ys[name] = g[name]
    if err_col in g.data.colnames:
        label_weights[name] = 1 / g[err_col]**2
    else:
        label_weights[name] = np.ones_like(label_ys[name])

label_latex['M_H'] = r"$[{\rm M}/{\rm H}]$"
label_latex['LOGG'] = r"$\log g$"
label_latex['TEFF'] = r"$T_{\rm eff}$"
label_latex['AK_WISE'] = r"$A_K$"

Check the label uncertainties:

In [None]:
ny = int(np.ceil(len(label_ys) / 2))
fig, axes = plt.subplots(
    2, 
    ny, 
    figsize=(5 * ny, 8)
)

for ax, name in zip(axes.flat, label_ys.keys()):
    y = label_ys[name]
    yerr = 1 / np.sqrt(label_weights[name])
    
    bins = [
        np.linspace(*np.percentile(y, [1, 99]), 128),
        np.geomspace(*np.percentile(yerr, [1, 99]), 128)
    ]
    if name == 'AK_WISE':
        bins[1] = np.geomspace(0.5, 2, 128)
        
    ax.hist2d(
        y, 
        yerr, 
        bins=bins,
        norm=mpl.colors.LogNorm(),
        cmap='Greys'
    )
    
    ax.set_yscale('log')
    
    _label = label_latex[name]
    ax.set_xlabel(_label)
    ax.set_ylabel(r'$\sigma$ ' + _label)
    
fig.tight_layout()

## Make training and validation samples

cut into eighths

In [None]:
# label_name = 'schmag'
# fractional = True
label_name = 'M_H'
fractional = False

rng = np.random.default_rng(seed=42)

rando = rng.integers(8, size=len(f_all))
train = rando != 0
valid = (
    ~train #&
#     (g.LOGG < 2.2) &
#     ((label_ys[label_name] * np.sqrt(label_weights[label_name])) > 4)
)

f_train = f_all[train]
f_valid = f_all[valid]

X_train, X_valid = f_train.X, f_valid.X
y_train, y_valid = label_ys[label_name][train], label_ys[label_name][valid]
w_train, w_valid = label_weights[label_name][train], label_weights[label_name][valid]

print(X_train.shape, X_valid.shape)
print(y_train.shape, y_valid.shape)
print(w_train.shape, w_valid.shape)

## Build a kNN model and validate it

Get all possibly useful validation-set neighbors up-front.
We'll use them in various ways below.

In [None]:
maxK = 1024  # MAGIC  
n_xp_tree = 16  # MAGIC

X_tree = f_all.slice_bp(n_xp_tree).slice_rp(n_xp_tree).X
X_train_tree = X_tree[train]
X_valid_tree = X_tree[valid]

tree = KDTree(X_train_tree, leaf_size=32) # magic
dists, inds = tree.query(X_valid_tree, k=maxK)
print(X_valid.shape, dists.shape, inds.shape)

In [None]:
Ks = 2 ** np.arange(0, int(np.log2(maxK)) + 1, 2)
Ks

# Weighted means of $K$ neighbors

In [None]:
weighted_means = {}
weighted_errs = {}
for k in Ks:
    weighted_means[k] = (
        np.sum(y_train[inds[:, :k]] * w_train[inds[:, :k]], axis=1) / 
        np.sum(w_train[inds[:, :k]], axis=1)
    )
    weighted_errs[k] = np.sqrt(1 / np.sum(w_train[inds[:, :k]], axis=1))

In [None]:
def scale(x):
    x = np.array(x)
    return (x - x.min()) / (x.max() - x.min())

In [None]:
# Let's look at a few objects
cmap = plt.get_cmap('turbo')

for ii in range(4):
    fig, axes = plt.subplots(1, 2, figsize=(11, 5))
    
    ax = axes[0]
    ax.axhline(y_valid[ii], c="r")
    ax.axhspan(
        y_valid[ii] - 1 / np.sqrt(w_valid[ii]),
        y_valid[ii] + 1 / np.sqrt(w_valid[ii]),
        color='r', alpha=0.25, linewidth=0
    )
    
    colors = cmap(scale(np.log(list(weighted_means.keys()))))
    for color, (kk, mean) in zip(colors, weighted_means.items()):
        ax.axhline(mean[ii], linestyle='--', alpha=0.4, color=color)
        ax.axhspan(
            mean[ii] - weighted_errs[kk][ii],
            mean[ii] + weighted_errs[kk][ii],
            alpha=0.4, color=color, linewidth=0
        )
    
    ax.errorbar(dists[ii], 
                y_train[inds[ii]], 
                yerr=1. / np.sqrt(w_train[inds[ii]]),
                fmt="o", color="k", ecolor="k")
    ax.set_xlabel("distance to neighbor")
    ax.set_ylabel(f"{label_latex[label_name]} (neighbors)")
    ax.set_title(f"validation-set object {ii}")
    
    ax.set_xlim(0, None)
    
    # ---
    
    ax = axes[1]
    
    bins = (
        np.linspace(-0.5, 3.5, 128),
        np.linspace(-4, 12, 128)
    )
    ax.hist2d(
        bprp,
        mg,
        bins=bins,
        cmap='Greys',
        norm=mpl.colors.LogNorm()
    )
    ax.scatter(
        bprp[valid][ii],
        mg[valid][ii],
        s=10,
        color='tab:red',
        zorder=100
    )
    ax.scatter(
        bprp[train][inds[ii]],
        mg[train][inds[ii]],
        s=4,
        color='tab:blue',
        alpha=0.5,
        zorder=10
    )
    ax.set_xlim(0., 4.)
    ax.set_ylim(10, -4)
    
    ax.set_xlabel('$G_{BP}-G_{RP}$')
    ax.set_ylabel('$M_G$')
    
    fig.tight_layout()

CMD colored by discrepancy

In [None]:
# for color, (kk, y_pred) in zip(colors, weighted_means.items()):
# #     dy = (Y_valid - Y_pred) / Y_valid
#     dy = (y_valid - y_pred)
    
#     fig, ax = plt.subplots(1, 1, figsize=(7, 6))
#     cs = ax.scatter(
#         bprp[valid],
#         mg[valid],
#         c=dy,
#         vmin=-.25, vmax=.25,
#         cmap='RdBu',
#         s=2
#     )
#     ax.set_xlim(0., 4.)
#     ax.set_ylim(10, -4)
    
#     cb = fig.colorbar(cs)
    
#     ax.set_xlabel('$G_{BP}-G_{RP}$')
#     ax.set_ylabel('$M_G$')
#     ax.set_title(f'K={kk}')
#     fig.tight_layout()

Red clump test: How do we do just predicting the mean over the training set?

In [None]:
mean_pred = (
    np.sum(y_train * w_train) / 
    np.sum(w_train)
)
mean_pred

In [None]:
if fractional:
    diff = (y_valid - mean_pred) / y_valid
else:
    diff = (y_valid - mean_pred)
print(1.5 * MAD(diff))

# Weighted linear least-squares method for KNN.

In [None]:
# implement 
Ks = 2 ** np.arange(0, int(np.log2(maxK)) + 1, 2)
n_xps = [1, 4, 16, 32]

y_valid_preds = {(n_xp, k): np.zeros(len(f_valid)) for k in Ks for n_xp in n_xps}
# weighted_lls_errs = {}

# TODO: Regularization
alpha = 1e-8

for n_xp in n_xps:
    f_train_cut = f_train.slice_bp(n_xp).slice_rp(n_xp)
    f_valid_cut = f_valid.slice_bp(n_xp).slice_rp(n_xp)
    
    X_fit_train = np.hstack((np.ones(f_train_cut.X.shape[0])[:, None], f_train_cut.X))
    X_fit_valid = np.hstack((np.ones(f_valid_cut.X.shape[0])[:, None], f_valid_cut.X))
    Nvalid = X_fit_valid.shape[0]
    
    L = np.eye(X_fit_train.shape[1]) * alpha
    Linv = np.eye(X_fit_train.shape[1]) * 1 / alpha

    for k in Ks:
        # TODO: switch to linalg.lstsq when you hit singular matrix shit
        for ii, ind in tqdm(enumerate(inds[:, :k]), total=Nvalid):
            C_train = np.diag(1 / w_train[ind])
            Cinv_train = np.diag(w_train[ind])

            if k > n_xp:
                y_valid_preds[n_xp, k][ii] = (
                    X_fit_valid[ii] @ np.linalg.solve(
                        X_fit_train[ind].T @ Cinv_train @ X_fit_train[ind] + L,
                        X_fit_train[ind].T @ Cinv_train @ y_train[ind]
                    )
                )
            else:
                y_valid_preds[n_xp, k][ii] = (
                    X_fit_valid[ii] @ Linv @ X_fit_train[ind].T @ np.linalg.solve(
                        X_fit_train[ind] @ Linv @ X_fit_train[ind].T + C_train,
                        y_train[ind]
                    )
                )

        # BUG: The next line is WRONG
    #     weighted_lls_errs[k] = np.sqrt(1 / np.sum(W_train[inds[:, :k]], axis=1))

In [None]:
# fig, axes = plt.subplots(
#     len(Ks), len(n_xps), 
#     figsize=(12, 12),
#     sharex=True, sharey=True,
#     constrained_layout=True
# )

# for i, P in enumerate(n_xps):
#     for j, k in enumerate(Ks):
#         ax = axes[j, i]
        
#         # dy = (Y_valid - Y_valid_preds[P, k]) / Y_valid
#         dy = (y_valid - y_valid_preds[P, k]) 

#         _cs = ax.scatter(
#             bprp[valid],
#             mg[valid],
#             c=dy,
#             vmin=-0.25, vmax=0.25,
#             cmap='RdBu',
#             s=1
#         )

#         ax.set_xlim(0., 4.)
#         ax.set_ylim(10, -4)

#         ax.set_title(f'K={k}, P={P}')

# for ax in axes[-1]:
#     ax.set_xlabel('$G_{BP}-G_{RP}$')
# for ax in axes[:, 0]:
#     ax.set_ylabel('$M_G$')

# cb = fig.colorbar(_cs, ax=axes, aspect=30)

In [None]:
fig, axes = plt.subplots(
    len(Ks), len(n_xps), 
    figsize=(12, 12 * len(Ks)/len(n_xps)),
    sharex=True, sharey=True,
    constrained_layout=True
)

_all_dy = []
for i, P in enumerate(n_xps):
    for j, k in enumerate(Ks):
        ax = axes[j, i]

        y_pred = y_valid_preds[P, k]
        if fractional:
            dy = (y_pred - y_valid) / y_valid
        else:
            dy = y_pred - y_valid
        _all_dy.append(dy)
        
        _cs = ax.scatter(
            y_valid,
            dy,
            s=1
        )
        
        stat = binned_statistic(
            y_valid,
            dy,
            bins=np.linspace(y_valid.min(), y_valid.max(), 10),
            statistic=lambda x: 1.5 * MAD(x)
        )
        binc = 0.5 * (stat.bin_edges[:-1] + stat.bin_edges[1:])
        for sign in [1, -1]:
            ax.plot(binc, sign * stat.statistic, 
                    marker='', drawstyle='steps-mid', 
                    color='tab:blue', alpha=0.5)

        ax.set_title(f'K={k}, P={P}')

_all_dy = np.array(_all_dy)

ax.set_xlim(np.nanpercentile(y_valid, [0.1, 99.9]))
_lim = np.abs([np.percentile(_all_dy, 5, axis=1).min(), 
               np.percentile(_all_dy, 95, axis=1).max()]).max()
ax.set_ylim(-_lim, _lim)
        
for ax in axes[-1]:
    ax.set_xlabel(f"{label_latex[label_name]} true", 
                  fontsize=10)
    
for ax in axes[:, 0]:
    if fractional:
        ax.set_ylabel('(pred - true) / true')
    else:
        ax.set_ylabel('true - pred')

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))

_all_dy = []

k = 1024
for i, P in enumerate(n_xps):
    y_pred = y_valid_preds[P, k]
    if fractional:
        dy = (y_pred - y_valid) / y_valid
    else:
        dy = y_pred - y_valid

    stat = binned_statistic(
        y_valid,
        dy,
        bins=np.linspace(y_valid.min(), y_valid.max(), 15),
        statistic=np.std
    )
    binc = 0.5 * (stat.bin_edges[:-1] + stat.bin_edges[1:])
    ax.plot(binc, stat.statistic, 
            marker='', drawstyle='steps-mid', 
            alpha=0.5, label=f'P={P}')

    _all_dy.append(stat.statistic)

    # ax.set_title(f'K={k}, P={P}')

_all_dy = np.array(_all_dy)

ax.set_xlim(np.nanpercentile(y_valid, [0.1, 99.9]))
ax.set_ylim(0, np.nanmax(_all_dy) + 0.1*np.nanmax(_all_dy))

ax.set_xlabel(f"{label_latex[label_name]} true")
# for ax in axes[:, 0]:
#     if fractional:
#         ax.set_ylabel('(pred - true) / true')
#     else:
#         ax.set_ylabel('true - pred')

ax.legend()

In [None]:
bulk_dy = np.zeros((len(n_xps), len(Ks)))
for i, P in enumerate(n_xps):
    for j, k in enumerate(Ks):
        if fractional:
            diff = (y_valid_preds[P, k] - y_valid) / y_valid
        else:
            diff = (y_valid_preds[P, k] - y_valid)
        
        bulk_dy[i, j] = 1.5 * MAD(diff)

In [None]:
bulk_dy.min()

In [None]:
fig, ax = plt.subplots(figsize=(6, 6 * len(Ks)/len(n_xps)))
ax.pcolormesh(bulk_dy.T, shading='auto')

ax.set_yticks(np.arange(len(Ks)) + 0.5)
ax.set_yticklabels(Ks.astype(str))

ax.set_xticks(np.arange(len(n_xps)) + 0.5)
ax.set_xticklabels(np.array(n_xps).astype(str))

ax.set_xlabel(r'$P$ aka $N_{\rm xp}$')
ax.set_ylabel('$K$ neighbors');
print("darker is better")

2D "image" of P vs K, colored by metric (MAD, RMS) in MS box and RGB box

In [None]:
stat_boxes = {
    'ms': (
        (np.abs(bprp - 1.5) < 0.5) &
        (np.abs(mg - 7) < 0.5)
    ),
    'rc': (
        (np.abs(bprp - 1.2) < 0.5) &
        (np.abs(mg - 0.9) < 0.5)
    ),
    'rgb': (
        (np.abs(bprp - 1.2) < 0.5) &
        (np.abs(mg - 1) < 0.5)
    ),
    'trgb': (
        (bprp > 1) &
        (bprp < 4) &
        (np.abs(mg - -1) < 0.5)
    )
}

In [None]:
stats = {}
for name, box_mask in stat_boxes.items():
    stats[name] = np.zeros((len(n_xps), len(Ks)))
    for i, P in enumerate(n_xps):
        for j, k in enumerate(ks):
            chi = (np.sqrt(w_valid) * (y_valid - y_valid_preds[P, k]))[box_mask[valid]]
            meanchi2 = np.mean(chi**2)
            medchi2 = np.median(chi**2)

            stats[name][i, j] = meanchi2
#             stats[name][i, j] = medchi2

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(11, 10), 
                         sharex=True, sharey=True,
                         constrained_layout=True)
for ax, (name, stat) in zip(axes.flat, stats.items()):
    _cs = ax.imshow(stat.T, cmap='turbo', origin='lower')
    fig.colorbar(_cs, ax=ax)
    ax.set_xlabel('$P$')
    ax.set_ylabel('$k$')
    
    ax.set_xticks(np.arange(len(n_xps)))
    ax.set_yticks(np.arange(len(Ks)))
    ax.set_xticklabels([str(x) for x in n_xps])
    ax.set_yticklabels([str(x) for x in Ks])
    
    ax.set_title(name)

In [None]:
# for name, stat in stats.items():
#     fig, ax = plt.subplots(figsize=(6.5, 5.5), constrained_layout=True)
#     _cs = plt.imshow(stat.T, cmap='turbo', origin='lower')
#     fig.colorbar(_cs, ax=ax)
#     ax.set_xlabel('$P$')
#     ax.set_ylabel('$k$')
    
#     ax.set_xticks(np.arange(len(Ps)))
#     ax.set_yticks(np.arange(len(ks)))
#     ax.set_xticklabels([str(x) for x in Ps])
#     ax.set_yticklabels([str(x) for x in ks])
    
#     ax.set_title(name)

In [None]:
# Test some kind of mixture model maybe??

# Iterative Reweighted linear least-squares

TODO:

# Run this model on EVERYTHING