In [21]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from pyro.infer.autoguide import AutoDiagonalNormal

# Generate synthetic data
np.random.seed(42)
x = np.linspace(-3, 3, 100)
y = np.sin(x) + np.random.normal(0, 0.2, size=x.shape)
x_data = torch.tensor(x, dtype=torch.float32).view(-1, 1)
y_data = torch.tensor(y, dtype=torch.float32).view(-1, 1)

In [22]:
x_data.shape, y_data.shape

(torch.Size([100, 1]), torch.Size([100, 1]))

In [23]:
class BayesianRegression(PyroModule):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.linear = PyroModule[torch.nn.Linear](input_size, output_size)
        self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([output_size, input_size]).to_event(2))
        self.linear.bias = PyroSample(dist.Normal(0., 10.).expand([output_size]).to_event(1))

    def forward(self, x, y=None):
        sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
        mean = self.linear(x).squeeze(-1)
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
        return mean

In [24]:
from pyro.infer.autoguide import AutoDiagonalNormal

guide = AutoDiagonalNormal(BayesianRegression(1, 1))

In [25]:
# Setup the optimizer and the SVI object
optimizer = Adam({"lr": 0.03})
svi = SVI(BayesianRegression(1, 1), guide, optimizer, loss=Trace_ELBO())

# Training loop
num_iterations = 1500
for j in range(num_iterations):
    # Calculate the loss and take a gradient step
    loss = svi.step(x_data, y_data)
    if j % 100 == 0:
        print(f"Epoch {j}: loss = {loss}")

ValueError: at site "obs", invalid log_prob shape
  Expected [100], actual [100, 100]
  Try one of the following fixes:
  - enclose the batched tensor in a with pyro.plate(...): context
  - .to_event(...) the distribution being sampled
  - .permute() data dimensions