In [None]:
import numpy as np
from matplotlib import pyplot as plt

## The factor model

In [None]:
n = 4
M = 1000
# Random parameters
beta0 = np.random.normal(0, 0.5, size=n)
beta1 = np.random.normal(0, 1, size=n) + 1
sigma = abs(np.random.normal(0, 1, size=n))

In [None]:
z = np.random.normal(0, 1, size=M)
x = np.zeros(shape=(M, n))

epsilon = np.random.normal(0, 1, size=(M, n))

for k in range(n):
    x[:, k] = beta0[k] + beta1[k] * z + sigma[k] * epsilon[:, k]

In [None]:
k = 0 # Try different k
print(beta0[k], beta1[k], sigma[k])
plt.scatter(z, x[:, k], alpha = 0.1)
plt.plot(np.array([-4, 4]), beta0[k] + beta1[k] * np.array([-4, 4]), 'r')     

## Learning

The function `learn_beta` below learns all the beta-parameters from complete observations using regression. You could, of course, do that also by sklearn or stats.linregress, but one benefit of the implementation below is that it works directly for several targets, $X_1, \ldots, X_n$, directly.

In [None]:
# Linear regression with many responses
def learn_beta(x, z):
    U = np.stack((np.ones_like(z), z), 1)
    Ut = np.transpose(U)
    Sigma = Ut @ U
    Sigma_inv =  np.linalg.inv(Sigma)
    beta = Sigma_inv @ (Ut @ x)
    return(beta[0, :], beta[1, :])

In [None]:
beta0_hat, beta1_hat = learn_beta(x, z)

In [None]:
#Comparison with true parameters
beta0_hat, beta0 

In [None]:
#Comparison with true parameters
beta1_hat, beta1

## The inference algorithm

In the information parameters, the conditional distribution of $Z$ given $\mathbf{X} = \mathbf{x}$ has parameters
\begin{align*}
J & =  \sum_{k=1}^n \frac{\beta_{1,k}^2}{\sigma_k^2} \\
h & =  \sum_{k=1}^n \frac{(x_k - \beta_{0,k})\beta_{1,k}}{\sigma_k^2}
\end{align*}
and the conditional mean and variance are simple to compute from these parameters.

In [None]:
# inference of z for a single x
def z_inf(x, beta0, beta1, sigma):
    J = np.sum((beta1 / sigma) ** 2)
    h = np.sum((x - beta0) * beta1 / (sigma ** 2))
    return((h / J, 1 / J))

In [None]:
m = 10
print(z[m])
z_inf(x[m, :], beta0, beta1, sigma)

In [None]:
def z_hat(x, beta0, beta1, sigma):
    M = x.shape[0]
    z_hats = np.zeros(shape=M)
    for m in range(M):
        z_hats[m] = z_inf(x[m, :], beta0, beta1, sigma)[0]
    return(z_hats)

In [None]:
# Check that the inference algorithm produces sensible predictions
z_hats = z_hat(x, beta0, beta1, sigma)
plt.scatter(z, z_hats)

## The hard-assignment EM algorithm

The algorithm is implemented as a simple iteration ($N$ times) of prediction by the inference algorithm given current parameters and updating the parameters by learning the beta-parameters and then the sigma-parameters.

In [None]:
def hard_EM(x, beta0_init, beta1_init, sigma_init, N=10):
    M, n = x.shape
    beta0 = beta0_init
    beta1 = beta1_init
    sigma = sigma_init
    for i in range(N):
        z_hats = z_hat(x, beta0, beta1, sigma)
        beta0, beta1 = learn_beta(x, z_hats)
        for k in range(n):
            sigma[k] = np.sqrt(np.sum((x[:, k] - beta0[k] - beta1[k] * z) ** 2) / M)
    return(beta0, beta1, sigma)

The most difficult part is actually choosing the starting values of the algorithm. Here we use that the means of the $x$-s are equal to the $\beta_0$-parameters and that the variances of the $x$-s is an upper bound on the $\sigma$-parameters. This gives some sensible choices of those paramters. The $\beta_1$-parameters are just (somewhat arbitrarily) set to 1.

In [None]:
hard_EM(x, np.mean(x, 0), np.ones(n), np.sqrt(np.var(x, 0)), N=10)

In [None]:
# We should compare the above to the true parameters
beta0, beta1, sigma

In [None]:
# Another way to test if the implementation behaves reasonably is to start the 
# algorithm in the true parameters
hard_EM(x, beta0, beta1, sigma, N=10)

It is a good sign that the algorithm doesn't start to drift off when started in the true parameters. However, the hard assignment EM algorithm doesn't have well-understood convergence properties and it may converge to slightly different values of the parameters depending on small changes in the choice of starting values. 