# What is a Gaussian Mixture Model?

A Gaussian mixture model (GMM) is a latent variable model of continuous data.  It assumes that each data point comes from one of several different Gaussian distributions.  The modeler assumes she knows the total number of Gaussians in the mixture.


*   List item
*   List item


![GMM](https://imgur.com/bRU3R6m.png) 

The figure on the left is a directed acyclic graph (DAG).  The figure on the right is the same model represented using [plate notation](https://en.wikipedia.org/wiki/Plate_notation).  Plate notation takes a set of nodes in the DAG that repeat and collapses them over one dimension into a single node.  The "plates" represent a single dimension.  

There are two plates in our GMM.  One plate has size N=3 for the number of data points and one of size K = 2 for the number of components.  Next to each plate is a sampling statement that shows how to sample the vector of variables for that plate from conditional probability distributions.

Plate notation is ideal for a specific model class like a GMM because the number of nodes in the DAG can vary from problem to problem, while the plate notation stays the same.

## A GMM as a Causal Model

You have probabbly never heard of a GMM being described as a causal model.  Indeed in most cases it would likely perform poorly as a causal model.  Later, we'll discuss how well this would fair as a causal model.

However, for now, let's just realize the fact that we have a probabilistic generative model on a directed acyclic graph, so we can assume this is a causal model just by assuming the DAG represents causality.

In this figure $X_1$, $X_2$ and $X_3$ are observed continuous random variables.  The fact that they are observed is indicated by grey.

$Z_1$, $Z_2$, and $Z_3$ are latent (unobserved) discrete random variables.  The fact that they are latent is indicated by the white color of the node.

Each observed node $X_i$ is sampled from either a Normal distribution with mean $\mu_1$ or a Normal distribution with mean $\mu_2$.

### So what is the causal generative story?
The _causal generative story_ is simply this; $Z_i$ causes $X_i$.  $Z_i$ is a switch mechanism that causes $X_i$ to have a value of either $\mu_1$ plus noise or $\mu_2$ plus noise or $\mu_3$ plus noise.

## Greeks vs. Romans

We see two kinds of variable names in this representation.  Those named with letters from the Greek letters, and those named with letters from the Roman alphabet.

So what should we think about these Greek letters?  They don't show up in the causal generative story.  Why are they in the graph?

Here is how to understand the differences between the Greeks and the Romans.

1. The Roman letters X and Z are the causally-related components of our data generating process.  
2. The Greek letters $\alpha$, $\theta$ $\sigma$, $\sigma_0$ are parameters or weights.  These are merely parameters of the **causal Markov kernels**.

A **causal Markov kernel** is just another name for the probability distribution of a variable conditional on its parents in the causal DAG.  The actual causal mechanism between the parents and the child determines (the word in the literature is "entails") this probability distribution.  If the causal model is correct, the causal Markov kernels should be invariant across data and domain.

The parameters of the causal Markov kernel are explicitly in the graph because **we are thinking like Bayesians**.  In a previous lecture, we said that we should use probability to represent any uncertainty there is in elements of our "data creation myth."  Generally, in probabilistic graphical models, random variables get their own nodes.  So a Bayesian using graphical modeling will represent parameters as random variables and thus nodes.  Explicitly modeling parameters in the graph structure allows them to use techniques from [Bayesian hierarchical modeling](https://en.wikipedia.org/wiki/Bayesian_hierarchical_modeling) to model uncertainty in these parameters.

However, from our causal perspective, explicit representations of these parameters distract us from the causal relationships we are assuming in our model.  We can get a view of those relationships by ignoring the Greek letters.

![simpler viz](https://i.imgur.com/zdRzSa5.png)



### Simple example

The following shows how to implement our GMM in Pyro.

Fist lets import a bunch of things, not all of which will be needed.

In [0]:
# To install Pyro
#!pip3 install torch torchvision
#!pip3 install pyro-ppl 

import os
from collections import defaultdict
import torch
import numpy as np
import scipy.stats
from torch.distributions import constraints
from matplotlib import pyplot
%matplotlib inline

import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.infer.autoguide import AutoDelta
from pyro.optim import Adam
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate, infer_discrete

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.2.0')
pyro.enable_validation(True)



Next, let's specify the model.  The`config_enumerate` decorator is used in inference.  We don't need to worry about it for our learning purposes.

In [0]:
K = 2

@config_enumerate
def model(N):
    # Global variables.
    α = 0.5
    θ = pyro.sample('θ', dist.Dirichlet(α * torch.ones(K)))
    σ = 1.0
    σ_O = 10.
    with pyro.plate('components', K):
        μ = pyro.sample('μ', dist.Normal(0., σ_O))

    with pyro.plate('data', N):
        # Local variables.
        Z = pyro.sample('Z', dist.Categorical(θ))
        X = pyro.sample('X', dist.Normal(μ[Z], σ))
    return {'X': X, 'Z': Z}

Notice how Pyro has a `pyro.plate` context manager that captures the "plate" abstraction in plate notation.  Also notice how the tensor representation provided by a deep generative modeling framework makes it convenient to capture variables within plates as vectors.

Now let's generate from the model.

In [0]:
model(4)

{'X': tensor([16.7073, 14.2783, 15.1040, 16.0092]), 'Z': tensor([0, 0, 0, 0])}

# Interventions

Since this is a causal model, we can apply interventions.

Pyro has a `pyro.do` function that will take in a model, and return a modified model that reflects the intervention.  It does this by replacing whatever sampling statement was used to generate the intervention target in the model with a statment that fixes that value to the intervention value.

In the following code, I set 10 values of Z to [0, 1, 1, 0, 1, 1, 1, 1, 1, 1].  Then I generate from the model.

In [0]:
intervention = torch.tensor([0, 1, 1, 0, 1, 1, 1, 1, 1, 1])
intervention_model = pyro.do(model, data={'Z': intervention})
intervention_model(10)

{'X': tensor([13.9519,  6.3355,  5.5210, 14.3321,  5.0247,  8.2380,  5.2797,  5.2364,
          6.1813,  4.7738]), 'Z': tensor([0, 1, 1, 0, 1, 1, 1, 1, 1, 1])}

Note the Z values are exactly what the intervention set them to.  The X values are forward generated from the Z values.

# Training the Greeks



In latent variable modeling, the modeler generally doesn't know the values of the Greek variables.  In our case, we used probability distributions to capture that uncertainty.

In practice, modelers try to infer their values from training data (i.e., values of Z's and X's).  In other words, we treat the Greeks as weights in a training step.

There are several ways to learn these parameters from data.  Getting maximum likelihood estimates using expectation maximization is a common way.  Here, since we are thinking as Bayesians, we use Bayesian inference.

A Bayesian inference algorithm will treat the probability distributions we gave to the unknown Greek letters as a prior distribution.  Given data, an inference algorithm will update these distributions.

The following uses an [approximate Bayesian algorithm](https://en.wikipedia.org/wiki/Approximate_Bayesian_computation) called [stochastic variational inference](http://pyro.ai/examples/svi_part_i.html) (SVI).  SVI makes good use of the gradient-descent based optimization infrastructure of a deep learning framework like Pyro.  The following inference implementation will find [MAP estimates](https://en.wikipedia.org/wiki/Maximum_a_posteriori_estimation) of the Greek letters -- these are Bayesian analogs to maximum likelihood estimates.

Do not be intimidated by the following code.  This is not unlike most deep learning code you see with deep learning libraries.  `TraceEnum_ELBO` and `SVI` are abstractions for stochastic variational inference.  I encourage you to learn more about Bayesian inference algorithms.  After all, knowledge of these algorithms tends to correlate with salary.  However, in these AltDeep causal modeling courses we only need a high-level understanding of inference.


In [0]:
data = torch.tensor([0., 1., 10., 11., 12.])
N = len(data)
K = 2  # Fixed number of components.
evidence_model = pyro.condition(model, data={'X': data})

optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]})
elbo = TraceEnum_ELBO(max_plate_nesting=1)

def init_loc_fn(site):
    if site["name"] == "θ":
        # Initialize weights to uniform.
        return torch.ones(K) / K
    if site["name"] == "μ":
        return data[torch.multinomial(torch.ones(N) / N, K)]
    raise ValueError(site["name"])

def initialize(seed):
    global global_guide, svi
    pyro.set_rng_seed(seed)
    pyro.clear_param_store()
    global_guide = AutoDelta(poutine.block(evidence_model, expose=['θ', 'μ']),
                             init_loc_fn=init_loc_fn)
    svi = SVI(evidence_model, global_guide, optim, loss=elbo)
    return svi.loss(evidence_model, global_guide, N)

# Choose the best among 100 random initializations.
loss, seed = min((initialize(seed), seed) for seed in range(100))
initialize(seed)
print('seed = {}, initial_loss = {}'.format(seed, loss))

# Register hooks to monitor gradient norms.
gradient_norms = defaultdict(list)
for name, value in pyro.get_param_store().named_parameters():
    value.register_hook(lambda g, name=name: gradient_norms[name].append(g.norm().item()))

losses = []
for i in range(200 if not smoke_test else 2):
    loss = svi.step(N)
    losses.append(loss)
    print('.' if i % 100 else '\n', end='')


map_estimates = global_guide(N)
θ = map_estimates['θ']
μ = map_estimates['μ']
print('/n')
print('θ = {}'.format(θ.data.numpy()))
print('μ = {}'.format(μ.data.numpy()))


seed = 19, initial_loss = 17.06005859375

...................................................................................................
.................................................................................................../n
θ = [0.62500006 0.375     ]
μ = [10.96146     0.49751246]


Now that we have estimates for the value of our Greeks, we can replace their distributions in the model with these estimates values.  An even more ideal approach would be to sample them from distributions in `pyro.distributions` that were close in shape to the posteriors of these Greeks.

## A word of caution on inference

There is much to say about Bayesian inference.  This is not a course on inference so I don't say much and leave it to you to experiment with various inference abstractions in Pyro.

However, there are some points worth mentioning when it comes to inferring the values of "Greeks" in causal models.  Firstly, getting these Greek letters right is of supreme importance in the common causal inference task of *inferring causal effects*, meaning quantifying the degree to which a cause influences an effect.

The above inference algorithm assumes latent Z's, which is the usual case for GMM's.  Even if our causal model were a good one, trying to train model parameters when causes like Z are latent can lead to problems when trying to estimate these causal effects accurately.  We address this in the "Identification and Estimation" part of the causal modeling curriculum.

Also, as a general rule, if you want an accurate estimation of the Greek variables, you should avoid approximate Bayesian algorithms in favor of exact ones (like MCMC approaches).  Approximate algorithms often ignore important nonlinearities in the causal mechanisms in exchange for speed and scalability.

That said, if all we care about is getting reasonably good predictions of interventions, we might be okay if we had a good causal model.  Further, we could start with a basic GMM, then apply the **iterative refutation algorithm** (see lecture notes in Model-based Inference om Machine Learning)  to iterate on this model.  Each iteration we could retrain the model using new data from actual intervention experiments from the previous intervention, gradually overcoming estimation problems.