# Linear Latent Variable Model
See the Text.

## TODO / questions
- 

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from tqdm import tqdm

from scipy.optimize import minimize
import jax
import jaxopt

from jax.config import config
config.update("jax_enable_x64", True)

from schlummernd import LinearLVM

In [None]:
renorm = np.sqrt(np.sum(self.A[:, self.free_elements_of_Z] ** 2, axis=0))
        self.A.at[:, self.free_elements_of_Z].divide(renorm[None, :])
        self.Z.at[:, self.free_elements_of_Z].multiply(renorm[None, :])

In [None]:
# Standard library
from dataclasses import dataclass
from typing import Annotated, Any, Dict, get_args

# Third-party
import jax.numpy as jnp
import numpy as np

from schlummernd import ParameterState
from schlummernd.lvm import _model_linear


class LinearLVM2:

    def __init__(self, X, y, X_err, y_err, B, alpha, beta,
                 verbose=False, rng=None):
        """
        N - stars
        R - features
        Q - labels
        D - latents

        Parameters
        ----------
        X : array-like
            shape `(N, R)` array of training features
        y : array-like
            shape `(N, Q)` array of training labels
        X_err : array-like
            shape `(N, R)` array of errors (standard deviations) for the features
        y_err : array-like
            shape `(N, Q)` array of errors (standard deviations) for the labels
        B : array-like
            shape `(Q, D)` matrix translating latents to labels.
        alpha : numeric
            regularization strength; use the source, Luke.
        beta : numeric
            burp.
        """
        self.verbose = verbose
        if rng is None:
            rng = np.random.default_rng()
        self.rng = rng

        self.X = jnp.array(X)
        self.y = jnp.array(y)
        self.X_err = jnp.array(X_err)
        self.y_err = jnp.array(y_err)

        self.sizes = {}
        self.sizes['N'], self.sizes['R'] = self.X.shape
        self.sizes['Q'] = self.y.shape[1]

        shp_msg = "Invalid shape for {object_name}: got {got}, expected {expected})"
        if self.y.shape[0] != self.sizes['N']:
            shp_msg.format(
                object_name="training labels y",
                got=self.y.shape[0],
                expected=self.sizes['N']
            )
        if self.X_err.shape != self.X.shape:
            shp_msg.format(
                object_name="X_err",
                got=self.X_err.shape,
                expected=self.X.shape
            )
        if self.y_err.shape != self.y.shape:
            shp_msg.format(
                object_name="y_err",
                got=self.y_err.shape,
                expected=self.y.shape
            )

        self._X_ivar = 1 / self.X_err**2
        self._y_ivar = 1 / self.y_err**2

        # B turned into a Jax array below
        B = np.array(B, copy=True)
        _, self.sizes['D'] = B.shape
        if B.shape[0] != self.sizes['Q']:
            shp_msg.format(
                object_name="B",
                got=B.shape[0],
                expected=self.sizes['Q']
            )

        # Elements of B that we will fit for should be set to nan in the input B array
        self._B_fit_mask = jnp.isnan(B)
        if not np.any(self._B_fit_mask) and verbose:
            print("no free elements of B")
        elif np.any(self._B_fit_mask):
            B[self._B_fit_mask] = 0.
            if verbose:
                print(f"using {self._B_fit_mask.sum()} free elements of B")
        self.B = jnp.array(B)
        if verbose:
            print(f"B = {B}")
            print(f"B fit elements = {self._B_fit_mask}")

        # Now assess which latents to fit:
        self._z_fit_mask = jnp.all(self.B == 0, axis=0)
        if verbose:
            print(
                f"using {self._z_fit_mask.sum()} unconstrained elements of z, "
                f"out of {self.sizes['D']} latents"
            )

        self.alpha = float(alpha)
        self.beta = float(beta)

        # Regularization matrix:
        self.Lambda = self.alpha * np.diag(self._z_fit_mask.astype(int))
        if verbose:
            print(f"Lambda = {self.Lambda}")
        assert self.alpha > 0., "You must regularize, and strictly positively."

        # TODO:
        self.par_state = self.initialize_par_state()

    def initialize_par_state(self, **state):
        """
        N - stars
        R - features
        Q - labels
        D - latents
        
        mu_X : (R, )
        mu_y : (Q, )
        z : (N, D)
        A : (R, D)
        B : (Q, D)
        
        """

        # Initialize the means using invvar weighted means
        # TODO: could do sigma-clipping here to be more robust
        if 'mu_X' not in state:
            state['mu_X'] = (
                jnp.sum(self.X * self._X_ivar, axis=0) /
                jnp.sum(self._X_ivar, axis=0)
            )

        if 'mu_y' not in state:
            state['mu_y'] = (
                jnp.sum(self.y * self._y_ivar, axis=0) /
                jnp.sum(self._y_ivar, axis=0)
            )

        if 'z' not in state:
            # First hack: Start with the pseudo-inverse of `B`.
            state['z'] = np.zeros((self.sizes['N'], self.sizes['D']))
            chi = (self.y - state['mu_y'][None]) / self.y_err
            for n in range(self.sizes['N']):
                state['z'][n] = jnp.linalg.lstsq(
                    self.B / self.y_err[n][:, None],
                    chi[n],
                    rcond=None
                )[0].T
                
            # Second hack: Add some noise to unconstrained z components
            sigma = np.std(state['z'][:, ~self._z_fit_mask], axis=0)
            scale = 0.1  # MAGIC NUMBER
            state['z'][:, ~self._z_fit_mask] += self.rng.normal(
                0, 
                scale * sigma, 
                size=(self.sizes['N'], (~self._z_fit_mask).sum())
            )
            
            state['z'][:, self._z_fit_mask] = self.rng.normal(
                0, 
                scale * np.mean(sigma), 
                size=(self.sizes['N'], self._z_fit_mask.sum())
            )

        if 'A' not in state:
            state['A'] = np.zeros((self.sizes['R'], self.sizes['D']))
            chi = self._chi_X(state['mu_X'], state['A'], state['z'])
            
            for r in range(self.sizes['R']):
                state['A'][r] = jnp.linalg.lstsq(
                    state['z'] /  self.X_err[:, r:r+1],
                    chi[:, r],
                    rcond=None
                )[0]
            
        renorm = np.sqrt(np.sum(state['A'][:, self._z_fit_mask]**2, axis=0))
        state['A'][:, self._z_fit_mask] = state['A'][:, self._z_fit_mask] / renorm[None, :]
        state['z'][:, self._z_fit_mask] = state['z'][:, self._z_fit_mask] * renorm[None, :]

        if 'B' not in state:
            # TODO: implement this
            state['B'] = self.B

        return ParameterState(sizes=self.sizes, **state)

    def _chi_X(self, mu_X, A, z):
        return (self.X - _model_linear(mu_X, A, z)) / self.X_err

    def _chi_y(self, mu_y, B, z):
        return (self.y - _model_linear(mu_y, B, z)) / self.y_err

    def unpack_p(self, p):
        """
        TODO: deal with some of B is frozen
        """
        i = 0
        state = {}
        for name in self.par_state.names:
            if name == 'B':
                # TODO: see note above
                state['B'] = self.par_state.B
                continue
                
            # HACK: REMOVE ME
#             elif name == 'mu_X':
#                 state['mu_X'] = self.par_state.mu_X
#                 continue
#             elif name == 'mu_y':
#                 state['mu_y'] = self.par_state.mu_y
#                 continue

            val = getattr(self.par_state, name)
            state[name] = p[i:i+val.size].reshape(val.shape)
            i += val.size
        return ParameterState(sizes=self.sizes, **state)

    def pack_p(self, par_state=None):
        """
        TODO: deal with some of B is frozen
        """
        if par_state is None:
            par_state = self.par_state

        arrs = []
        for name in par_state.names:
            if name == 'B':
                # TODO: deal with note above
                continue
                
            # HACK: REMOVE ME
#             elif name in ['mu_X', 'mu_Y']:
#                 continue
                
            val = getattr(par_state, name).flatten()
            arrs.append(val)
        return jnp.concatenate(arrs)

    def cost(self, p):
        """
        TODO: Regularization term is totally wrong.
        """
        pars = self.unpack_p(p)
        # TODO: set par_state??

        chi_X = self._chi_X(pars.mu_X, pars.A, pars.z)
        chi_y = self._chi_y(pars.mu_y, pars.B, pars.z)
        
        return 0.5 * (
            jnp.sum(chi_X ** 2) +
            jnp.sum(chi_y ** 2) +
            self.alpha * jnp.sum(pars.z[:, self._z_fit_mask] ** 2) +
            self.beta * jnp.sum(pars.A[:, self._z_fit_mask] ** 2)
        )

    def __call__(self, p):
        val = self.cost(p)
        return val

    def predict_y(self, X, X_err, par_state=None):
        if par_state is None:
            par_state = self.par_state

        # should this use the regularization matrix? Hogg thinks not.
        M = X.shape[0]
        if X.shape[1] != self.sizes['R']:
            raise ValueError("Invalid shape for input feature matrix X")

        y_hat = np.zeros((M, self.sizes['Q']))

        chi = (X - par_state.mu_X[None]) / X_err
        for i, dx in enumerate(chi):
            M = par_state.A / X_err[i][:, None]
            z = np.linalg.lstsq(M, dx, rcond=None)[0]
            y_hat[i] = par_state.mu_y + par_state.B @ z

        return y_hat

# Make toy fake data:

In [None]:
# N - stars
# R - features
# Q - labels
# D - latents

N = 1024
R = 32
Q = 3
D = 5
M = 128

# N = 191
# R = 17
# Q = 3
# D = 5
# M = 53

# ---

rng = np.random.default_rng(42)

A_true = rng.normal(size=(R, D))
B_true = np.zeros((Q, D))
B_true[:Q, :Q] = np.eye(Q)
z_true = rng.normal(size=(N, D))

mu_X = rng.uniform(-1, 1, size=(1, R))
mu_y = rng.uniform(-1, 1, size=(1, Q))

X_true = mu_X + z_true @ A_true.T
y_true = mu_y + z_true @ B_true.T

sigma = 0.1
X = rng.normal(X_true, sigma, size=X_true.shape)  # Noisify
y = rng.normal(y_true, sigma, size=y_true.shape)  # Noisify


z_star_true = rng.normal(size=(M, D))
X_star_true = mu_X + z_star_true @ A_true.T
y_star_true = mu_y + z_star_true @ B_true.T
X_star = rng.normal(X_star_true, sigma, size=X_star_true.shape)  # Noisify
y_star = rng.normal(y_star_true, sigma, size=y_star_true.shape)  # Noisify

X_err = np.full_like(X, sigma)
y_err = np.full_like(y, sigma)
X_star_err = np.full_like(X_star, sigma)
y_star_err = np.full_like(y_star, sigma)

alpha = 0.1
beta = 1.

In [None]:
rng = np.random.default_rng(42)
# llvm = LinearLVM(X, y, X_err, y_err, B_true, alpha, beta, verbose=True, rng=rng)
llvm = LinearLVM2(X, y, X_err, y_err, B_true, alpha, beta, verbose=True, rng=rng)
# llvm = LinearLVM(X, y, X_err, y_err, B_true, 0.1, 0., verbose=True, rng=rng)

In [None]:
chix = np.asarray(llvm._chi_X(llvm.par_state.mu_X, llvm.par_state.A, llvm.par_state.z))
np.std(chix, axis=0)

In [None]:
x0 = llvm.pack_p()
llvm(x0)

In [None]:
# for name in llvm.par_state.names:
#     assert np.all(getattr(llvm.unpack_p(llvm.pack_p()), name) == getattr(llvm.par_state, name))

In [None]:
solver = jaxopt.LBFGS(fun=llvm, maxiter=1000)
res = solver.run(x0)
res.state.iter_num

In [None]:
print(
    llvm(x0),
    llvm(res.params)
)

In [None]:
res_state = llvm.unpack_p(res.params)
ystar_predict0 = llvm.predict_y(X_star, X_star_err, llvm.par_state)
ystar_predict = llvm.predict_y(X_star, X_star_err, res_state)

In [None]:
# infer for test-set objects
for k in range(y_star.shape[1]):
    plt.figure()
#     plt.scatter(y_star[:, k], ystar_predict0[:, k], c="r", marker="o")
    plt.scatter(y_star[:, k], ystar_predict[:, k], c="k", marker="o")
    plt.plot([y_star[:, k].min(), y_star[:, k].max()],
             [y_star[:, k].min(), y_star[:, k].max()], 
             marker='', color='tab:blue')
    plt.xlabel(f"true label {k}")
    plt.ylabel(f"prediction of label {k}")
    plt.title("held-out data")

---

In [None]:
N = 1024
R = 32
Q = 3
D = 5
M = 100

rng = np.random.default_rng(42)

A_true = rng.normal(size=(R, D))
B_true = np.zeros((Q, D))
B_true[:Q, :Q] = np.eye(Q)
z_true = rng.normal(size=(N, D))

X_true = z_true @ A_true.T
y_true = z_true @ B_true.T

sigma = 0.1
X_train = rng.normal(X_true, sigma, size=X_true.shape)  # Noisify
y_train = rng.normal(y_true, sigma, size=y_true.shape)  # Noisify


z_valid_true = rng.normal(size=(M, D))
X_valid_true = z_valid_true @ A_true.T
y_valid_true = z_valid_true @ B_true.T
X_valid = rng.normal(X_valid_true, sigma, size=X_valid_true.shape)  # Noisify
y_valid = rng.normal(y_valid_true, sigma, size=y_valid_true.shape)  # Noisify

X_train_err = np.full_like(X, sigma)
y_train_err = np.full_like(y, sigma)
X_valid_err = np.full_like(X_valid, sigma)
y_valid_err = np.full_like(y_valid, sigma)

alpha = 0.1
beta = 1.

In [None]:
rng = np.random.default_rng(42)
n_labels = y_train.shape[1]
n_latents = n_labels + 1
B = np.zeros((n_labels, n_latents))
B[:n_labels, :n_labels] = np.eye(n_labels)

llvm = LinearLVM(
    X_train, y_train, 
    X_train_err, y_train_err, 
    B, alpha=1, beta=1., 
    verbose=True, rng=rng
)

In [None]:
x0 = llvm.pack_p()
print(llvm(x0))
print(x0.shape)

In [None]:
solver = jaxopt.LBFGS(fun=llvm, maxiter=10000)
res_bfgs = solver.run(x0)
res_state = llvm.unpack_p(res_bfgs.params)
print(res_bfgs.state.iter_num)
llvm(res_bfgs.params)

### Self-test:

In [None]:
y_train_predict0 = llvm.predict_y(
    X_train, 
    X_train_err, 
    llvm.par_state
)

y_train_predict = llvm.predict_y(
    X_train, 
    X_train_err,  
    res_state
)

In [None]:
for q in range(y_train.shape[1]):
    plt.figure()
    plt.scatter(
        y_train[:, q],
        y_train_predict0[:, q]
    )
    plt.scatter(
        y_train[:, q],
        y_train_predict[:, q]
    )

In [None]:
# infer for test-set objects
for k in range(y_valid.shape[1]):
    plt.figure()
    plt.scatter(y_valid[:, k], y_valid_predict0[:, k], c="r", marker="o")
    plt.scatter(y_valid[:, k], y_valid_predict[:, k], c="k", marker="o")
    plt.plot([y_valid[:, k].min(), y_valid[:, k].max()],
             [y_valid[:, k].min(), y_valid[:, k].max()], 
             marker='', color='tab:blue')
    plt.xlabel(f"true label {k}")
    plt.ylabel(f"prediction of label {k}")
    plt.title("held-out data")

In [None]:
y_valid_predict0 = llvm.predict_y(
    X_valid, 
    X_valid_err, 
    llvm.par_state
)

y_valid_predict = llvm.predict_y(
    X_valid, 
    X_valid_err, 
    res_state
)

print(y_valid)
print(y_valid_predict0)
print(y_valid_predict)