# Linear Latent Variable Model
See the Text.

## Authors:
- **Adrian Price-Whelan** (Flatiron)
- **David W. Hogg** (NYU) (MPIA) (Flatiron)

## TODO / questions
- 

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from tqdm import tqdm

from scipy.optimize import minimize
import jax
import jaxopt

from jax.config import config
config.update("jax_enable_x64", True)

from schlummernd import LinearLVM

# Make toy fake data:

In [None]:
# N - stars
# R - features
# Q - labels
# D - latents

N = 617
R = 17
Q = 3
D = 5
M = 101

# N = 191
# R = 17
# Q = 3
# D = 5
# M = 53

# ---

rng = np.random.default_rng(42)

A_true = rng.normal(size=(R, D))
B_true = np.zeros((Q, D))
B_true[:Q, :Q] = np.eye(Q)
z_true = rng.normal(size=(N, D))

X_true = z_true @ A_true.T
y_true = z_true @ B_true.T

sigma = 0.1
X = rng.normal(X_true, sigma, size=X_true.shape)  # Noisify
y = rng.normal(y_true, sigma, size=y_true.shape)  # Noisify


z_star_true = rng.normal(size=(M, D))
X_star_true = z_star_true @ A_true.T
y_star_true = z_star_true @ B_true.T
X_star = rng.normal(X_star_true, sigma, size=X_star_true.shape)  # Noisify
y_star = rng.normal(y_star_true, sigma, size=y_star_true.shape)  # Noisify

X_err = np.full_like(X, sigma)
y_err = np.full_like(y, sigma)
X_star_err = np.full_like(X_star, sigma)
y_star_err = np.full_like(y_star, sigma)

alpha = 0.1
beta = 1.

In [None]:
llvm = LinearLVM(X, y, X_err, y_err, B_true, alpha, beta, verbose=True, rng=rng)
# llvm = LinearLVM(X, y, X_err, y_err, B_true, 0.1, 0., verbose=True, rng=rng)

In [None]:
x0 = llvm.pack_p()
llvm(x0)

In [None]:
# for name in llvm.par_state.names:
#     assert np.all(getattr(llvm.unpack_p(llvm.pack_p()), name) == getattr(llvm.par_state, name))

In [None]:
solver = jaxopt.LBFGS(fun=llvm, maxiter=16384)
res = solver.run(x0)
res.state.iter_num

In [None]:
llvm(x0)

In [None]:
llvm(res.params)

In [None]:
res_state = llvm.unpack_p(res.params)
ystar_predict0 = llvm.predict_y(X_star, X_star_err, llvm.par_state)
ystar_predict = llvm.predict_y(X_star, X_star_err, res_state)

In [None]:
# infer for test-set objects
for k in range(y_star.shape[1]):
    plt.figure()
#     plt.scatter(y_star[:, k], ystar_predict0[:, k], c="r", marker="o")
    plt.scatter(y_star[:, k], ystar_predict[:, k], c="k", marker="o")
    plt.plot([y_star[:, k].min(), y_star[:, k].max()],
             [y_star[:, k].min(), y_star[:, k].max()], 
             marker='', color='tab:blue')
    plt.xlabel(f"true label {k}")
    plt.ylabel(f"prediction of label {k}")
    plt.title("held-out data")