# Gaussian_Mixture_Model

Concepts used:

- `jit`: make our code move real fast
- `vmap`: automatically vectorize our `e_step` and `m_step` across the matrices of data
- `pmap`: automatically parallelize our GMM across multiple initializations
- `RNG`: reproducible randomness

In [2]:
import jax.numpy as jnp
import jax

from sklearn.datasets import make_blobs

Our problem formulation is where there are 4 actual clusters (we do not know this). In a real-world scenario, we do not know the number of clusters ahead-of-time. 

In [3]:
def make_ds():
    points_in_classes = [30, 50, 20, 5]
    centers = jnp.asarray([
        [1, -1],  # bottom left
        [5, 5],  # middle
        [8,7],  # mid-right
        [10, 0]  # bottom right
    ])
    return make_blobs(points_in_classes, centers=centers)
    
X, y = make_ds()

# Batching

Note how all our following code works with `x` instead of `X`? This is to indicate that 
we are working on a single sample at a time

In [4]:
def _e_step(x, mu, sigma, class_proba):
    pass

def _m_step(
        x, mu, sigma, class_proba, responsibilities
):
    pass

def _log_lik(x, mu, sigma, class_proba):
    pass

In [None]:
@jax.jit
def e_step(X, mu, sigma, class_proba):
    return jax.vmap(_e_step, in_axes=(0, None, None, None))(X, mu, sigma, class_proba)

@jax.jit
def m_step(X, mu, sigma, class_proba, responsibilities):
    return jax.vmap(_m_step, in_axes=(0, None, None, None, None))(X, mu, sigma, class_proba, responsibilities)

@jax.jit
def log_lik(X, mu, sigma, class_proba):
    return jax.vmap(_log_lik, in_axes=(0, None, None, None))(X, mu, sigma, class_proba)

def run_GMM(
        X,
        # Parameters
        mu, sigma, class_proba,
        
        # Hyperparameters
        tol=1e-2
):
    """
    We accept the guesses so that we can do a `pmap` over
    """
    
    ll = log_lik()
    while ll > tol:
        
        responsibilities = e_step(X, mu, sigma, class_proba)
        mu, sigma, class_proba = m_step(
            X, mu, sigma, class_proba, responsibilities
        )
        
        ll = log_lik(X, mu, sigma, class_proba)
        
    return mu, sigma, class_proba