# Additional Gaussian_Mixture_Model

## Motivation

In the first notebook [core_jax_GMM](./core_jax_gaussian_mixture_model.ipynb), we initialized our initial clusters by "cheating" in that we used the unknowable true centers and adding noise. This worked
great as a teaching and sanity-checking tool, but in reality we can not know this. So, we initialize in many random locations and with a few cluster centers.

For this notebook, we take the concepts and code from our first notebook [core_jax_GMM](./core_jax_gaussian_mixture_model.ipynb) and expand on them. Some new concepts used:

- `RNG` for reproducible initializations of the mus and
- `PMAP` for parallelization of the GMM across the initializations

Before each newly introduced concept we briefly discuss the arguments and why the code is laid out the way it is :)

In [1]:
# Import OS and specify the number of devices to force simulate BEFORE importing Jax
#   comment me out and run the next two lines to see what happens if we do not do this correctly
import os
os.environ['XLA_FLAGS'] = "--xla_force_host_platform_device_count=8"


In [2]:
import jax.numpy as jnp
import jax
import numpy as np
np.random.seed(123)

from additional_gmm import unknown_centers, make_ds, EM_GMM

In [3]:
jax.devices()

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7)]

## Create the function to be `pmap`-ed

We create a wrapping function that encompasses the logic of taking in a random seed for initializing the guessed centers

In [4]:
def run_gmm(rk, X, ndims,K):
    main_key, subkey = random.split(rk)
    # We can use the `main_key` for further random operations. The subkey should NOT be reused
    mus = jax.random.normal(subkey, shape=(K, ndims))

    sigmas = np.asarray([np.cov(X.T) for _ in range(K)])
    cls_probs = np.asarray([1 / K for _ in range(K)]).T
    cls_probs = np.expand_dims(cls_probs, axis=-1)
    return EM_GMM(
        X,
        mus=mus,
        sigmas=sigmas,
        cls_probs=cls_probs,
        guess_num_classes=K,
        verbose=False
    )

## Setting up the `pmap`

`pmap` is very similar to `vmap`, but instead of applying a function across a mapped axis, it copies the function and executes it on each device.

Here we are mapping the random keys over the devices, so we specify that we map it over the 0th axis



In [5]:
parallel_gmm = jax.pmap(
    run_gmm,
    in_axes=(0, None, None)
)

## Randomness

Jax's `random.key` and `random.split` returns a vector or array of shape (N, `2). You don't actually access the underlying two values, but they are instead passed as a single value to Jax

In [6]:
from jax import random

seed = 42
init_key = random.key(seed)
keys_and_subkeys = random.split(init_key, num=len(jax.devices()))

print(keys_and_subkeys)

Array((8,), dtype=key<fry>) overlaying:
[[1016697191 1792542510]
 [  21752309 3647990511]
 [ 344551668 3939928494]
 [ 861363423  169498067]
 [2390192106  167227791]
 [ 201508585 2676631123]
 [3104550939 3018605412]
 [ 775411565  603659288]]


In [7]:
try:
    keys_and_subkeys[0][0]
except Exception as e:
    print(e)

Too many indices for array: 1 non-None/Ellipsis indices for dim 0.


In [8]:
new_k, new_subk = random.split(keys_and_subkeys[0])
print(new_k)
print(new_subk)


Array((), dtype=key<fry>) overlaying:
[2733821252 4190582358]
Array((), dtype=key<fry>) overlaying:
[2048541734 4153621245]


# Run the parallel training

For each cluster number, we run 

In [9]:
(X, y), _ = make_ds(unknown_centers)

N, M = X.shape

all_results = []
for num_clusters in range(1, 5):  
    
    res = parallel_gmm(keys_and_subkeys, X, M, num_clusters)
    all_results.append(res)

ValueError: pytree structure error: different lengths of tuple at key path
    pmap in_axes[0]
At that key path, the prefix pytree pmap in_axes has a subtree of type tuple of length 3, but the full pytree has a subtree of the same type but of length 4.

The 'full pytree' here is the tuple of arguments passed positionally to the pmapped function, and the value of `in_axes` must be a tree prefix of that tuple. But it was not a prefix.

Check that the value of the `in_axes` argument to `pmap` is a tree prefix of the tuple of arguments passed positionally to the pmapped function.

# Result Investigation

In [None]:
import matplotlib.pyplot as plt

## Plot the Log-likelihood

In [None]:
plt.plot(lls)
plt.xlabel("Iteration")
plt.ylabel("Log-likelihood of points")
plt.show()

# Show the Points

In [None]:
from matplotlib.patches import Ellipse
import matplotlib.transforms as transforms
def confidence_ellipse(mu, sigma, ax, n_std=3.0, facecolor='none', **kwargs):
    """
    Modified based on function from: https://matplotlib.org/stable/gallery/statistics/confidence_ellipse.html
    Create a plot of the covariance confidence ellipse of *x* and *y*.

    Parameters
    ----------
    x, y : array-like, shape (n, )
        Input data.

    ax : matplotlib.axes.Axes
        The Axes object to draw the ellipse into.

    n_std : float
        The number of standard deviations to determine the ellipse's radiuses.

    **kwargs
        Forwarded to `~matplotlib.patches.Ellipse`

    Returns
    -------
    matplotlib.patches.Ellipse
    """
    pearson = sigma[0, 1]/np.sqrt(sigma[0, 0] * sigma[1, 1])
    # Using a special case to obtain the eigenvalues of this
    # two-dimensional dataset.
    ell_radius_x = np.sqrt(1 + pearson)
    ell_radius_y = np.sqrt(1 - pearson)
    ellipse = Ellipse((0, 0), width=ell_radius_x * 2, height=ell_radius_y * 2,
                      facecolor=facecolor, **kwargs)

    # Calculating the standard deviation of x from
    # the squareroot of the variance and multiplying
    # with the given number of standard deviations.
    scale_x = np.sqrt(sigma[0, 0]) * n_std
    # calculating the standard deviation of y ...
    scale_y = np.sqrt(sigma[1, 1]) * n_std

    transf = transforms.Affine2D() \
        .rotate_deg(45) \
        .scale(scale_x, scale_y) \
        .translate(mu[0], mu[1])

    ellipse.set_transform(transf + ax.transData)
    return ax.add_patch(ellipse)

In [None]:
fig, axs = plt.subplots(1, 1, figsize=(15, 10))

colors = ["r", "g", "b", "y"]

for i, c in enumerate(colors):
    
    # Plot the centers
    plt.scatter(unknown_centers[i, 0], unknown_centers[i, 1], c=c, marker="o", label=f"Cluster: {i} True Center")
    plt.scatter(mus[i, 0], mus[i, 1], c=c, marker="^", label=f"Cluster: {i} Inferred Center")
    
    # Plot the standard deviations
    mask = y == i
    masked_points = X[mask]
    mu_x = np.mean(masked_points, axis=0)
    sigma = np.cov(masked_points[:, 0], masked_points[:, 1])
    confidence_ellipse(mu_x, sigma,  ax=axs, n_std=1, edgecolor=c, linestyle="-")
    confidence_ellipse(mu_x, sigma, ax=axs, n_std=2, edgecolor=c, linestyle="-")
    confidence_ellipse(mu_x, sigma, ax=axs, n_std=3, edgecolor=c, linestyle="-")


    confidence_ellipse(mus[i], sigmas[i],  ax=axs, n_std=1, edgecolor=c, linestyle="--")
    confidence_ellipse(mus[i], sigmas[i], ax=axs, n_std=2, edgecolor=c, linestyle="--")
    confidence_ellipse(mus[i], sigmas[i], ax=axs, n_std=3, edgecolor=c, linestyle="--")
plt.legend(loc="best")


# Followup

For more advanced concepts, please check out []()