# Linear Latent Variable Model
See the Text.

## Authors:
- **Adrian Price-Whelan** (Flatiron)
- **David W. Hogg** (NYU) (MPIA) (Flatiron)

## TODO / questions
- 

In [None]:
import pathlib
import astropy.coordinates as coord
from astropy.stats import median_absolute_deviation as MAD
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import h5py
from tqdm import tqdm
from sklearn.neighbors import KDTree
from pyia import GaiaData
from scipy.stats import binned_statistic

import jax.numpy as jnp

from helpers import load_data, Features

# Load APOGEE x Gaia data

see `Assemble-data.ipynb` for more information.

In [None]:
# Upper giant branch:
# g = load_data(
#     filters=dict(
#         TEFF=(3000, 5100), 
#         LOGG=(-0.5, 2.3),
#         M_H=(-3, None),
#         phot_g_mean_mag=(None, 15.5*u.mag),
#         AK_WISE=(-0.1, None)
#     )
# )

# For red clump instead:
g = load_data(
    filters=dict(
        TEFF=(4500, 5100), 
        LOGG=(2.3, 2.6),
        M_H=(-3, None),
        phot_g_mean_mag=(None, 15.5*u.mag),
        AK_WISE=(-0.1, None)
    )
)

g = g[(np.abs(g.b) > 15*u.deg) & (g.SFD_EBV < 0.2)]

len(g)

In [None]:
bprp = (g.phot_bp_mean_mag - g.phot_rp_mean_mag).value
mg = (g.phot_g_mean_mag - g.get_distance(allow_negative=True).distmod).value

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

ax = axes[0]
H, xb, yb, _ = ax.hist2d(
    g.TEFF,
    g.LOGG,
    bins=(
        np.linspace(3000, 8000, 128),
        np.linspace(-0.5, 5.5, 128)
    ),
    norm=mpl.colors.LogNorm()
)
ax.set_xlim(xb.max(), xb.min())
ax.set_ylim(yb.max(), yb.min())
ax.set_xlabel('TEFF')
ax.set_ylabel('LOGG')

ax = axes[1]
H, xb, yb, _ = ax.hist2d(
    bprp,
    mg,
    bins=(
        np.linspace(-0.5, 3, 128),
        np.linspace(-4, 10.5, 128)
    ),
    norm=mpl.colors.LogNorm()
)
ax.set_xlim(xb.min(), xb.max())
ax.set_ylim(yb.max(), yb.min())
ax.set_xlabel('BP-RP')
ax.set_ylabel('$M_G$')

fig.tight_layout()

# Construct features and labels

Make list of possible labels (and label weights), aligned with the features.

In [None]:
other_features = {
    r"$G_{\rm BP}-G_{\rm RP}$": 0.1 * (g.phot_bp_mean_mag - g.phot_rp_mean_mag)
}
f_all = Features.from_gaiadata(g, n_bp=32, n_rp=32, **other_features)

In [None]:
# Make list of labels (and label weights), aligned with the features.

label_ys = {}
label_errs = {}
label_latex = {}

schmag_factor = 10 ** (0.2 * g.phot_g_mean_mag.value) / 100.
schmag_err = g.parallax_error.value * schmag_factor
label_ys['schmag'] = g.parallax.value * schmag_factor
label_errs['schmag'] = schmag_err
label_latex['schmag'] = '$G$-band schmag (absmgy$^{-1/2}$)'

for name in ['M_H', 'LOGG', 'TEFF']: # , 'AK_WISE']:
    err_col = f'{name}_ERR'
    label_ys[name] = g[name]
    if err_col in g.data.colnames:
        label_errs[name] = g[err_col]
    elif name == 'AK_WISE':
        label_errs[name] = 0.05 * np.ones_like(label_ys[name])

label_latex['M_H'] = r"$[{\rm M}/{\rm H}]$"
label_latex['LOGG'] = r"$\log g$"
label_latex['TEFF'] = r"$T_{\rm eff}$"
label_latex['AK_WISE'] = r"$A_K$"

label_y = np.hstack([np.array(x)[:, None] for x in label_ys.values()])
label_err = np.hstack([np.array(x)[:, None] for x in label_errs.values()])

# Make training and validation samples

cut into eighths

In [None]:
rng = np.random.default_rng(seed=42)

rando = rng.integers(8, size=len(f_all))
train = rando != 0
valid = (
    ~train #&
#     (g.LOGG < 2.2) &
#     ((label_ys[label_name] * np.sqrt(label_weights[label_name])) > 4)
)

f_train = f_all[train]
f_valid = f_all[valid]

X_train, X_valid = f_train.X, f_valid.X
y_train, y_valid = label_y[train], label_y[valid]
w_train, w_valid = label_err[train], label_err[valid]

print(X_train.shape, X_valid.shape)
print(y_train.shape, y_valid.shape)
print(w_train.shape, w_valid.shape)

# Define the LLVM model

- This code more-or-less assumes that your data and labels and `B` are
  normalized to reasonable ranges.
- In the long run, `X` and `Y` don't need to be the same length.

In [None]:
from dataclasses import dataclass
import jax.numpy as jnp
import jax

In [None]:
@dataclass
class ParameterState:
    sizes: dict
    A: ('R', 'D')  # i.e. shape = (R, D)
    B: ('Q', 'D')
    z: ('N', 'D')
    mu_X: ('R', )
    mu_y: ('Q', )
        
    def __post_init__(self):
        for name, field in self.__dataclass_fields__.items():
            if name == 'sizes':
                continue
            shape = tuple([self.sizes[x] for x in field.type])
            # setattr(self, name, np.array(getattr(self, name)))
            
            got_shape = getattr(self, name).shape
            if got_shape != shape:
                raise ValueError(
                    f'Invalid shape for {name}: expected {field.type}={shape}, got {got_shape}'
                )
        
    @property
    def names(self):
        return [x for x in self.__dataclass_fields__.keys() if x != 'sizes']

In [None]:
def _model_linear(mu, A, x):
    # Formerly:
    # pars.mu_x[None, :] + pars.z @ pars.A.T
    # pars.mu_y[None, :] + pars.z @ pars.B.T
    return mu[None] + x @ A.T


class LinearLVM:
    
    def __init__(self, X, y, X_err, y_err, B, alpha, verbose=False):
        """
        Parameters
        ----------
        X : array-like
            shape `(N, R)` array of training features
        y : array-like
            shape `(N, Q)` array of training labels
        X_err : array-like
            shape `(N, R)` array of errors (standard deviations) for the features
        y_err : array-like
            shape `(N, Q)` array of errors (standard deviations) for the labels
        B : array-like
            shape `(Q, D)` matrix translating latents to labels.
        alpha : numeric
            regularization strength; use the source, Luke.
        """
        self.verbose = verbose
        
        self.X = jnp.array(X)
        self.y = jnp.array(y)
        self.X_err = jnp.array(X_err)
        self.y_err = jnp.array(y_err)
        
        self.sizes = {}
        self.sizes['N'], self.sizes['R'] = self.X.shape    
        self.sizes['Q'] = self.y.shape[1]
        
        shp_msg = "Invalid shape for {object_name}: got {got}, expected {expected})"
        if self.y.shape[0] != self.sizes['N']:
            shp_msg.format(
                object_name="training labels y",
                got=self.y.shape[0],
                expected=self.sizes['N']
            )
        if self.X_err.shape != self.X.shape:
            shp_msg.format(
                object_name="X_err",
                got=self.X_err.shape,
                expected=self.X.shape
            )
        if self.y_err.shape != self.y.shape:
            shp_msg.format(
                object_name="y_err",
                got=self.y_err.shape,
                expected=self.y.shape
            )
        
        self._X_ivar = 1 / self.X_err**2
        self._y_ivar = 1 / self.y_err**2
        
        # B turned into a Jax array below
        B = np.array(B, copy=True)
        _, self.sizes['D'] = B.shape
        if B.shape[0] != self.sizes['Q']:
            shp_msg.format(
                object_name="B",
                got=B.shape[0],
                expected=self.sizes['Q']
            )
            
        # The elements of B that we will fit for should be set to nan in the 
        # input B array
        self._B_fit_mask = jnp.isnan(B)
        if not np.any(self._B_fit_mask) and verbose:
            print("no free elements of B")
        elif np.any(self._B_fit_mask):
            B[self._B_fit_mask] = 0.
            if verbose:
                print(f"using {self._B_fit_mask.sum()} free elements of B")
        self.B = jnp.array(B)
        if verbose:
            print(f"B = {B}")
            print(f"B fit elements = {self._B_fit_mask}")
        
        # Now assess which latents to fit:
        self._z_fit_mask = jnp.any(self._B_fit_mask, axis=0)
        if verbose:
            print(
                f"using {self._z_fit_mask.sum()} unconstrained elements of z, "
                f"out of {self.sizes['D']} latents"
            )
        
        self.alpha = float(alpha)
        
        # Regularization matrix:
        self.Lambda = self.alpha * np.diag(self._z_fit_mask.astype(int))
        if verbose:
            print(f"Lambda = {self.Lambda}")
        assert self.alpha > 0., "You must regularize, and strictly positively."
        
        # TODO:
        self.par_state = self.initialize_par_state()
        
    def initialize_par_state(self, **state):
        """
        TODO: this is a little hacky
        """

        # Initialize the means using invvar weighted means
        # TODO: could do sigma-clipping here to be more robust
        if 'mu_X' not in state:
            state['mu_X'] = (
                jnp.sum(self.X * self._X_ivar, axis=0) / 
                jnp.sum(self._X_ivar, axis=0)
            )
        
        if 'mu_y' not in state:
            state['mu_y'] = (
                jnp.sum(self.y * self._y_ivar, axis=0) / 
                jnp.sum(self._y_ivar, axis=0)
            )

        if 'z' not in state:
            # First hack: Start with the pseudo-inverse of `B`.
            # BUG: Doesn't use weights.
            state['z'] = jnp.linalg.lstsq(
                self.B, 
                (self.y - state['mu_y'][None, :]).T, 
                rcond=None
            )[0].T

            # Second hack: Add noise.
            # TODO: magic numberz
            TINY = 1e-1 # MAGIC
            sigma = np.std(state['z']) + TINY
            state['z'] += np.random.normal(0, TINY * sigma, size=state['z'].shape)
        
        if 'A' not in state:
            A = np.zeros((self.sizes['R'], self.sizes['D']))
            state['A'] = A.copy()
            for j, chi in enumerate(self._chi_X(state['mu_X'], state['A'], state['z']).T):
                state['A'][j] = np.linalg.lstsq(
                    state['z'] * self._X_ivar[:, j][None].T,
                    chi, rcond=None
                )[0]
            
        if 'B' not in state:
            # TODO: implement this
            state['B'] = self.B
        
        return ParameterState(sizes=self.sizes, **state)

    def _chi_X(self, mu_X, A, z):
        return (self.X - _model_linear(mu_X, A, z)) / self.X_err

    def _chi_y(self, mu_y, B, z):
        return (self.y - _model_linear(mu_y, B, z)) / self.y_err
    
    def unpack_p(self, p):
        """
        TODO: deal with some of B is frozen
        """
        i = 0
        state = {}
        for name in self.par_state.names:
            if name == 'B':
                # TODO: see note above
                state['B'] = self.par_state.B
                continue
                
            val = getattr(self.par_state, name)
            state[name] = p[i:i+val.size].reshape(val.shape)
            i += val.size
        return ParameterState(sizes=self.sizes, **state)
    
    def pack_p(self, par_state=None):
        """
        TODO: deal with some of B is frozen
        """
        if par_state is None:
            par_state = self.par_state
            
        arrs = []
        for name in par_state.names:
            if name == 'B':
                # TODO: deal with note above
                continue
            val = getattr(par_state, name).flatten()
            arrs.append(val)
        return jnp.concatenate(arrs)
    
    def cost(self, p):
        """
        TODO: Regularization term is totally wrong.
        """
        pars = self.unpack_p(p)
        # TODO: set par_state??
        
        chi_X = self._chi_X(pars.mu_X, pars.A, pars.z)
        chi_y = self._chi_y(pars.mu_y, pars.B, pars.z)

        return (
            0.5 * jnp.sum(chi_X ** 2) +
            0.5 * jnp.sum(chi_y ** 2) +
            0.5 * self.alpha * jnp.sum(pars.z[:, self._z_fit_mask] ** 2)
        )
    
    def __call__(self, p):
        val = self.cost(p)
        return val
    
    def predict_y(self, X, X_err, par_state=None):
        if par_state is None:
            par_state = self.par_state
            
        # should this use the regularization matrix? Hogg thinks not.
        M = X.shape[0]
        if X.shape[1] != self.sizes['R']:
            raise ValueError("Invalid shape for input feature matrix X")
            
        y_hat = np.zeros((M, self.sizes['Q']))
        
        chi = (X - par_state.mu_X[None, :]) / X_err
        for i, dx in enumerate(chi):
            M = par_state.A / X_err[i][:, None]
            z = np.linalg.lstsq(M, dx, rcond=None)[0]
            y_hat[i] = par_state.mu_y + par_state.B @ z
            
        return y_hat

Very dumb fake data:

In [None]:
rng = np.random.default_rng(42)

Atrue = rng.normal(size=(17, 5))
Btrue = np.zeros((3,5))
Btrue[:3,:3] = np.eye(3)
Ztrue = rng.normal(size=(191, 5))
Zstartrue = rng.normal(size=(53, 5))
Xtrue = Ztrue @ Atrue.T
Xstartrue = Zstartrue @ Atrue.T
sigma = 0.1
X = Xtrue + sigma * rng.normal(size=Xtrue.shape)
Xstar = Xstartrue + sigma * rng.normal(size=Xstartrue.shape)
Ytrue = Ztrue @ Btrue.T
Ystartrue = Zstartrue @ Btrue.T
Y = Ytrue + sigma * rng.normal(size=Ytrue.shape)
Ystar = Ystartrue + sigma * rng.normal(size=Ystartrue.shape)
xerr = np.zeros_like(X) + sigma
xstarerr = np.zeros_like(Xstar) + sigma
yerr = np.zeros_like(Y) + sigma
ystarerr = np.zeros_like(Ystar) + sigma

B = np.zeros((3,7))
B[:3, :3] = np.eye(3)
alpha = 0.1

In [None]:
llvm = LinearLVM(X, Y, xerr, yerr, B, alpha, verbose=True)

In [None]:
# for name in llvm.par_state.names:
#     assert np.all(getattr(llvm.unpack_p(llvm.pack_p()), name) == getattr(llvm.par_state, name))

In [None]:
from scipy.optimize import minimize

In [None]:
llvm_grad = jax.grad(llvm)

In [None]:
x0 = llvm.pack_p()

# GLOBALS: careful!!
ks = [0]
vals = [llvm(x0)]
fig, ax = plt.subplots(figsize=(6, 5))
ax.plot(ks, vals)

def callback(x):
    global ks, vals, ax
    ks.append(ks[-1] + 1)
    vals.append(llvm(x))
    if (ks[-1] % 4) == 0:
        ax.cla()
        ax.plot(ks, vals)
    
    if (ks[-1] % 10) == 0:
        print(f"{ks[-1]}", end='\r')
    
res = minimize(
    llvm, 
    jac=llvm_grad, 
    x0=x0, 
    callback=callback, 
    options=dict(maxiter=512),
    method='bfgs'
)

In [None]:
plt.plot(ks[-10:], vals[-10:])

In [None]:
res_state = llvm.unpack_p(res.x)
ystar_predict0 = llvm.predict_y(Xstar, xstarerr, llvm.par_state)
ystar_predict = llvm.predict_y(Xstar, xstarerr, res_state)

In [None]:
# infer for test-set objects
for k in range(Ystar.shape[1]):
    plt.figure()
    plt.scatter(Ystar[:, k], ystar_predict0[:, k], c="r", marker="o")
    plt.scatter(Ystar[:, k], ystar_predict[:, k], c="k", marker="o")
    plt.plot([Ystar[:, k].min(), Ystar[:, k].max()],
             [Ystar[:, k].min(), Ystar[:, k].max()], 
             marker='', color='tab:blue')
    plt.xlabel(f"true label {k}")
    plt.ylabel(f"prediction of label {k}")
    plt.title("held-out data")