# Experiments related to Joaquin
Technically, this notebook implements something *even dumber* than *Joaquin*.
It implements kNN in *Gaia*-only quantities to get a weighted-mean estimate of schmag.

## Authors:
- **Adrian Price-Whelan** (Flatiron)
- **David W. Hogg** (NYU) (MPIA) (Flatiron)

## Definitions and Conventions:
- `ncoeff`: The number of BP and RP spectral coefficients to use.
- `maxk`: The maximum `k` to which we take neighbors.
- scalings or preprocessing of input features (currently null).
- how we use the neighbors (weighted mean, weighted linear fit, mixture of some kind?).

## TODO / questions
- Do we add "Reduced proper motion" as a feature?
- Use 2MASS or WISE photometry in features?
- Color the CMD by implied density (and store distance to Kth neighbor as proxy for density)
- 

In [None]:
import pathlib
import astropy.coordinates as coord
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import h5py
from tqdm import tqdm
from sklearn.neighbors import KDTree
from pyia import GaiaData

Load APOGEE x Gaia data — see `Assemble-data.ipynb` for more information.

In [None]:
gall = GaiaData('../cache/apogee-dr17-x-gaia-dr3-xp.fits')

In [None]:
TEFF_lim = (4500, 5100)
LOGG_lim = (2.3, 2.6)

fig, ax = plt.subplots(figsize=(6, 6))
H, xb, yb, _ = ax.hist2d(
    gall.TEFF,
    gall.LOGG,
    bins=(
        np.linspace(3000, 8000, 128),
        np.linspace(-0.5, 5.5, 128)
    ),
    norm=mpl.colors.LogNorm()
)
ax.set_xlim(xb.max(), xb.min())
ax.set_ylim(yb.max(), yb.min())

for l in TEFF_lim:
    ax.axvline(l)
for l in LOGG_lim:
    ax.axhline(l)

fig.tight_layout()

In [None]:
g = gall.filter(
    TEFF=TEFF_lim,
    LOGG=LOGG_lim,
    M_H=[-3, None],
    phot_g_mean_mag=[None, 15.5*u.mag]
)

## Make rectangular data

In [None]:
# This does something useful!
xp_apogee_tbl = xp_apogee_tbl.filled()

In [None]:
# plt.plot(
#     xp_apogee_tbl['phot_g_mean_mag'] - xp_apogee_tbl['J'],
#     xp_apogee_tbl['AK_WISE'],
#     ls='none',
#     ms=1., mew=0, alpha=0.2
# )
# plt.xlim(-1, 8)
# plt.ylim(-0.1, 1)

In [None]:
# Make rectangular block of Gaia-only features (X) for training and validation

# APW, HOGG: BUG: Why these cuts?
feature_mask = (
#     (xp_apogee_tbl['J'] < 13) &
#     (xp_apogee_tbl['H'] < 12) &
#     (xp_apogee_tbl['K'] < 11) &
    (xp_apogee_tbl['AK_WISE'] > 0) &
    (np.abs(xp_apogee_tbl['b']) > 30.)
)

ncoeff = 54 # MAGIC
xx = (xp_apogee_tbl['bp'][:, 1:ncoeff + 1] / xp_apogee_tbl['bp'][:, 0:1])[feature_mask]
yy = (xp_apogee_tbl['rp'][:, 1:ncoeff + 1] / xp_apogee_tbl['rp'][:, 0:1])[feature_mask]
coeffs = np.vstack([[xx[:, i], yy[:, i]] for i in range(ncoeff)]).T
coeff_names = np.concatenate([[f'BP[{i}]', f'RP[{i}]'] for i in range(1, ncoeff + 1)])

features = np.hstack((
    0.1 * (xp_apogee_tbl['phot_bp_mean_mag'] - xp_apogee_tbl['phot_rp_mean_mag'])[feature_mask, None],
    # (xp_apogee_tbl['phot_g_mean_mag'] - xp_apogee_tbl['phot_rp_mean_mag'])[feature_mask, None],
#     0.2 * (xp_apogee_tbl['phot_g_mean_mag'] - xp_apogee_tbl['J'])[feature_mask, None],
    coeffs
))
coeff_idx = 1
# features = coeffs
# coeff_idx = 0

feature_names = np.concatenate((
    ['$BP-RP$ [mag]', ],
#     ['$G-J$ [mag]', ],
    coeff_names
))
# feature_names = coeff_names

print(features.shape)
print(len(feature_names), feature_names)

In [None]:
plx_mask = (xp_apogee_tbl[feature_mask]['parallax_over_error'] > 5)

_tbl = xp_apogee_tbl[feature_mask]
DM = coord.Distance(parallax=_tbl['parallax'].value * u.mas, allow_negative=True).distmod
bprp = _tbl['phot_bp_mean_mag'] - _tbl['phot_rp_mean_mag']
mg = _tbl['phot_g_mean_mag'] - DM

In [None]:
# BP[1] and RP[1] are very correlated with BP-RP, even after scaling
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].plot(
    bprp,
    coeffs[:, 0],
    ls='none'
)
axes[1].plot(
    bprp,
    coeffs[:, 1],
    ls='none'
)

In [None]:
# things = np.nanpercentile(features, [5, 95], axis=0)
# plt.hist(things[1] - things[0], bins=np.linspace(0, things.max(), 32));

In [None]:
# plt.figure(figsize=(10, 5))
# plt.hist(
#     xp_apogee_tbl['phot_g_mean_mag'][feature_mask], 
#     bins=np.linspace(5, 20.7, 121)
# );
# plt.xlabel('$G$ [mag]')
# plt.yscale('log')

In [None]:
# HACK / TEST: remove temperature dependence of coefficient prediction
# M = np.hstack((
#     np.ones((features.shape[0], 1)),
#     coeffs
# ))
# M = coeffs

# # sol, *_ = np.linalg.lstsq(M, xp_apogee_tbl['TEFF'][feature_mask], rcond=None)
# sol, *_ = np.linalg.lstsq(M, bprp, rcond=None)

# corrected = M - M.dot(sol)[:, None] * sol[None]
# # corrected = corrected[:, 1:]
# features = np.hstack((
#     features[:, :coeff_idx],
#     corrected
# ))

In [None]:
# # BP[1] and RP[1] are very correlated with BP-RP, even after scaling
# fig, axes = plt.subplots(1, 2, figsize=(10, 5))
# axes[0].plot(
#     bprp,
#     corrected[:, 0],
#     ls='none'
# )
# axes[1].plot(
#     bprp,
#     corrected[:, 1],
#     ls='none'
# )

In [None]:
# Make list of labels (and label weights), aligned with the features.
# labels = (xp_apogee_tbl['parallax'] * 10 ** (1/5 * xp_apogee_tbl['phot_g_mean_mag']))[feature_mask]
labels = xp_apogee_tbl['M_H'][feature_mask]
print(labels.shape)

# label_errors = (xp_apogee_tbl['parallax_error'] * 10 ** (1/5 * xp_apogee_tbl['phot_g_mean_mag']))[feature_mask]
label_errors = xp_apogee_tbl['M_H_ERR'][feature_mask]
print(label_errors.shape)

label_weights = 1. / (label_errors ** 2)
print(label_weights.shape)

label_name = '$G$-band schmag (absmgy$^{-1/2}$)'

In [None]:
# check that the labels aren't wack
bins = (
    np.linspace(-300, 1000, 128),
    np.linspace(-10, 200, 128)
)

fig, ax = plt.subplots(figsize=(8, 5))

# ax.scatter(labels, labels / label_errors, c="k", s=1., alpha=0.05)
ax.hist2d(
    labels, 
    labels / label_errors, 
    bins=bins,
    norm=mpl.colors.LogNorm(),
    cmap='Greys'
)

ax.axhline(np.median(labels / label_errors), color="k", ls='--')
ax.set_xlim(bins[0].min(), bins[0].max())
ax.set_ylim(bins[1].min(), bins[1].max())
ax.set_xlabel(label_name)
ax.set_ylabel("label SNR")

fig.tight_layout()

In [None]:
# check that the features aren't wack

for i in range(min(features.shape[1], 8)):
    f = plt.figure()
    foo = np.percentile(features[:, i], [2.5, 97.5])
    lo = 0.5 * (foo[1] + foo[0]) - (foo[1] - foo[0])
    hi = 0.5 * (foo[1] + foo[0]) + (foo[1] - foo[0])
    # plt.scatter(features[:, i], labels, c="k", s=1., alpha=0.05)
    bins = (
        np.linspace(*np.nanpercentile(features[:, i], [1e-1, 100-1e-1]), 128),
        np.linspace(-300, 1200, 128)
    )
    plt.hist2d(features[:, i], labels, bins=bins, 
               cmap='Greys', norm=mpl.colors.LogNorm())
    plt.xlim(lo, hi)
    plt.ylim(bins[1].min(), bins[1].max())
    plt.xlabel(feature_names[i])
    plt.ylabel(label_name)

## Make training and validation samples

cut into eighths

In [None]:
rng = np.random.default_rng(seed=42)

rando = rng.integers(8, size=len(features))
train = rando != 0
valid = ~train

X_train, X_valid = features[train], features[valid]
Y_train, Y_valid = labels[train], labels[valid]
W_train, W_valid = label_weights[train], label_weights[valid]

print(X_train.shape, X_valid.shape)
print(Y_train.shape, Y_valid.shape)
print(W_train.shape, W_valid.shape)

## Build a kNN model and validate it

Get all possibly useful validation-set neighbors up-front.
We'll use them in various ways below.

In [None]:
maxk = 64  # MAGIC  
# P_tree = 2 * 5 + 1  # MAGIC
P_tree = 2 * 15 + 1  # MAGIC
tree = KDTree(X_train[:, :P_tree], leaf_size=32) # magic
dists, inds = tree.query(X_valid[:, :P_tree], k=maxk)
print(X_valid.shape, dists.shape, inds.shape)

In [None]:
ks = 2 ** np.arange(0, int(np.log2(maxk)) + 1, 2)
ks

In [None]:
# implement weighted-mean method for KNN.
weighted_means = {}
weighted_errs = {}
for k in ks:
    assert k <= maxk
    weighted_means[k] = (
        np.sum(Y_train[inds[:, :k]] * W_train[inds[:, :k]], axis=1) / 
        np.sum(W_train[inds[:, :k]], axis=1)
    )
    weighted_errs[k] = np.sqrt(1 / np.sum(W_train[inds[:, :k]], axis=1))

In [None]:
def scale(x):
    x = np.array(x)
    return (x - x.min()) / (x.max() - x.min())

In [None]:
# Let's look at a few objects
cmap = plt.get_cmap('turbo')

for ii in range(8):
    fig, axes = plt.subplots(1, 2, figsize=(11, 5))
    
    ax = axes[0]
    ax.axhline(Y_valid[ii], c="r")
    ax.axhspan(
        Y_valid[ii] - 1 / np.sqrt(W_valid[ii]),
        Y_valid[ii] + 1 / np.sqrt(W_valid[ii]),
        color='r', alpha=0.25, linewidth=0
    )
    
    colors = cmap(scale(np.log(list(weighted_means.keys()))))
    for color, (kk, mean) in zip(colors, weighted_means.items()):
        ax.axhline(mean[ii], linestyle='--', alpha=0.4, color=color)
        ax.axhspan(
            mean[ii] - weighted_errs[kk][ii],
            mean[ii] + weighted_errs[kk][ii],
            alpha=0.4, color=color, linewidth=0
        )
    
    ax.errorbar(dists[ii], 
                 Y_train[inds[ii]], 
                 yerr=1. / np.sqrt(W_train[inds[ii]]),
                 fmt="o", color="k", ecolor="k")
    ax.set_xlabel("distance to neighbor")
    ax.set_ylabel("label of neighbor")
    ax.set_title(f"validation-set object {ii}")
    
    # ---
    
    ax = axes[1]
    
    bins = (
        np.linspace(-0.5, 3.5, 128),
        np.linspace(-4, 12, 128)
    )
    ax.hist2d(
        bprp,
        mg,
        bins=bins,
        cmap='Greys',
        norm=mpl.colors.LogNorm()
    )
    ax.scatter(
        bprp[valid][ii],
        mg[valid][ii],
        s=10,
        color='tab:red',
        zorder=100
    )
    ax.scatter(
        bprp[train][inds[ii]],
        mg[train][inds[ii]],
        s=4,
        color='tab:blue',
        alpha=0.5,
        zorder=10
    )
    ax.set_xlim(0., 4.)
    ax.set_ylim(10, -4)
    
    ax.set_xlabel('$G_{BP}-G_{RP}$')
    ax.set_ylabel('$M_G$')
    
    fig.tight_layout()

CMD colored by discrepancy

In [None]:
for color, (kk, Y_pred) in zip(colors, weighted_means.items()):
#     dy = (Y_valid - Y_pred) / Y_valid
    dy = (Y_valid - Y_pred)
    
    fig, ax = plt.subplots(1, 1, figsize=(7, 6))
    cs = ax.scatter(
        bprp[valid],
        mg[valid],
        c=dy,
        vmin=-.25, vmax=.25,
        cmap='RdBu',
        s=2
    )
    ax.set_xlim(0., 4.)
    ax.set_ylim(10, -4)
    
    cb = fig.colorbar(cs)
    
    ax.set_xlabel('$G_{BP}-G_{RP}$')
    ax.set_ylabel('$M_G$')
    ax.set_title(f'K={kk}')
    fig.tight_layout()

In [None]:
# implement weighted linear least-squares method for KNN.
ks = 2 ** np.arange(0, int(np.log2(maxk)) + 1, 2)
Ps = [3, 11, 33, 101]

X_fit_train = np.hstack((np.ones(X_train.shape[0])[:, None], X_train))
N, P = X_fit_train.shape

X_fit_valid = np.hstack((np.ones(X_valid.shape[0])[:, None], X_valid))
Nvalid, Pvalid = X_fit_valid.shape

assert Pvalid == P

Y_valid_preds = {(P, k): np.zeros(Nvalid) for k in ks for P in Ps}
# weighted_lls_errs = {}

# TODO: Regularization
alpha = 1e-8

for P in Ps:
    L = np.eye(P) * alpha
    Linv = np.eye(P) * 1 / alpha

    for k in ks:
        assert k <= maxk

        # TODO: switch to linalg.lstsq when you hit singular matrix shit
        for ii, ind in tqdm(enumerate(inds[:, :k]), total=Nvalid):
            slc = (ind, slice(None, P))
            
            C_train = np.diag(1 / W_train[ind])
            Cinv_train = np.diag(W_train[ind])

            if k > P:
                Y_valid_preds[P, k][ii] = (
                    X_fit_valid[ii, :P] @ np.linalg.solve(
                        X_fit_train[slc].T @ Cinv_train @ X_fit_train[slc] + L,
                        X_fit_train[slc].T @ Cinv_train @ Y_train[ind]
                    )
                )
            else:
                Y_valid_preds[P, k][ii] = (
                    X_fit_valid[ii, :P] @ Linv @ X_fit_train[slc].T @ np.linalg.solve(
                        X_fit_train[slc] @ Linv @ X_fit_train[slc].T + C_train,
                        Y_train[ind]
                    )
                )

        # BUG: The next line is WRONG
    #     weighted_lls_errs[k] = np.sqrt(1 / np.sum(W_train[inds[:, :k]], axis=1))

In [None]:
fig, axes = plt.subplots(
    len(ks), len(Ps), 
    figsize=(12, 12),
    sharex=True, sharey=True,
    constrained_layout=True
)

for i, P in enumerate(Ps):
    for j, k in enumerate(ks):
        ax = axes[j, i]
        
        # dy = (Y_valid - Y_valid_preds[P, k]) / Y_valid
        dy = (Y_valid - Y_valid_preds[P, k]) 

        _cs = ax.scatter(
            bprp[valid],
            mg[valid],
            c=dy,
            vmin=-0.25, vmax=0.25,
            cmap='RdBu',
            s=1
        )

        ax.set_xlim(0., 4.)
        ax.set_ylim(10, -4)

        ax.set_title(f'K={k}, P={P}')

for ax in axes[-1]:
    ax.set_xlabel('$G_{BP}-G_{RP}$')
for ax in axes[:, 0]:
    ax.set_ylabel('$M_G$')

cb = fig.colorbar(_cs, ax=axes, aspect=30)

In [None]:
# _cs = plt.scatter(
#     bprp[valid],
#     mg[valid],
#     c=Y_valid_preds[3, 64] - Y_valid_preds[101, 64],
#     vmin=-10, vmax=10,
#     cmap='RdBu',
#     s=1
# )

# plt.xlim(0., 4.)
# plt.ylim(10, -4)

2D "image" of P vs K, colored by metric (MAD, RMS) in MS box and RGB box

In [None]:
stat_boxes = {
    'ms': (
        (np.abs(bprp - 1.5) < 0.5) &
        (np.abs(mg.value - 7) < 0.5)
    ),
    'rc': (
        (np.abs(bprp - 1.2) < 0.5) &
        (np.abs(mg.value - 0.9) < 0.5)
    ),
    'rgb': (
        (np.abs(bprp - 1.2) < 0.5) &
        (np.abs(mg.value - 1) < 0.5)
    ),
    'trgb': (
        (bprp > 1) &
        (bprp < 4) &
        (np.abs(mg.value - -1) < 0.5)
    )
}

In [None]:
stats = {}
for name, box_mask in stat_boxes.items():
    stats[name] = np.zeros((len(Ps), len(ks)))
    for i, P in enumerate(Ps):
        for j, k in enumerate(ks):
            chi = (np.sqrt(W_valid) * (Y_valid - Y_valid_preds[P, k]))[box_mask[valid]].value
            meanchi2 = np.mean(chi**2)
            medchi2 = np.median(chi**2)

#             stats[name][i, j] = meanchi2
            stats[name][i, j] = medchi2

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(11, 10), 
                         sharex=True, sharey=True,
                         constrained_layout=True)
for ax, (name, stat) in zip(axes.flat, stats.items()):
    _cs = ax.imshow(stat.T, cmap='turbo', origin='lower')
    fig.colorbar(_cs, ax=ax)
    ax.set_xlabel('$P$')
    ax.set_ylabel('$k$')
    
    ax.set_xticks(np.arange(len(Ps)))
    ax.set_yticks(np.arange(len(ks)))
    ax.set_xticklabels([str(x) for x in Ps])
    ax.set_yticklabels([str(x) for x in ks])
    
    ax.set_title(name)

In [None]:
# for name, stat in stats.items():
#     fig, ax = plt.subplots(figsize=(6.5, 5.5), constrained_layout=True)
#     _cs = plt.imshow(stat.T, cmap='turbo', origin='lower')
#     fig.colorbar(_cs, ax=ax)
#     ax.set_xlabel('$P$')
#     ax.set_ylabel('$k$')
    
#     ax.set_xticks(np.arange(len(Ps)))
#     ax.set_yticks(np.arange(len(ks)))
#     ax.set_xticklabels([str(x) for x in Ps])
#     ax.set_yticklabels([str(x) for x in ks])
    
#     ax.set_title(name)

In [None]:
valid.sum()

In [None]:
box_mask[valid].sum()

In [None]:
# Test a linear weighted least squares as a function of k
# HOGG: TBD

In [None]:
# Test some kind of mixture model maybe??

## Run this model on EVERYTHING

In [None]:
# APW: We need to figure out the above tests and then run in the data center.