# Experiments with a linear latent-variable model

## Authors:
- **Adrian Price-Whelan** (Flatiron)
- **David W. Hogg** (NYU) (MPIA) (Flatiron)

## Hyper-parameters:
- *add key hyper-parameters here.*

## TODO / questions:
- There is no reason that `X` and `Y` (training data) must contain the same number of objects, people.
- Right now this only deals with weight vectors, not full covariance matrices.

In [None]:
import numpy as np
import jax.numpy as jnp
import pylab as plt

## Make LLVM functions

In [None]:
class llvm:
    def __init__(self, X, Y, Wx, Wy, B, Lambda):
        """
        ## inputs:
        `X`: shape `(n, r)` block of training features
        `Y`: shape `(n, q)` block of training labels
        `Wx`: shape `(n, r)` block of inverse variances (weights) for the features
        `Wy`: shape `(n, q)` block of inverse variances (weights) for the labels
        `B`: shape `(q, d)` matrix translating latents to labels.
        `Lambda`: regularization strength; use the source, Luke.
        
        ## bugs / comments:
        - Put `np.NaN` into the parts of `B` you want to optimize!
        - This code more-or-less assumes that your data and labels and `B` are
          normalized to reasonable ranges.
        - In the long run, `X` and `Y` don't need to be the same length.
        """
        self.X = jnp.array(X)
        self.Y = jnp.array(Y)
        self.n, self.r = self.X.shape
        enn, self.q = self.Y.shape
        assert enn == self.n, "llvm: Right now, the training data must be rectangular."
        self.Wx = jnp.array(Wx)
        assert self.Wx.shape == self.X.shape, "llvm: Inconsistency between `X` and `Wx`."
        assert np.all(self.Wx >= 0.), "llvm: Weights can't be negative."
        self.sqrtWx = jnp.sqrt(self.Wx)
        self.Wy = jnp.array(Wy)
        assert self.Wy.shape == self.Y.shape, "llvm: Inconsistency between `Y` and `Wy`."
        assert np.all(self.Wy >= 0.), "llvm: Weights can't be negative."
        self.sqrtWy = jnp.sqrt(Wy)
        self.B = jnp.array(B)
        cue, self.d = B.shape
        assert cue == self.q, "llvm: Inconsistency between `B` and `Y`."
        self.B_elements_to_fit = jnp.isnan(self.B)
        if jnp.sum(self.B_elements_to_fit) == 0:
            self.B_elements_to_fit = None
            print("llvm: found no free elements of `B`.")
        else:
            self.B[self.B_elements_to_fit] = 0.
            print("llvm: found", np.sum(self.B_elements_to_fit),
                  "free elements of `B`.")
        self.free_elements_of_Z = jnp.isclose(jnp.sum(self.B, axis=0), 0.)
        print(self.free_elements_of_Z)
        print("llvm: found", np.sum(self.free_elements_of_Z),
              "unconstrained elements of `Z`, of", self.d)
        self.Lambda = Lambda
        self.regularization_matrix = Lambda * np.diag(self.free_elements_of_Z.astype(int))
        print(self.regularization_matrix)
        assert Lambda > 0., "llvm: You must regularize, and strictly positively."
        self.initialize_latents()
        return

    def _predict_X(self):
        return self.mux[None, :] + self.Z @ self.A.T

    def _predict_Y(self):
        return self.muy[None, :] + self.Z @ self.B.T

    def _chi_X(self):
        return (self.X - self._predict_X()) * self.sqrtWx

    def _chi_Y(self):
        return (self.Y - self._predict_Y()) * self.sqrtWy

    def _cost(self):
        """
        WARNING: Regularization term is totally wrong.
        """
        Xchi, Ychi = self._chi_X(), self._chi_Y()
        return 0.5 * jnp.sum(self._chi_X() ** 2) \
             + 0.5 * jnp.sum(self._chi_Y() ** 2) \
             + 0.5 * self.Lambda * jnp.sum(self.Z[:, self.free_elements_of_Z] ** 2)

    def predict_y_given_x(self, Xstar, Wxstar):
        # should this use the regularization matrix? Hogg thinks not.
        m, arrh = Xstar.shape
        assert arrh == self.r
        Ystarhat = np.zeros((m, self.q))
        sqrtWx = np.sqrt(Wxstar)
        chi = (Xstar - self.mux[None, :]) * sqrtWx
        for i, x in enumerate(chi):
            M = self.A * sqrtWx[i][:, None]
            z = np.linalg.lstsq(M, x, rcond=None)[0]
            Ystarhat[i] = self.muy + self.B @ z
        return Ystarhat

    def initialize_latents(self):
        # this is a bag of hacks.

        # Zeoth hack: Take means.
        self.mux = jnp.sum(self.X * self.Wx, axis=0) / jnp.sum(self.Wx, axis=0)
        assert self.mux.shape == (self.r, )
        self.muy = jnp.sum(self.Y * self.Wy, axis=0) / jnp.sum(self.Wy, axis=0)
        assert self.muy.shape == (self.q, )

        # First hack: Start with the pseudo-inverse of `B`.
        # BUG: Doesn't use weights.
        self.Z = jnp.linalg.lstsq(self.B, (self.Y - self.muy[None, :]).T, rcond=None)[0].T
        assert self.Z.shape == (self.n, self.d)

        # Second hack: Add noise.
        TINY = 1.e-1 # MAGIC
        sigma = np.std(self.Z) + TINY
        self.Z += 0.1 * sigma * np.random.normal(size=self.Z.shape) # MAGIC

        # Third hack: Run A and B steps.
        self.A = jnp.zeros((self.r, self.d))
        self.A_step()
        self.B_step()
        
    def optimize_step(self, ftol=0.1):
        self._renorm_A()
        before = self._cost()
        self.Z_step()
        self.mu_step()
        self.A_step()
        self.B_step()
        after = self._cost()
        print("optimize_step(): Cost after:", after, self.n * (self.r + self.q))
        return after < (before - ftol)

    def Z_step(self):
        # BUG: DOESN'T DO REGULARIZATION RIGHT.
        dZ = np.zeros_like(self.Z)
        for i, (x, y) in enumerate(zip(self._chi_X(),
                                       self._chi_Y())):
            resid = np.append(x, y)
            matrix = np.concatenate((self.A * (self.sqrtWx[i])[:, None],
                                     self.B * (self.sqrtWy[i])[:, None]),
                                    axis=0)
            dZ[i] = np.linalg.lstsq(matrix.T @ matrix
                                    + self.regularization_matrix,
                                    matrix.T @ resid, rcond=None)[0]
        self.Z = self.Z + dZ
        return

    def A_step(self):
        dA = np.zeros_like(self.A)
        for j, x in enumerate(self._chi_X().T):
            dA[j] = np.linalg.lstsq(self.Z * self.sqrtWx[:, j][None, :].T,
                                    x, rcond=None)[0]
        self.A = self.A + dA
        return

    def _renorm_A(self):
        renorm = np.sqrt(np.sum(self.A[:, self.free_elements_of_Z] ** 2, axis=0))
        self.A.at[:, self.free_elements_of_Z].divide(renorm[None, :])
        self.Z.at[:, self.free_elements_of_Z].multiply(renorm[None, :])
        return

    def B_step(self):
        # BUG: Not currently operable.
        if self.B_elements_to_fit is None:
            return
        assert False
        return

    def mu_step(self):
        Xresid = self.X - self._predict_X()
        self.mux = self.mux + jnp.sum(Xresid * self.Wx, axis=0) / jnp.sum(self.Wx, axis=0)
        Yresid = self.Y - self._predict_Y()
        self.muy = self.muy + jnp.sum(Yresid * self.Wy, axis=0) / jnp.sum(self.Wy, axis=0)
        return

## Make fake data

In [None]:
np.random.seed(17) # whoops
Atrue = np.random.normal(size=(17, 5))
Btrue = np.zeros((3,5))
Btrue[:3,:3] = np.eye(3)
Ztrue = np.random.normal(size=(191, 5))
Zstartrue = np.random.normal(size=(53, 5))
Xtrue = Ztrue @ Atrue.T
Xstartrue = Zstartrue @ Atrue.T
sigma = 0.1
X = Xtrue + sigma * np.random.normal(size=Xtrue.shape)
Xstar = Xstartrue + sigma * np.random.normal(size=Xstartrue.shape)
Ytrue = Ztrue @ Btrue.T
Ystartrue = Zstartrue @ Btrue.T
Y = Ytrue + sigma * np.random.normal(size=Ytrue.shape)
Ystar = Ystartrue + sigma * np.random.normal(size=Ystartrue.shape)
Wx = np.zeros_like(X) + sigma ** -2
Wxstar = np.zeros_like(Xstar) + sigma ** -2
Wy = np.zeros_like(Y) + sigma ** -2
Wystar = np.zeros_like(Ystar) + sigma ** -2

In [None]:
B = np.zeros((3,7))
B[:3, :3] = np.eye(3)
Lambda = 0.1
LLVM = llvm(X, Y, Wx, Wy, B, Lambda)

In [None]:
while LLVM.optimize_step():
    True

In [None]:
print(LLVM.mux)

In [None]:
# show self-test
plt.scatter(Y[:, 0], LLVM._predict_Y()[:, 0], c="k", marker="o")
plt.xlabel("true label 0")
plt.ylabel("prediction of label 0")
plt.title("self test")
f = plt.figure()
plt.scatter(X[:, 0], LLVM._predict_X()[:, 0], c="k", marker="o")
plt.xlabel("true feature 0")
plt.ylabel("prediction of feature 0")
plt.title("self test")

In [None]:
# infer for test-set objects
Yhat = LLVM.predict_y_given_x(Xstar, Wxstar)
plt.scatter(Ystar[:, 0], Yhat[:, 0], c="k", marker="o")
plt.xlabel("true label 0")
plt.ylabel("prediction of label 0")
plt.title("held-out data")