# 4 Modules in Pyro

In [2]:
import os
import torch
import torch.nn as nn
import pyro
import pyro.distributions as D
import pyro.poutine as poutine
from torch.distributions import constraints as C
from pyro.nn import PyroModule, PyroParam, PyroSample
from pyro.nn.module import to_pyro_module_
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoNormal
from pyro.optim import Adam


In [3]:
class Linear(nn.Module):
    def __init__(self, insize, outsize):
        super().__init__()
        self.w = nn.Parameter(torch.randn(insize, outsize))
        self.b = nn.Parameter(torch.randn(outsize))
        
    def forward(self, x):
        return self.b + x @ self.w 
    
    
linear = Linear(5, 2)
example = linear(torch.randn(100, 5))
assert example.shape == (100, 2)

In [4]:
class PyroLinear(Linear, PyroModule):
    pass 

linear = PyroLinear(5, 2)
example = linear(torch.randn(100, 5))
assert example.shape == (100, 2)

In [5]:
linear = PyroModule[Linear](5, 2)
example = linear(torch.randn(100, 5))
assert example.shape == (100, 2)

In [6]:
class BayesianLinear(PyroModule):
    def __init__(self, insize, outsize):
        super().__init__()
        self.bias = PyroSample(D.LogNormal(0, 1).expand([outsize]).to_event(1))
        self.weight = PyroSample(D.Normal(0, 1).expand([insize, outsize]).to_event(2))
        
    def forward(self, x):
        return self.bias + x @ self.weight
    
linear = BayesianLinear(5, 2)
example = linear(torch.randn(100, 5))
assert example.shape == (100, 2)

Accesing attributes inside plates

In [7]:
from turtle import forward


class NormalModel(PyroModule):
    def __init__(self):
        super().__init__()
        self.loc = PyroSample(D.Normal(0, 1))
        
class GlobalModel(NormalModel):
    def forward(self, data):
        loc = self.loc 
        assert loc.shape == ()
        with pyro.plate('data', len(data)):
            pyro.sample('obs', D.Normal(loc, 1), obs=data)
            
            
class LocalModel(NormalModel):
    def forward(self, data):
        with pyro.plate('data', len(data)):
            loc = self.loc 
            assert loc.shape == (len(data), )
            pyro.sample('obs', D.Normal(loc, 1), obs=data)
            
data = torch.randn(10)
LocalModel()(data)
GlobalModel()(data)

In [13]:
class Model(PyroModule):
    def __init__(self, insize, outsize):
        super().__init__()
        self.linear = BayesianLinear(insize, outsize)
        self.scale = PyroSample(D.LogNormal(0, 1))
        
    def forward(self, x, y=None):
        loc = self.linear(x)
        scale = self.scale
        with pyro.plate('data', len(x)):
            return pyro.sample('obs', D.Normal(loc, scale).to_event(1), obs=y)

In [14]:
from tqdm import trange

In [16]:
pyro.clear_param_store()
pyro.set_rng_seed(1)

model = Model(5, 2)
x = torch.randn(100, 5)
y = model(x)

guide = AutoNormal(model)
svi = SVI(model, guide, Adam({'lr': 0.01}), Trace_ELBO())

pbar = trange(501)
for step in pbar:
    loss = svi.step(x, y) / y.numel()
    if step % 100 == 0:
        pbar.set_description(f'loss {loss:.4f}')

loss 1.6734: 100%|██████████| 501/501 [00:02<00:00, 219.04it/s]


In [17]:
with poutine.trace() as tr:
    model(x)
for site in tr.trace.nodes.values():
    print(site['type'], site['name'], site['value'].shape)

sample linear.bias torch.Size([2])
sample linear.weight torch.Size([5, 2])
sample scale torch.Size([])
sample data torch.Size([100])
sample obs torch.Size([100, 2])
