Quickest way to setup:

- Download and install conda: https://www.anaconda.com/download
- `conda install tensorflow`
- `pip install tensorflow_datasets "jax[cpu]" chex`

In [None]:
import functools

import jax
import jax.numpy as jnp
import jax.scipy as jsp
import numpy as np
import scipy as sp
import matplotlib.pylab as plt
import chex
import tqdm

import tensorflow_datasets as tfds

## Download MNIST using tfds

In [None]:
train_ds = tfds.load('mnist', split='train', shuffle_files=True)
test_ds = tfds.load('mnist', split='test', shuffle_files=True)

In [None]:
def tfds_to_numpy(ds) -> tuple[jax.Array, jax.Array]:
    images, labels = [], []
    for example in tfds.as_numpy(ds):
        images.append(example['image'] / 255.0)
        labels.append(example['label'])
    images = np.array(images)
    labels = np.array(labels)
    binary_images = np.squeeze(images >= 0.5, axis=-1)
    return binary_images, labels

In [None]:
train_binary_images, train_labels = tfds_to_numpy(train_ds)
test_binary_images, test_labels = tfds_to_numpy(test_ds)

In [None]:
print("train_binary_images.shape", train_binary_images.shape)
print("test_binary_images.shape", test_binary_images.shape)

## BernoulliMixture dataclass

We define a dataclass to collect all the Bernoulli mixture parameters in a single data structure. The `@chex.dataclass` is needed so that we can pass `BernoulliMixture` instances as valid pytrees into jax jitted functions.

In [None]:
n_pixels = np.prod(train_binary_images.shape[1:])

@chex.dataclass
class BernoulliMixture:
    cluster_weights: jax.Array  # (n_clusters,)
    cluster_means: jax.Array  # (n_clusters, n_pixels)

def init_mixture_params(key: chex.PRNGKey, n_clusters: int, n_pixels: int) -> BernoulliMixture:
    """Initialize the mixture parameters."""
    raise NotImplementedError

## Expectation Maximization (EM) implementation

Note: for debugging, it may be helpful to remove the `@jax.jit` annotations. But for performance reasons you will want to re-enable the jit annotations.

In [None]:
@jax.jit
def log_joint_prob(
    params: BernoulliMixture,
    binary_images: jax.Array
) -> jax.Array:
    """Compute the log probability log p(x, z) for each example.

    Args:
        params: the BernoulliMixture to evaluate log probability under.
        binary_images: an (n, h, w) shape binary array containing the observations.

    Returns:
        A (n,) shape array containing log p(x, z) for each input observation.
    """
    raise NotImplementedError
    

@jax.jit
def log_likelihood(
    params: BernoulliMixture,
    binary_images: jax.Array
) -> jax.Array:
    """Compute the marginal log probability log p(x) for each example.
    
    Args:
        params: the BernoulliMixture to evaluate log probability under.
        binary_images: an (n, h, w) shape binary array containing the observations.

    Returns:
        A (n,) shape array containing log p(x) for each input observation.
    """
    raise NotImplementedError


@jax.jit
def em_step(
    params: BernoulliMixture,
    binary_images: jax.Array
) -> BernoulliMixture:
    """Run one Expectation Maximization (EM) step for the Bernoulli mixture model.
    
    Args:
        params: the current BernoulliMixture model.
        binary_images: an (n, h, w) shape binary array containing the observations.

    Returns:
        The updated BernoulliMixture parameters.
    """
    raise NotImplementedError

## Run the EM algorithm

Here we run EM for 20 iterations using $k=15$ clusters.

In [None]:
n_clusters = 15
params = init_mixture_params(jax.random.PRNGKey(83832), n_clusters, n_pixels)
train_lls, test_lls = [], []
for iter in tqdm.tqdm(range(20)):
    params = em_step(params, train_binary_images)
    train_lls.append(np.mean(log_likelihood(params, train_binary_images)))
    test_lls.append(np.mean(log_likelihood(params, test_binary_images)))
train_lls, test_lls = np.array(train_lls), np.array(test_lls)

## Evaluate EM results

Here we plot various evaluation metrics.

In [None]:
plt.plot(train_lls)
plt.plot(test_lls)
plt.xlabel('Iteration')
plt.ylabel('Log Likelihood')
plt.title('Bernoulli EM, n_clusters={}'.format(n_clusters))
plt.legend(('Train', 'Test'))

In [None]:
# display
n_rows = 3
n_cols = 5

fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols)
for row_idx in range(n_rows):
    for col_idx in range(n_cols):
        ax = axs[row_idx][col_idx]
        cluster_idx = row_idx * n_cols + col_idx
        ax.imshow(params.cluster_means[cluster_idx].reshape((28, 28)))
        ax.xaxis.set_ticklabels([])
        ax.xaxis.set_ticks([])
        ax.yaxis.set_ticklabels([])
        ax.yaxis.set_ticks([])
        ax.set_title("$p_{{{}}}={:.3f}$".format(cluster_idx, params.cluster_weights[cluster_idx]))

plt.tight_layout()

In [None]:
# for each datapoint, compute a hard assignment 
train_assignments = np.argmax(log_joint_prob(params, train_binary_images), axis=-1)
test_assignments = np.argmax(log_joint_prob(params, test_binary_images), axis=-1)

In [None]:
def make_cluster_labels(assignments, labels): 
    cluster_labels = [{} for _ in range(n_clusters)]
    for cluster_assignment, true_label in zip(assignments, labels):
        d = cluster_labels[cluster_assignment]
        d[true_label] = d.get(true_label, 0) + 1
    return cluster_labels

train_cluster_labels = make_cluster_labels(train_assignments, train_labels)
test_cluster_labels = make_cluster_labels(test_assignments, test_labels)

In [None]:
# display
n_rows = 3
n_cols = 5

fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(10, 8))
for row_idx in range(n_rows):
    for col_idx in range(n_cols):
        ax = axs[row_idx][col_idx]
        cluster_idx = row_idx * n_cols + col_idx
        labels = np.array([l for l, _ in train_cluster_labels[cluster_idx].items()])
        counts = np.array([v for _, v in train_cluster_labels[cluster_idx].items()])
        ax.bar(labels - 0.2, counts / np.sum(counts), width=0.4)
        labels = np.array([l for l, _ in test_cluster_labels[cluster_idx].items()])
        counts = np.array([v for _, v in test_cluster_labels[cluster_idx].items()])
        ax.bar(labels + 0.2, counts / np.sum(counts), width=0.4)
        ax.xaxis.set_ticks(list(range(10)))
        if cluster_idx == 0:
            ax.legend(('train', 'test'))
        ax.set_title("$p_{{{}}}={:.3f}$".format(cluster_idx, params.cluster_weights[cluster_idx]))

plt.tight_layout()