In [1]:
import torch
import pyro
from pyro.distributions import Bernoulli, Categorical, MultivariateNormal, Normal

In [2]:
%load_ext watermark
%watermark -n -u -v -iv -w

Last updated: Mon Jun 24 2024

Python implementation: CPython
Python version       : 3.9.19
IPython version      : 8.18.1

pyro : 1.9.1
torch: 2.3.0

Watermark: 2.4.3



In [4]:
d = Bernoulli(0.5)
assert d.batch_shape == ()
assert d.event_shape == ()
x = d.sample()
assert x.shape == ()
assert d.log_prob(x).shape == ()
print(x)

tensor(1.)


In [5]:
# Distributions can be batched by passing in batched parameters.
d = Bernoulli(0.5 * torch.ones(3,4))
assert d.batch_shape == (3, 4)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)
print(x)

tensor([[0., 1., 1., 0.],
        [1., 0., 0., 0.],
        [0., 0., 1., 1.]])


In [6]:
# Another way to batch distributions is via the .expand() method. 
# This only works if parameters are identical along the leftmost dimensions.
batch = torch.tensor([0.1, 0.2, 0.3, 0.4])
print(batch)
d = Bernoulli(batch).expand([3, 4])
assert d.batch_shape == (3, 4)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)

n = 100000
for i in range(n):
    x = x + d.sample()
print(x/(n+1))

tensor([0.1000, 0.2000, 0.3000, 0.4000])
tensor([[0.1000, 0.2019, 0.2998, 0.3999],
        [0.1004, 0.2008, 0.3002, 0.3992],
        [0.1000, 0.2002, 0.2999, 0.3993]])


In [7]:
# Another way to batch distributions is via the .expand() method. 
# This only works if parameters are identical along the leftmost dimensions.
batch = torch.tensor([[0.1], [0.2], [0.3]])
print(batch)
d = Bernoulli(batch).expand([3, 4])
assert d.batch_shape == (3, 4)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)

n = 100000
for i in range(n):
    x = x + d.sample()
print(x/(n+1))

tensor([[0.1000],
        [0.2000],
        [0.3000]])
tensor([[0.1005, 0.1000, 0.0988, 0.0991],
        [0.2007, 0.1982, 0.2000, 0.2005],
        [0.3004, 0.3008, 0.3001, 0.2987]])


In [8]:
d = MultivariateNormal(torch.zeros(3), torch.eye(3, 3))
assert d.batch_shape == ()
assert d.event_shape == (3,)
x = d.sample()
assert x.shape == (3,)            # == batch_shape + event_shape
assert d.log_prob(x).shape == ()  # == batch_shape
print(d.log_prob(x))

tensor(-3.1385)


# Reshaping Distribution

In [9]:
# In Pyro you can treat a univariate distribution as multivariate by calling the .to_event(n) property 
# where n is the number of batch dimensions (from the right) to declare as dependent.

d = Bernoulli(0.5 * torch.ones(3,4))
assert d.batch_shape == (3,4)
assert d.event_shape == ()

d = d.to_event(1)
assert d.batch_shape == (3,)
assert d.event_shape == (4,)

x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3,)

In [10]:
# In Pyro you can treat a univariate distribution as multivariate by calling the .to_event(n) property 
# where n is the number of batch dimensions (from the right) to declare as dependent.

# we need to ensure that batch_shape is carefully controlled by either trimming it down using .to_event(n)
# or by declaring dimensions as independent by using pyro.plate 
d = Bernoulli(0.5 * torch.ones(3,4))
assert d.batch_shape == (3,4)
assert d.event_shape == ()

d = d.to_event(2)
assert d.batch_shape == ()
assert d.event_shape == (3, 4)

x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == ()

\# log prob size = batch_shape

\# samples have shape = batch_shape + event_shape

# Safe to assume dependence

In [11]:
d = Normal(0, 1)
print(d.batch_shape, d.event_shape)
d = d.expand([10])
print(d.batch_shape, d.event_shape)
d = d.to_event(1)
print(d.batch_shape, d.event_shape)
x = pyro.sample("x", d)
assert x.shape == (10,)

# assumes conditional independence
# plate informs Pyro that it can make use of conditional independence information when estimating gradients, 
# whereas in the first version Pyro must assume they are dependent (even though the normals are in fact conditionally independent). 
with pyro.plate("x_plate", 10):
    x = pyro.sample("x", Normal(0, 1))  # .expand([10]) is automatic
assert x.shape == (10,)


torch.Size([]) torch.Size([])
torch.Size([10]) torch.Size([])
torch.Size([]) torch.Size([10])


In [12]:
d = Normal(0, 1)
print(d.batch_shape, d.event_shape)
d = d.expand([10, 3])
print(d.batch_shape, d.event_shape)
d = d.to_event(1)
print(d.batch_shape, d.event_shape)
x = pyro.sample("x", d)
assert x.shape == (10,3)
assert d.log_prob(x).shape == (10,)
print(d.log_prob(x))

torch.Size([]) torch.Size([])
torch.Size([10, 3]) torch.Size([])
torch.Size([10]) torch.Size([3])
tensor([-4.5822, -2.9822, -3.9823, -3.1746, -2.8273, -8.0561, -7.0914, -2.9798,
        -2.8291, -3.3716])
