# Schlummernd: Linear Latent Variable Model
See the Text.

## TODO / questions
- 

In [None]:
import pathlib
import corner
import astropy.coordinates as coord
from astropy.stats import median_absolute_deviation as MAD
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import h5py
from tqdm import tqdm
from sklearn.neighbors import KDTree
from pyia import GaiaData
from scipy.stats import binned_statistic

from jax.config import config
config.update("jax_enable_x64", True)
import jaxopt
import optax

from schlummernd import LinearLVM
from schlummernd.data import load_data, Features

# Load APOGEE x Gaia data

see `Assemble-data.ipynb` for more information.

In [None]:
# Upper giant branch:
# g = load_data(
#     filters=dict(
#         TEFF=(3000, 5100), 
#         LOGG=(-0.5, 2.3),
#         M_H=(-3, None),
#         phot_g_mean_mag=(None, 15.5*u.mag),
#         AK_WISE=(-0.1, None)
#     )
# )

# For red clump instead:
g_all = load_data(
    filters=dict(
        TEFF=(4500, 5100), 
        LOGG=(2.3, 2.6),
        M_H=(-3, None),
        phot_g_mean_mag=(None, 15.*u.mag),
        AK_WISE=(-0.1, None),
        # HACK:
        TEFF_ERR=(0, 75),
        M_H_ERR=(0, 0.05),
    )
)

In [None]:
g = g_all[(np.abs(g_all.b) > 15*u.deg) & (g_all.SFD_EBV < 0.2)]
len(g)

In [None]:
bprp = (g.phot_bp_mean_mag - g.phot_rp_mean_mag).value
mg = (g.phot_g_mean_mag - g.get_distance(allow_negative=True).distmod).value

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

ax = axes[0]
H, xb, yb, _ = ax.hist2d(
    g.TEFF,
    g.LOGG,
    bins=(
        np.linspace(3000, 8000, 128),
        np.linspace(-0.5, 5.5, 128)
    ),
    norm=mpl.colors.LogNorm()
)
ax.set_xlim(xb.max(), xb.min())
ax.set_ylim(yb.max(), yb.min())
ax.set_xlabel('TEFF')
ax.set_ylabel('LOGG')

ax = axes[1]
H, xb, yb, _ = ax.hist2d(
    bprp,
    mg,
    bins=(
        np.linspace(-0.5, 3, 128),
        np.linspace(-4, 10.5, 128)
    ),
    norm=mpl.colors.LogNorm()
)
ax.set_xlim(xb.min(), xb.max())
ax.set_ylim(yb.max(), yb.min())
ax.set_xlabel('BP-RP')
ax.set_ylabel('$M_G$')

fig.tight_layout()

# Construct features and labels

Make list of possible labels (and label weights), aligned with the features.

In [None]:
# other_features = {
#     r"$G_{\rm BP}-G_{\rm RP}$": 0.1 * (g.phot_bp_mean_mag - g.phot_rp_mean_mag)
# }
# f_all = Features.from_gaiadata(g, n_bp=32, n_rp=32) # , **other_features)
f_all = Features.from_gaiadata(g, n_bp=5, n_rp=5) # , **other_features)

In [None]:
# Make list of labels (and label weights), aligned with the features.

label_ys = {}
label_errs = {}
label_latex = {}
label_y_perc = {}

schmag_factor = 10 ** (0.2 * g.phot_g_mean_mag.value) / 100.
schmag_err = g.parallax_error.value * schmag_factor
# label_ys['schmag'] = g.parallax.value * schmag_factor
# label_errs['schmag'] = schmag_err
# label_latex['schmag'] = '$G$-band schmag (absmgy$^{-1/2}$)'

for name in ['TEFF', 'M_H', 'LOGG']: # , 'AK_WISE']:
    label_y_perc[name] = np.nanpercentile(g[name], [16, 50, 84])
    scale = (label_y_perc[name][2] - label_y_perc[name][0])
    
    err_col = f'{name}_ERR'
    label_ys[name] = (g[name] - label_y_perc[name][1]) / scale
    if err_col in g.data.colnames:
        label_errs[name] = g[err_col]
    elif name == 'AK_WISE':
        label_errs[name] = 0.05 * np.ones_like(label_ys[name])
    label_errs[name] = label_errs[name] / scale

label_latex['M_H'] = r"$[{\rm M}/{\rm H}]$"
label_latex['LOGG'] = r"$\log g$"
label_latex['TEFF'] = r"$T_{\rm eff}$"
label_latex['AK_WISE'] = r"$A_K$"

label_y = np.hstack([np.array(x)[:, None] for x in label_ys.values()])
label_err = np.hstack([np.array(x)[:, None] for x in label_errs.values()])
assert np.all(label_err > 0)

In [None]:
def untransform_y(y, meta):
    new_y = {}
    for i, name in enumerate(label_ys.keys()):
        scale = (meta[name][2] - meta[name][0])
        new_y[name] = y[:, i] * scale + meta[name][1]
    return new_y

In [None]:
# NNN = 10
# plt.errorbar(
#     g.M_H[:NNN], 
#     label_ys['M_H'][:NNN],
#     xerr=g.M_H_ERR[:NNN],
#     yerr=label_errs['M_H'][:NNN],
#     ls='none',
#     marker='o'
# )

In [None]:
label_y.shape, label_err.shape

In [None]:
_ = corner.corner(label_y)

# Make training and validation samples

cut into eighths

In [None]:
rng = np.random.default_rng(seed=42)

rando = rng.integers(8, size=len(f_all))
train = rando != 0
valid = (
    ~train #&
#     (g.LOGG < 2.2) &
#     ((label_ys[label_name] * np.sqrt(label_weights[label_name])) > 4)
)

f_train = f_all[train]
f_valid = f_all[valid]

X_train, X_valid = f_train.X, f_valid.X
X_train_err, X_valid_err = f_train.X_err, f_valid.X_err
y_train, y_valid = label_y[train], label_y[valid]
y_train_err, y_valid_err = label_err[train], label_err[valid]

print(X_train.shape, X_valid.shape)
print(X_train_err.shape, X_valid_err.shape)
print(y_train.shape, y_valid.shape)
print(y_train_err.shape, y_valid_err.shape)

In [None]:
# _ = corner.corner(X_train)

In [None]:
# _ = corner.corner(X_valid)

# Find neighbors for validation sample stars

In [None]:
K = 128  # TODO: need to assess this  
# n_xp_tree = 8  # MAGIC
n_xp_tree = 5

X_tree = f_all.slice_bp(n_xp_tree).slice_rp(n_xp_tree).X_tree
X_train_tree = X_tree[train]
X_valid_tree = X_tree[valid]

tree = KDTree(X_train_tree, leaf_size=32) # magic
dists, inds = tree.query(X_valid_tree, k=K)
print(X_valid.shape, dists.shape, inds.shape)

# Run LLVM model


In [None]:
class HoggLLVM:
    def __init__(self, X, Y, Wx, Wy, B, Lambda):
        """
        ## inputs:
        `X`: shape `(n, r)` block of training features
        `Y`: shape `(n, q)` block of training labels
        `Wx`: shape `(n, r)` block of inverse variances (weights) for the features
        `Wy`: shape `(n, q)` block of inverse variances (weights) for the labels
        `B`: shape `(q, d)` matrix translating latents to labels.
        `Lambda`: regularization strength; use the source, Luke.
        
        ## bugs / comments:
        - Put `np.NaN` into the parts of `B` you want to optimize!
        - This code more-or-less assumes that your data and labels and `B` are
          normalized to reasonable ranges.
        - In the long run, `X` and `Y` don't need to be the same length.
        """
        self.X = jnp.array(X)
        self.Y = jnp.array(Y)
        self.n, self.r = self.X.shape
        enn, self.q = self.Y.shape
        assert enn == self.n, "llvm: Right now, the training data must be rectangular."
        self.Wx = jnp.array(Wx)
        assert self.Wx.shape == self.X.shape, "llvm: Inconsistency between `X` and `Wx`."
        assert np.all(self.Wx >= 0.), "llvm: Weights can't be negative."
        self.sqrtWx = jnp.sqrt(self.Wx)
        self.Wy = jnp.array(Wy)
        assert self.Wy.shape == self.Y.shape, "llvm: Inconsistency between `Y` and `Wy`."
        assert np.all(self.Wy >= 0.), "llvm: Weights can't be negative."
        self.sqrtWy = jnp.sqrt(Wy)
        self.B = jnp.array(B)
        cue, self.d = B.shape
        assert cue == self.q, "llvm: Inconsistency between `B` and `Y`."
        self.B_elements_to_fit = jnp.isnan(self.B)
        if jnp.sum(self.B_elements_to_fit) == 0:
            self.B_elements_to_fit = None
            print("llvm: found no free elements of `B`.")
        else:
            self.B[self.B_elements_to_fit] = 0.
            print("llvm: found", np.sum(self.B_elements_to_fit),
                  "free elements of `B`.")
        self.free_elements_of_Z = jnp.isclose(jnp.sum(self.B, axis=0), 0.)
        print(self.free_elements_of_Z)
        print("llvm: found", np.sum(self.free_elements_of_Z),
              "unconstrained elements of `Z`, of", self.d)
        self.Lambda = Lambda
        self.regularization_matrix = Lambda * np.diag(self.free_elements_of_Z.astype(int))
        print(self.regularization_matrix)
        assert Lambda > 0., "llvm: You must regularize, and strictly positively."
        self.initialize_latents()
        return

    def _predict_X(self):
        return self.mux[None, :] + self.Z @ self.A.T

    def _predict_Y(self):
        return self.muy[None, :] + self.Z @ self.B.T

    def _chi_X(self):
        return (self.X - self._predict_X()) * self.sqrtWx

    def _chi_Y(self):
        return (self.Y - self._predict_Y()) * self.sqrtWy

    def _cost(self):
        """
        WARNING: Regularization term is totally wrong.
        """
        Xchi, Ychi = self._chi_X(), self._chi_Y()
        return 0.5 * jnp.sum(self._chi_X() ** 2) \
             + 0.5 * jnp.sum(self._chi_Y() ** 2) \
             + 0.5 * self.Lambda * jnp.sum(self.Z[:, self.free_elements_of_Z] ** 2)

    def predict_y_given_x(self, Xstar, Wxstar):
        # should this use the regularization matrix? Hogg thinks not.
        m, arrh = Xstar.shape
        assert arrh == self.r
        Ystarhat = np.zeros((m, self.q))
        sqrtWx = np.sqrt(Wxstar)
        chi = (Xstar - self.mux[None, :]) * sqrtWx
        for i, x in enumerate(chi):
            M = self.A * sqrtWx[i][:, None]
            z = np.linalg.lstsq(M, x, rcond=None)[0]
            Ystarhat[i] = self.muy + self.B @ z
        return Ystarhat

    def initialize_latents(self):
        # this is a bag of hacks.

        # Zeoth hack: Take means.
        self.mux = jnp.sum(self.X * self.Wx, axis=0) / jnp.sum(self.Wx, axis=0)
        assert self.mux.shape == (self.r, )
        self.muy = jnp.sum(self.Y * self.Wy, axis=0) / jnp.sum(self.Wy, axis=0)
        assert self.muy.shape == (self.q, )

        # First hack: Start with the pseudo-inverse of `B`.
        # BUG: Doesn't use weights.
        self.Z = jnp.linalg.lstsq(self.B, (self.Y - self.muy[None, :]).T, rcond=None)[0].T
        assert self.Z.shape == (self.n, self.d)

        # Second hack: Add noise.
        TINY = 1.e-1 # MAGIC
        sigma = np.std(self.Z) + TINY
        self.Z += 0.1 * sigma * np.random.normal(size=self.Z.shape) # MAGIC

        # Third hack: Run A and B steps.
        self.A = jnp.zeros((self.r, self.d))
        self.A_step()
        self.B_step()
        
    def optimize_step(self, ftol=0.1):
        self._renorm_A()
        before = self._cost()
        self.Z_step()
        self.mu_step()
        self.A_step()
        self.B_step()
        after = self._cost()
        print("optimize_step(): Cost after:", after, self.n * (self.r + self.q))
        return after < (before - ftol)

    def Z_step(self):
        # BUG: DOESN'T DO REGULARIZATION RIGHT.
        dZ = np.zeros_like(self.Z)
        for i, (x, y) in enumerate(zip(self._chi_X(),
                                       self._chi_Y())):
            resid = np.append(x, y)
            matrix = np.concatenate((self.A * (self.sqrtWx[i])[:, None],
                                     self.B * (self.sqrtWy[i])[:, None]),
                                    axis=0)
            dZ[i] = np.linalg.lstsq(matrix.T @ matrix
                                    + self.regularization_matrix,
                                    matrix.T @ resid, rcond=None)[0]
        self.Z = self.Z + dZ
        return

    def A_step(self):
        dA = np.zeros_like(self.A)
        for j, x in enumerate(self._chi_X().T):
            dA[j] = np.linalg.lstsq(self.Z * self.sqrtWx[:, j][None, :].T,
                                    x, rcond=None)[0]
        self.A = self.A + dA
        return

    def _renorm_A(self):
        renorm = np.sqrt(np.sum(self.A[:, self.free_elements_of_Z] ** 2, axis=0))
        self.A.at[:, self.free_elements_of_Z].divide(renorm[None, :])
        self.Z.at[:, self.free_elements_of_Z].multiply(renorm[None, :])
        return

    def B_step(self):
        # BUG: Not currently operable.
        if self.B_elements_to_fit is None:
            return
        assert False
        return

    def mu_step(self):
        Xresid = self.X - self._predict_X()
        self.mux = self.mux + jnp.sum(Xresid * self.Wx, axis=0) / jnp.sum(self.Wx, axis=0)
        Yresid = self.Y - self._predict_Y()
        self.muy = self.muy + jnp.sum(Yresid * self.Wy, axis=0) / jnp.sum(self.Wy, axis=0)
        return

In [None]:
valid_n = 20
idx = inds[valid_n]

rng = np.random.default_rng(42)
n_labels = y_train.shape[1]
n_latents = n_labels + 3
B = np.zeros((n_labels, n_latents))
B[:n_labels, :n_labels] = np.eye(n_labels)

llvm_hogg = HoggLLVM(
    X_train[idx], y_train[idx], 
    1/X_train_err[idx]**2, 1/y_train_err[idx]**2, 
    B, 0.1
)

In [None]:
while llvm_hogg.optimize_step():
    True

In [None]:
Yhat = llvm_hogg.predict_y_given_x(X_valid[valid_n:valid_n+1], 
                                   1/X_valid_err[valid_n:valid_n+1]**2)
Yhat

In [None]:
untransform_y(Yhat, label_y_perc)

In [None]:
untransform_y(y_valid[valid_n:valid_n+1], label_y_perc)

## Per neighborhood:

## Per-source:

In [None]:
valid_n = 10
idx = inds[valid_n]

rng = np.random.default_rng(42)
n_labels = y_train.shape[1]
n_latents = n_labels + 1
B = np.zeros((n_labels, n_latents))
B[:n_labels, :n_labels] = np.eye(n_labels)

llvm = LinearLVM(
    X_train[idx], y_train[idx], 
    X_train_err[idx], y_train_err[idx], 
    B, alpha=1e-3, beta=1., 
    verbose=True, rng=rng
)

In [None]:
x0 = llvm.pack_p()
print(x0.shape)

In [None]:
solver = jaxopt.LBFGS(fun=llvm, maxiter=2**15)
res_bfgs = solver.run(x0)
res_state = llvm.unpack_p(res_bfgs.params)
print(res_bfgs.state.iter_num)

In [None]:
print(
    llvm(x0),
    llvm(res_bfgs.params)
)

### Self-test:

In [None]:
y_train_predict0 = llvm.predict_y(
    X_train[idx], 
    X_train_err[idx], 
    llvm.par_state
)

y_train_predict = llvm.predict_y(
    X_train[idx], 
    X_train_err[idx],  
    res_state
)

In [None]:
for q in range(y_train.shape[1]):
    plt.figure()
    plt.scatter(
        y_train[idx, q],
        y_train_predict0[:, q]
    )
    plt.scatter(
        y_train[idx, q],
        y_train_predict[:, q]
    )

In [None]:
x0 = llvm.pack_p()
print(x0.shape)
llvm(x0)

In [None]:
y_valid_predict0 = llvm.predict_y(
    X_valid[valid_n:valid_n+1], 
    X_valid_err[valid_n:valid_n+1], 
    llvm.par_state
)

y_valid_predict = llvm.predict_y(
    X_valid[valid_n:valid_n+1], 
    X_valid_err[valid_n:valid_n+1], 
    res_state
)

print(y_valid[valid_n:valid_n+1])
print(y_valid_predict0)
print(y_valid_predict)

In [None]:
untransform_y(y_valid[valid_n:valid_n+1], label_y_perc)

In [None]:
untransform_y(y_valid_predict, label_y_perc)

---

In [None]:
solver = jaxopt.LBFGS(fun=llvm, maxiter=10000)
res_bfgs = solver.run(x0)
print(res_bfgs.state.iter_num)
llvm(res_bfgs.params)

In [None]:
opt = optax.adam(1.)
solver = jaxopt.OptaxSolver(opt=opt, fun=llvm, maxiter=100000)
res_adam = solver.run(x0)
print(res_adam.state.iter_num)
llvm(res_adam.params)

In [None]:
# llvm.unpack_p(res_bfgs.params)

In [None]:
res_state = llvm.unpack_p(res_bfgs.params)
# res_state = llvm.unpack_p(res_adam.params)

# y_valid_predict0 = llvm.predict_y(X_valid, X_valid_err, llvm.par_state)
# y_valid_predict = llvm.predict_y(X_valid, X_valid_err, res_state)

y_valid_predict = llvm.predict_y(
    X_valid[valid_n:valid_n+1], 
    X_valid_err[valid_n:valid_n+1], 
    res_state
)

In [None]:
y_valid[valid_n:valid_n+1]

In [None]:
y_valid_predict0

In [None]:
y_valid_predict

In [None]:
y_valid_err[valid_n:valid_n+1]

In [None]:
# infer for test-set objects
for k in range(y_valid.shape[1]):
    plt.figure()
#     plt.scatter(y_star[:, k], ystar_predict0[:, k], c="r", marker="o")
    plt.scatter(y_valid[:, k], y_valid_predict[:, k], c="k", marker="o")
    plt.plot([y_valid[:, k].min(), y_valid[:, k].max()],
             [y_valid[:, k].min(), y_valid[:, k].max()], 
             marker='', color='tab:blue')
    plt.xlabel(f"true label {k}")
    plt.ylabel(f"prediction of label {k}")
    plt.title("held-out data")