## Lecture 16: Mixture Models: EM Algorithm (MIT Lecture Notes)

This is a companion notebook to tackle the **project** for Unit4. The detailed notes are in **Expectation Maximization.pptx**

**EM Algorithm: GMM (Clustering)**

This is a boiler-plate template to get you guys started on how Expectation and Maximization steps can be implimented using numpy arrays.

Just to emphasize:

E Step 

**KMeans Using EM**

One can also frame k-means as a special case of EM algorithm.

|EM: GMM|EM:Kmeans|
|--------|---------|
|**E Step:** $p(j,i)$ is computed based on a Gaussian Mixture of cluster centers|**E Step:** We simply find out points which are near the means|
|**M Step:** $\mu$,$\sigma$ and $p_j$ are updated based on $n_j$ obtained from $p(j,i)$ obtained in previous step| **M Step**: Only $\mu$ is updated based on the cluster assignments done previous step|

Below is the code given as part of the project

In [1]:
from typing import Tuple
import numpy as np
from common import GaussianMixture, plot, init


def estep(X: np.ndarray, mixture: GaussianMixture) -> np.ndarray:
    """E-step: Assigns each datapoint to the gaussian component with the
    closest mean

    Args:
        X: (n, d) array holding the data
        mixture: the current gaussian mixture

    Returns:
        np.ndarray: (n, K) array holding the soft counts
            for all components for all examples

        """
    n, _ = X.shape
    K, _ = mixture.mu.shape
    post = np.zeros((n, K))

    for i in range(n):
        tiled_vector = np.tile(X[i, :], (K, 1))
        sse = ((tiled_vector - mixture.mu)**2).sum(axis=1)
        j = np.argmin(sse)
        post[i, j] = 1 ## You can notice here we are simply hard counting which points belong to which clusters based on distance,(line 25 and 26) 

    return post


def mstep(X: np.ndarray, post: np.ndarray) -> Tuple[GaussianMixture, float]:
    """M-step: Updates the gaussian mixture. Each cluster
    yields a component mean and variance.

    Args: X: (n, d) array holding the data
        post: (n, K) array holding the soft counts
            for all components for all examples

    Returns:
        GaussianMixture: the new gaussian mixture
        float: the distortion cost for the current assignment
    """
    n, d = X.shape
    _, K = post.shape

    n_hat = post.sum(axis=0)
    p = n_hat / n

    cost = 0
    mu = np.zeros((K, d))
    var = np.zeros(K)

    for j in range(K):
        mu[j, :] = post[:, j] @ X / n_hat[j] ## This step is essentially finding mean of points in each cluster
        sse = ((mu[j] - X)**2).sum(axis=1) @ post[:, j]
        cost += sse
        var[j] = sse / (d * n_hat[j])

    return GaussianMixture(mu, var, p), cost


def run(X: np.ndarray, mixture: GaussianMixture,
        post: np.ndarray) -> Tuple[GaussianMixture, np.ndarray, float]:
    """Runs the mixture model

    Args:
        X: (n, d) array holding the data
        post: (n, K) array holding the soft counts
            for all components for all examples

    Returns:
        GaussianMixture: the new gaussian mixture
        np.ndarray: (n, K) array holding the soft counts
            for all components for all examples
        float: distortion cost of the current assignment
    """

    prev_cost = None
    cost = None
    while (prev_cost is None or prev_cost - cost > 1e-4):
        prev_cost = cost
        post = estep(X, mixture)
        mixture, cost = mstep(X, post)

    return mixture, post, cost


In [2]:
def run_kmeans(X):
    for K in range(1, 5):
        min_cost = None
        best_seed = None
        for seed in range(0, 5):
            mixture, post = init(X, K, seed)
            mixture, post, cost = run(X, mixture, post)
            if min_cost is None or cost < min_cost:
                min_cost = cost
                best_seed = seed

        mixture, post = init(X, K, best_seed)
        mixture, post, cost = run(X, mixture, post)
        print(f'cost: {cost} k:{K}, seed:{seed}')

In [3]:
X = np.loadtxt("./data/toy_data.txt")

In [4]:
run_kmeans(X)

cost: 5462.297452340001 k:1, seed:4
cost: 1684.9079502962372 k:2, seed:4
cost: 1329.5948671544297 k:3, seed:4
cost: 1035.499826539466 k:4, seed:4
