In [None]:
import numpy as np 
import scipy.stats as stats
import matplotlib.pyplot as plt 
from matplotlib import rcParams
import pandas as pd 
import seaborn as sns
from tqdm.notebook import tqdm

import ipywidgets
from ipywidgets import interact
import IPython
# If in your browser the figures are not nicely vizualized, change the following line. 
rcParams['font.size'] = 12

import warnings
warnings.filterwarnings('ignore')

import pyro.optim 
from pyro.infer import Predictive, SVI, Trace_ELBO, HMC, MCMC, NUTS, TraceEnum_ELBO, config_enumerate, infer_discrete
from pyro.infer.autoguide import AutoDelta, AutoDiagonalNormal
import pyro.distributions as dist

import torch 
from torch import nn
from torch.distributions import MultivariateNormal, constraints


prior_c, svi_c, mcmc_c = sns.color_palette()[:3]

# Gaussian Mixture Model
From https://pyro.ai/examples/gmm.html

In [None]:
K = 2  # Fixed number of components.
@config_enumerate
def model(data):
    # Global variables.
    weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K)))
    scale = pyro.sample('scale', dist.LogNormal(0., 2.))
    with pyro.plate('components', K):
        locs = pyro.sample('locs', dist.Normal(0., 10.))

    with pyro.plate('data', len(x)):
        # Local variables.
        assignment = pyro.sample('assignment', dist.Categorical(weights))
        pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=x)

def summary(samples):
    site_stats = {}
    for site_name, values in samples.items():
        marginal_site = pd.DataFrame(values)
        describe = marginal_site.describe(percentiles=[.05, 0.25, 0.5, 0.75, 0.95]).transpose()
        site_stats[site_name] = describe[["mean", "std", "5%", "25%", "50%", "75%", "95%"]]
    return site_stats


In [None]:
x = torch.tensor([0., 1., 10., 11., 12.])

test_x = torch.linspace(-5, 25, 1000)

# Prior samples
num_samples = 100
# print(prior_predictive(x))
prior_predictive = Predictive(model, {}, num_samples=num_samples)(test_x)

weights = prior_predictive['weights'].mean(0)[0]
locs = prior_predictive['locs'].mean(0)
scale = prior_predictive['scale'].mean(0)
print('weights = {}'.format(weights.data.numpy()))
print('locs = {}'.format(locs.data.numpy()))
print('scale = {}'.format(scale.data.numpy()))


with torch.no_grad():
    plt.figure(figsize=(10, 4), dpi=100).set_facecolor('white')

    Y1 = stats.norm.pdf((test_x - locs[0].item()) / scale.item())
    Y2 = stats.norm.pdf((test_x - locs[1].item()) / scale.item())
    Y = weights[0].item() * Y1 + weights[1].item() * Y2

    plt.plot(test_x, Y1, 'r-', label='First Component')
    plt.plot(test_x, Y2, 'b-', label='Second Component')
    plt.plot(test_x, Y, 'k--', label='Mixutre Model')
    plt.plot(x.numpy(), np.zeros(len(x)), 'k*')
    plt.legend(loc='best')
    plt.title('Prior mixture model')
    plt.ylabel('probability density');


## SVI

In [None]:
# Initialization of SVI
def init_loc_fn(site):
    if site["name"] == "weights": # Initialize weights to uniform.
        return torch.ones(K) / K
    if site["name"] == "scale":  # Initialize the scale accordingly to uniform.
        return (x.var() / 2).sqrt()
    if site["name"] == "locs":  # Initialize the locations by sampling u.a.r. from a multinomial.
        return x[torch.multinomial(torch.ones(len(x)) / len(x), K)]
    raise ValueError(site["name"])

def initialize(seed):
    pyro.set_rng_seed(seed)
    pyro.clear_param_store()
    return AutoDiagonalNormal(pyro.poutine.block(model, expose=['weights', 'locs', 'scale']), init_loc_fn=init_loc_fn)

guide = initialize(7)

optim = pyro.optim.Adam({'lr': 0.01, 'betas': [0.8, 0.99]})
elbo = TraceEnum_ELBO(max_plate_nesting=1)

svi = SVI(model, guide, optim, loss=elbo)


map_estimates = guide(x)
weights = map_estimates['weights']
locs = map_estimates['locs']
scale = map_estimates['scale']
print('weights = {}'.format(weights.data.numpy()))
print('locs = {}'.format(locs.data.numpy()))
print('scale = {}'.format(scale.data.numpy()))

plt.figure(figsize=(10, 4), dpi=100).set_facecolor('white')

X = np.arange(-10,25,0.1)
Y1 = stats.norm.pdf((X - locs[0].item()) / scale.item())
Y2 = stats.norm.pdf((X - locs[1].item()) / scale.item())
Y = weights[0].item() * Y1 + weights[1].item() * Y2

plt.plot(X, Y1, 'r-', label='First Component')
plt.plot(X, Y2, 'b-', label='Second Component')
plt.plot(X, Y, 'k--', label='Mixutre Model')
plt.plot(x.numpy(), np.zeros(len(x)), 'k*')
plt.legend(loc='best')
plt.title('SVI Posterior at initialization')
plt.ylabel('probability density');


In [None]:
# Learning SVI
losses = []
for i in tqdm(range(200)):
    loss = svi.step(x)
    losses.append(loss)

map_estimates = guide(x)
weights = map_estimates['weights']
locs = map_estimates['locs']
scale = map_estimates['scale']
print('weights = {}'.format(weights.data.numpy()))
print('locs = {}'.format(locs.data.numpy()))
print('scale = {}'.format(scale.data.numpy()))

X = np.arange(-3,15,0.1)
Y1 = stats.norm.pdf((X - locs[0].item()) / scale.item())
Y2 = stats.norm.pdf((X - locs[1].item()) / scale.item())
Y = weights[0].item() * Y1 + weights[1].item() * Y2

plt.figure(figsize=(10, 4), dpi=100).set_facecolor('white')

X = np.arange(-3,15,0.1)
Y1 = stats.norm.pdf((X - locs[0].item()) / scale.item())
Y2 = stats.norm.pdf((X - locs[1].item()) / scale.item())
Y = weights[0].item() * Y1 + weights[1].item() * Y2

plt.plot(X, Y1, 'r-', label='First Component')
plt.plot(X, Y2, 'b-', label='Second Component')
plt.plot(X, Y, 'k--', label='Mixutre Model')
plt.plot(x.numpy(), np.zeros(len(x)), 'k*', label='Raw Data')
plt.legend(loc='best')
plt.title('Final SVI Model')
plt.ylabel('probability density');

In [None]:
@config_enumerate
def full_guide(data):
    # Global variables.
    with pyro.poutine.block(hide_types=["param"]):  # Keep our learned values of global parameters.
        guide(data)

    # Local variables.
    with pyro.plate('data', len(data)):
        assignment_probs = pyro.param('assignment_probs', torch.ones(len(data), K) / K,
                                      constraint=constraints.unit_interval)
        pyro.sample('assignment', dist.Categorical(assignment_probs))

optim = pyro.optim.Adam({'lr': 0.2, 'betas': [0.8, 0.99]})
elbo = TraceEnum_ELBO(max_plate_nesting=1)
svi = SVI(model, full_guide, optim, loss=elbo)

losses = []
for i in range(200):
    loss = svi.step(x)
    losses.append(loss)

plt.figure(figsize=(10,3), dpi=100).set_facecolor('white')
plt.plot(losses)
plt.xlabel('iters')
plt.ylabel('loss')
plt.yscale('log')
plt.title('Convergence of SVI');
plt.show()


assignment_probs = pyro.param('assignment_probs')
plt.figure(figsize=(10, 3), dpi=100).set_facecolor('white')
plt.plot(x.numpy(), assignment_probs.data.numpy()[:, 0], 'ro',
            label='component with mean {:0.2g}'.format(locs[0]))
plt.plot(x.numpy(), assignment_probs.data.numpy()[:, 1], 'bo',
            label='component with mean {:0.2g}'.format(locs[1]))
plt.title('Mixture assignment probabilities')
plt.xlabel('data value')
plt.ylabel('assignment probability')
plt.legend(loc='center')
plt.show();

## MCMC 

In [None]:
pyro.set_rng_seed(2)

kernel = NUTS(model)
# kernel = HMC(model)

mcmc = MCMC(kernel, num_samples=20, warmup_steps=250)
mcmc.run(x)
posterior_samples = mcmc.get_samples()

weights = posterior_samples["weights"].mean(0)
scale = posterior_samples["scale"].mean()
locs = posterior_samples["locs"].mean(0)

plt.figure(figsize=(10, 4), dpi=100).set_facecolor('white')

X = np.arange(-3,15,0.1)
Y1 = stats.norm.pdf((X - locs[0].item()) / scale.item())
Y2 = stats.norm.pdf((X - locs[1].item()) / scale.item())
Y = weights[0].item() * Y1 + weights[1].item() * Y2

plt.plot(X, Y1, 'r-', label='First Component')
plt.plot(X, Y2, 'b-', label='Second Component')
plt.plot(X, Y, 'k--', label='Mixutre Model')
plt.plot(x.numpy(), np.zeros(len(x)), 'k*', label='Raw Data')
plt.legend(loc='best')
plt.title('Final Density of two-component mixture model')
plt.ylabel('probability density');

X1, X2 = posterior_samples["locs"].t()
plt.figure(figsize=(8, 3), dpi=100).set_facecolor('white')
plt.plot(X1.numpy(), color='red',  label='Loc of component 0')
plt.plot(X2.numpy(), color='blue',  label='Loc of component 1')
plt.xlabel('NUTS step')
plt.ylabel('loc')
plt.title('Trace plot of loc parameter during MCMC sampling')
plt.tight_layout()
plt.show()