# Experiments related to Joaquin
Technically, this notebook implements something *even dumber* than *Joaquin*.
It implements kNN in *Gaia*-only quantities to get weighted-mean and weighted-least squares estimates of schmag or schmarrn.

## Authors:
- **Adrian Price-Whelan** (Flatiron)
- **David W. Hogg** (NYU) (MPIA) (Flatiron)

## Hyper-parameters:
- `ncoeff`: The maximum number of BP and RP spectral coefficients to use in the project.
- `pee_tree`: The number of features to use in the kdtree.
- `maxk`: The maximum `k` to which we take neighbors; various `k` values are attempted.
- scalings or preprocessing of input features (currently just normalization by `RP[0]`).
- how we use the neighbors (weighted mean, weighted linear fit, mixture of some kind?).

## To-do items and bugs:
- We currently take ALL neighbors. But we don't need to consider neighbors that have obviously discrepant schmags given the extant Gaia data. Should we cut on schmag? Maybe?? It's complicated.
- Many of the KNN collections contain significant outliers. We should do something more robust than just WLS. Maybe some iteratively reweighted LS?

## Read in and munge all data

In [None]:
import pathlib
import astropy.table as at
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import h5py
from tqdm import tqdm
from sklearn.neighbors import KDTree

In [None]:
datadir = "./"
xm = at.Table.read(datadir + 'allStar-dr17-synspec-gaiadr3.fits')
xm2 = at.Table.read(datadir + 'allStar-dr17-synspec-gaiadr3-gaiasourcelite.fits')
xm2.rename_column('source_id', 'GAIADR3_SOURCE_ID')
allstar = at.Table.read(datadir + 'allStarLite-dr17-synspec_rev1.fits')

In [None]:
tbl = at.unique(at.hstack((allstar, xm)), keys='APOGEE_ID')
tbl = tbl[tbl['GAIADR3_SOURCE_ID'] != 0]
tbl = at.join(tbl, xm2, keys='GAIADR3_SOURCE_ID')
len(tbl)

In [None]:
apogee_xp_cont_filename = pathlib.Path(datadir + 'apogee-dr17-xpcontinuous.hdf5')

In [None]:
# Read data and lightly rearrange
xp_tbl = at.Table()
with h5py.File(apogee_xp_cont_filename, 'r') as f:
    xp_tbl['GAIADR3_SOURCE_ID'] = f['source_id'][:]
    xp_tbl['bp'] = f['bp_coefficients'][:]
    xp_tbl['rp'] = f['rp_coefficients'][:]

In [None]:
# Read data and make simple cuts
# Hogg: Why these cuts?
xp_apogee_tbl = at.join(tbl, xp_tbl, keys='GAIADR3_SOURCE_ID')
xp_apogee_tbl = xp_apogee_tbl[
    (xp_apogee_tbl['TEFF'] > 3500.) &
    (xp_apogee_tbl['TEFF'] < 6000.) &
    (xp_apogee_tbl['LOGG'] > -0.5) &
    (xp_apogee_tbl['LOGG'] < 5.5) &
    (xp_apogee_tbl['M_H'] > -2.)
]
len(xp_apogee_tbl)

## Make rectangular data

In [None]:
# This does something useful!
xp_apogee_tbl = xp_apogee_tbl.filled()

In [None]:
# Make rectangular block of Gaia-only features (X) for training and validation
# Note the gymnastics around normalizing by RP[0].

# APW, HOGG: BUG: Why these cuts?
feature_mask = (
    (xp_apogee_tbl['J'] < 13) &
    (xp_apogee_tbl['H'] < 12) &
    (xp_apogee_tbl['K'] < 11) &
    (xp_apogee_tbl['AK_WISE'] > -0.1))

ncoeff = 50 # MAGIC
features = np.hstack((
    (xp_apogee_tbl['phot_bp_mean_mag'] - xp_apogee_tbl['phot_rp_mean_mag'])[feature_mask, None],
    (xp_apogee_tbl['bp'][:, 0:ncoeff] / xp_apogee_tbl['rp'][:, 0:1])[feature_mask],
    (xp_apogee_tbl['rp'][:, 1:ncoeff + 1] / xp_apogee_tbl['rp'][:, 0:1])[feature_mask],
))

feature_names = np.concatenate((
    ['$BP-RP$ (mag)', ],
    [f'BP[{i}]' for i in range(0, ncoeff)],
    [f'RP[{i}]' for i in range(1, ncoeff + 1)],
))

print(features.shape)
print(len(feature_names), feature_names)

In [None]:
# rearrange feature order because Hogg has issues

index = np.concatenate((
    [0, ], 
    *([i, ncoeff + i, ] for i in range(1, ncoeff + 1))
))
print(feature_names[index])

features = features[:, index]
feature_names = feature_names[index]
print(features.shape, feature_names.shape)

In [None]:
# Make list of labels (and label weights), aligned with the features.

# Divide by 100 mas to get into correct maggie units.
schmag_factor = 10 ** (0.2 * xp_apogee_tbl['phot_g_mean_mag'].value) / 100.

labels = (xp_apogee_tbl['parallax'].value * schmag_factor)[feature_mask]
print(labels.shape)

label_errors = (xp_apogee_tbl['parallax_error'].value * schmag_factor)[feature_mask]
print(label_errors.shape)

label_weights = 1. / (label_errors ** 2)
print(label_weights.shape)

label_name = '$G$-band schmag (absmgy$^{-1/2}$)'

In [None]:
# check that the labels aren't wack

plt.scatter(labels, labels / label_errors, c="k", s=1., alpha=0.05)
plt.axhline(np.median(labels / label_errors), color="k")
plt.xlim(-10., 50.)
plt.ylim(-10., 200.)
plt.xlabel(label_name)
plt.ylabel("label SNR")

In [None]:
# check that the features aren't wack

for i in range(min(16, features.shape[1])):
    f = plt.figure()
    foo = np.percentile(features[:, i], [2.5, 97.5])
    lo = 0.5 * (foo[1] + foo[0]) - (foo[1] - foo[0])
    hi = 0.5 * (foo[1] + foo[0]) + (foo[1] - foo[0])
    plt.scatter(features[:, i], labels, c="k", s=1., alpha=0.05)
    plt.xlim(lo, hi)
    plt.ylim(-10., 50.)
    plt.xlabel(feature_names[i])
    plt.ylabel(label_name)

## Make training and validation samples

In [None]:
# cut to eighths #MAGIC
# BUG: Should fix random state more sensibly than this.

np.random.seed(17)
rando = np.random.randint(8, size=len(features))
train = rando != 0
valid = ~train
X_train, X_valid = features[train], features[valid]
Y_train, Y_valid = labels[train], labels[valid]
W_train, W_valid = label_weights[train], label_weights[valid]
print(X_train.shape, X_valid.shape)
print(Y_train.shape, Y_valid.shape)
print(W_train.shape, W_valid.shape)

## Build a kNN model and validate it

In [None]:
# Get all possibly useful validation-set neighbors up-front.
# We'll use them in various ways below.
pee_tree = 9 # magic
maxk = 2 ** 12 # magic
tree = KDTree(X_train[:, :pee_tree], leaf_size=32) # magic
dists, inds = tree.query(X_valid[:, :pee_tree], k=maxk)
print(X_valid.shape, dists.shape, inds.shape)

In [None]:
# Let's look at a few objects
for jj in range(8):
    ii = np.random.randint(len(Y_valid))
    while Y_valid[ii] > 1.:
        ii = np.random.randint(len(Y_valid))
    ff = plt.figure()
    plt.axhline(Y_valid[ii], c="r")
    plt.errorbar(dists[ii], Y_train[inds[ii]], yerr = 1. / np.sqrt(W_train[inds[ii]]),
                 fmt="o", color="k", ecolor="k")
    plt.xlabel("distance to neighbor")
    plt.ylabel("label (schmag) of neighbor")
    plt.title(f"validation-set object {ii}")

In [None]:
# Test weighted mean.
# BUG: this is a bad idea!
ks = 2 ** np.arange(5)
Y_hat_mean, Y_hat_mean_ivar = {}, {}
for k in Y_hat_mean.keys():
    I = inds[:, :k]
    Y_hat_mean_ivar[k] = np.sum(W_train[I], axis=1)
    Y_hat_mean[k] = np.sum(W_train[I] * Y_train[I], axis=1) / Y_hat_mean_ivar[k]
    print(k, Y_hat_mean[k].shape, Y_hat_mean_ivar[k].shape)

In [None]:
def get_sigma(ys, ys_true):
    xs = (ys - ys_true) / (ys_true)
    I = (ys_true > 0.5) & (ys_true < 2.0)
    foo = np.percentile(xs[I], [16, 84])
    return 0.5 * (foo[1] - foo[0])

for k in Y_hat_mean.keys():
    ff = plt.figure()
    plt.plot([-100, 100], [-100, 100], "k-")
    plt.plot(Y_valid, Y_hat_mean[k], "k.", alpha=0.1)
    plt.axis("equal")
    plt.xlim(-0.2, 2)
    plt.ylim(-0.2, 2)
    plt.xlabel("Gaia-measured schmag")
    plt.ylabel("weighted mean KNN predicted schmag")
    sigma = get_sigma(Y_hat_mean[k], Y_valid)
    plt.title("mean of KNN, $k={0}$, fractional $\sigma={1:4.2f}$".format(k, sigma))

In [None]:
# Test unweighted median
# BUG: this is a bad idea!
Y_hat_med = {}
for k in ks:
    I = inds[:, :k]
    Y_hat_med[k] = np.median(Y_train[I], axis=1)
    print(k, Y_hat_med[k].shape)

In [None]:
for k in Y_hat_med.keys():
    ff = plt.figure()
    plt.plot([-100, 100], [-100, 100], "k-")
    plt.plot(Y_valid, Y_hat_med[k], "k.", alpha=0.1)
    plt.axis("equal")
    plt.xlim(-0.2, 2)
    plt.ylim(-0.2, 2)
    plt.xlabel("Gaia-measured schmag")
    plt.ylabel("median KNN predicted schmag")
    sigma = get_sigma(Y_hat_med[k], Y_valid)
    plt.title("median of KNN, $k={0}$, fractional $\sigma={1:4.2f}$".format(k, sigma))

In [None]:
# Test a linear weighted least squares as a function of k
# BUG: UNTESTED
# BUG: DOESN'T RETURN IVARS
ks = 2 ** np.arange(9, 12)
Y_hat_wls, Y_hat_wls_ivar = {}, {}
for k in ks:
    Y_hat_wls[k] = np.zeros_like(Y_valid) + np.NaN
    I = inds[:, :k]
    for i, II in enumerate(I):
        # make design matrix
        X = np.hstack((np.ones((k, 1)), X_train[II]))
        Xstar = np.append(1, X_valid[i])
        Cinv = W_train[II]
        Y = Y_train[II]
        Y_hat_wls[k][i] = Xstar @ np.linalg.lstsq(X.T @ (Cinv[:, None] * X),
                                                  X.T @ (Cinv * Y),
                                                  rcond=None)[0]
    print(k, Y_hat_wls[k].shape)

In [None]:
for k in Y_hat_wls.keys():
    ff = plt.figure()
    plt.plot([-100, 100], [-100, 100], "k-")
    plt.plot(Y_valid, Y_hat_wls[k], "k.", alpha=0.1)
    plt.axis("equal")
    plt.xlim(-0.2, 2)
    plt.ylim(-0.2, 2)
    plt.xlabel("Gaia-measured schmag")
    plt.ylabel("WLS of KNN predicted schmag")
    sigma = get_sigma(Y_hat_wls[k], Y_valid)
    plt.title("WLS of KNN, $k={0}$, fractional $\sigma={1:4.2f}$".format(k, sigma))

In [None]:
# Test some kind of mixture model maybe??

## Run this model on EVERYTHING

In [None]:
# APW: We need to figure out the above tests and then run in the data center.