<a href="https://colab.research.google.com/github/PGM-Lab/2023-probai-private/blob/main/Day2-AfterLunch/notebooks/solution_simple_gaussian_model_pyro.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<img src="https://github.com/PGM-Lab/2022-ProbAI/raw/main/Day2-AfterLunch/notebooks/Figures/simple_pyro_exercise.png" width="800pt">

In [None]:
!pip install -q --upgrade pyro-ppl torch 

import numpy as np
import torch
from torch.distributions import constraints
import matplotlib.pyplot as plt

import pyro
from pyro.distributions import Normal, Gamma, MultivariateNormal
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
import pyro.optim as optim

## Generate some data

In [None]:
# Sample data
np.random.seed(123)
N = 100
correct_mean = 5
correct_precision = 1
data = torch.tensor(np.random.normal(loc=correct_mean, scale=np.sqrt(1./correct_precision), size=N), dtype=torch.float)

## Our model specification

In [None]:
def model(data):
    gamma = pyro.sample("gamma", Gamma(torch.tensor(1.), torch.tensor(1.)))
    mu = pyro.sample("mu", Normal(torch.zeros(1), torch.tensor(10000.0)))
    with pyro.plate("data", len(data)):
        pyro.sample("x", Normal(loc=mu, scale=torch.sqrt(1. / gamma)), obs=data)

## Our guide specification

In [None]:
def guide(data=None):
    alpha_q = pyro.param("alpha_q", torch.tensor(1.), constraint=constraints.positive)
    beta_q = pyro.param("beta_q", torch.tensor(1.), constraint=constraints.positive)
    pyro.sample("gamma", Gamma(alpha_q, beta_q))

    mean_q = pyro.param("mean_q", torch.tensor(0.))
    scale_q = pyro.param("scale_q", torch.tensor(1.), constraint=constraints.positive)
    pyro.sample("mu", Normal(mean_q, scale_q))

## Do learning

In [None]:
# setup the optimizer
adam_args = {"lr": 0.01}
optimizer = Adam(adam_args)

pyro.clear_param_store()
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
train_elbo = []
# training loop
for epoch in range(3000):
    loss = svi.step(data)
    train_elbo.append(-loss)
    if (epoch % 500) == 0:
        print("[epoch %03d] average training loss: %.4f" % (epoch, loss))

In [None]:
for name, value in pyro.get_param_store().items():
    print(name, pyro.param(name).data.numpy())

In [None]:
plt.plot(range(len(train_elbo)), train_elbo)
plt.xlabel("Number of iterations")
plt.ylabel("ELBO")
plt.show()