In [23]:
import sys
sys.path.append("../")
sys.dont_write_bytecode = True

import pyro
import torch
import pyro.distributions as dist

from vectorized_loop.ops import Index
import vectorized_loop as vec
import logging

logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

### Enumeration

In [24]:
@vec.enum
@vec.vectorize
def model(s: vec.State):
    s.z = torch.tensor([1, 2, 3])  # (, | 3)
    s.x = 0  # ()
    for i in vec.range("a", 10, vectorized=True):
        s.y = s.z + s.x.unsqueeze(-1)  # (10 | 3) -> (3, 1 | 10 | 3)
        s.x = pyro.sample(
            "x",
            dist.Categorical(s.y),
            infer={"enumerate": "parallel"},
        )  # (3 | 10 | ,)

vec.clear_allocators()
vec.trace(model).get_trace().nodes["x"]

DEBUG:vectorized_loop.vectorized_loop_messenger:a (vectorized): repeat 2


{'type': 'sample',
 'name': 'x',
 'fn': BranchDistribution(site_shape=(10,), event_shape=torch.Size([]), conds=[tensor(True)], dists=[Categorical(probs: torch.Size([3, 1, 10, 3]))]),
 'is_observed': False,
 'args': (),
 'kwargs': {},
 'value': tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]]),
 'infer': {'enumerate': 'parallel', '_enumerate_dim': -2},
 'scale': 1.0,
 'mask': None,
 'cond_indep_stack': (CondIndepStackFrame(name='a', dim=-1, size=10, counter=None, full_size=10),),
 'done': True,
 'stop': False,
 'continuation': None}

In [25]:
@vec.vectorize
def model(s: vec.State, vectorized, P):
    s.x = 2
    s.y = 1
    for i in vec.range("a", 10, vectorized=vectorized):  # (a: -1)
        s.x = pyro.sample("x", dist.Categorical(P[(s.x.long() * s.y.long()) % 5]), infer={"enumerate": "parallel"})  # (x_curr: -3 -> x_prev: -6)
        for j in vec.range("b", 10, vectorized=vectorized):  # (b: -2)
            s.y = pyro.sample("y", dist.Categorical(P[(s.x.long() + s.y.long()) % 5]), infer={"enumerate": "parallel"})  # (y_curr: -4 -> y_prev: -5)

P = torch.rand(5, 5)
P = P / P.sum(dim=1, keepdim=True)

vec.clear_allocators()
tr = vec.trace(vec.enum(model)).get_trace(True, P)
tr.compute_log_prob()
x_log_prob_1 = tr.nodes["x"]["log_prob"]
y_log_prob_1 = tr.nodes["y"]["log_prob"]

vec.clear_allocators()
tr = vec.trace(vec.enum(model)).get_trace(False, P)
tr.compute_log_prob()
x_log_prob_2 = tr.nodes["x"]["log_prob"]
y_log_prob_2 = tr.nodes["y"]["log_prob"]

print(x_log_prob_1.shape)  # (x_prev, y_prev, -,      x_curr | -, a)
print(torch.allclose(x_log_prob_1, x_log_prob_2))
print(y_log_prob_1.shape)  # (-     , y_prev, y_curr, x_curr | b, a)
print(torch.allclose(y_log_prob_1, y_log_prob_2))

DEBUG:vectorized_loop.vectorized_loop_messenger:b (vectorized): repeat 2
DEBUG:vectorized_loop.vectorized_loop_messenger:b (vectorized): repeat 2
DEBUG:vectorized_loop.vectorized_loop_messenger:a (vectorized): repeat 2


torch.Size([5, 5, 1, 5, 1, 10])
True
torch.Size([5, 5, 5, 10, 10])
True


In [26]:
@vec.vectorize
def model(s: vec.State, vectorized):
    P = torch.ones(5, 5) / 5
    s.x = 0
    s.y = 0
    for i in vec.range("a", 10, vectorized=vectorized):  # (a: -1)
        s.x = pyro.sample("x", dist.Categorical(P[s.x.long()]), infer={"enumerate": "parallel"})  # (x_curr: -2 -> x_prev: -4)
        s.y = pyro.sample("y", dist.Categorical(P[s.y.long()]), infer={"enumerate": "parallel"})  # (y_curr: -3 -> y_prev: -5)

vec.clear_allocators()
tr = vec.trace(vec.enum(model)).get_trace(True)
print(tr.nodes["x"]["value"].shape)                                # (-     , -     , -     , x_curr | a)
print(tr.nodes["x"]["fn"].log_prob(tr.nodes["x"]["value"]).shape)  # (-     , x_prev, -     , x_curr | a)
print(tr.nodes["y"]["value"].shape)                                # (-     , -     , y_curr, -      | a)
print(tr.nodes["y"]["fn"].log_prob(tr.nodes["y"]["value"]).shape)  # (y_prev, -     , y_curr, -      | a)

DEBUG:vectorized_loop.vectorized_loop_messenger:a (vectorized): repeat 2


torch.Size([5, 10])
torch.Size([5, 1, 5, 10])
torch.Size([5, 1, 10])
torch.Size([5, 1, 5, 1, 10])


In [27]:
@vec.vectorize
def model(s: vec.State):
    P = torch.rand(5, 5).log()
    s.x = 2
    for _ in vec.range("a", 3, vectorized=True):
        s.x = pyro.sample("x", dist.Categorical(logits=Index(P)[s.x]), infer={"enumerate": "parallel"})  # (x_curr: -3 -> x_prev: -5)
        s.z = s.x  # (x_curr | -, a)
        for _ in vec.range("b", 3, vectorized=True):
            s.y = pyro.sample("y", dist.Categorical(logits=Index(P)[s.x]), infer={"enumerate": "parallel"})  # (y_curr: -4)
            s.z = s.x + s.y  # (y_curr, x_curr | b, a)
        
vec.clear_allocators()

tr_model = vec.trace(vec.enum(model)).get_trace()
tr_model.compute_log_prob()

print(tr_model.nodes["x"]["value"].shape)     # (-     , -     , x_curr | - , a)
print(tr_model.nodes["x"]["log_prob"].shape)  # (x_prev, -     , x_curr | - , a)
print(tr_model.nodes["y"]["value"].shape)     # (-     , y_curr, -      | b , a)
print(tr_model.nodes["y"]["log_prob"].shape)  # (-     , y_curr, x_curr | b , a)

DEBUG:vectorized_loop.vectorized_loop_messenger:b (vectorized): repeat 1
DEBUG:vectorized_loop.vectorized_loop_messenger:b (vectorized): repeat 1
DEBUG:vectorized_loop.vectorized_loop_messenger:a (vectorized): repeat 2


torch.Size([5, 1, 3])
torch.Size([5, 1, 5, 1, 3])
torch.Size([5, 1, 3, 3])
torch.Size([5, 5, 3, 3])


### Enumeration with branches

In [28]:
@vec.vectorize
def model(s: vec.State, P, is_guide, vectorized):
    if not is_guide:
        s.x = 2
        for s.i in vec.range("a", 3, vectorized=vectorized):
            s.y = s.x
            with vec.branch(s.i >= 1):
                s.x = pyro.sample("x", dist.Categorical(logits=Index(P)[s.y]), infer={"enumerate": "parallel"})
            with vec.branch(s.i == 0):
                s.x = pyro.sample("x", dist.Categorical(logits=Index(P)[(s.y + 1) % 5]), infer={"enumerate": "parallel"})

P = torch.rand(5, 5).log()
P = P - P.logsumexp(-1, True)

vec.clear_allocators()
tr_guide = vec.trace(model).get_trace(P, is_guide=True, vectorized=False)
tr_model = vec.trace(vec.replay(vec.enum(model), tr_guide)).get_trace(P, is_guide=False, vectorized=False)
tr_model.compute_log_prob()
log_prob_1 = tr_model.nodes["x"]["log_prob"]

vec.clear_allocators()
tr_guide = vec.trace(model).get_trace(P, is_guide=True, vectorized=False)
tr_model = vec.trace(vec.replay(vec.enum(model), tr_guide)).get_trace(P, is_guide=False, vectorized=True)
tr_model.compute_log_prob()
log_prob_2 = tr_model.nodes["x"]["log_prob"]

print(log_prob_1.shape)
print(torch.allclose(log_prob_1, log_prob_2))

DEBUG:vectorized_loop.vectorized_loop_messenger:a (vectorized): repeat 2


torch.Size([5, 5, 3])
True


### Basic plates and markovs

In [29]:
def run(model, vectorized):
    tr_guide = vec.trace(model).get_trace(False)
    tr_model = vec.trace(vec.replay(model, tr_guide)).get_trace(vectorized)

    for name, site in tr_guide.nodes.items():
        if site["type"] == "sample":
            print("Guide: %s ~ %s" % (site["value"], site["fn"]))

    for name, site in tr_model.nodes.items():
        if site["type"] == "sample":
            print("Model: %s ~ %s" % (site["value"], site["fn"]))

In [30]:
# 1-level nested plates

@vec.vectorize
def model(s: vec.State, vectorized):
    for i in vec.range("a", size=3, vectorized=vectorized):
        s.x = pyro.sample("x", dist.Normal(0, 1))
        s.y = pyro.sample("y", dist.Normal(s.x, 1))

vec.clear_allocators()
# run(model, True)
run(model, False)

Guide: tensor([-0.9255,  0.7853,  0.7795]) ~ BranchDistribution(site_shape=(3,), event_shape=torch.Size([]), conds=[tensor([ True, False, False]), tensor([False,  True, False]), tensor([False, False,  True])], dists=[Normal(loc: 0.0, scale: 1.0), Normal(loc: 0.0, scale: 1.0), Normal(loc: 0.0, scale: 1.0)])
Guide: tensor([-0.1979,  0.3649,  0.4731]) ~ BranchDistribution(site_shape=(3,), event_shape=torch.Size([]), conds=[tensor([ True, False, False]), tensor([False,  True, False]), tensor([False, False,  True])], dists=[Normal(loc: -0.9254725575447083, scale: 1.0), Normal(loc: 0.7852697372436523, scale: 1.0), Normal(loc: 0.7795404195785522, scale: 1.0)])
Model: tensor([-0.9255,  0.7853,  0.7795]) ~ BranchDistribution(site_shape=(3,), event_shape=torch.Size([]), conds=[tensor([ True, False, False]), tensor([False,  True, False]), tensor([False, False,  True])], dists=[Normal(loc: 0.0, scale: 1.0), Normal(loc: 0.0, scale: 1.0), Normal(loc: 0.0, scale: 1.0)])
Model: tensor([-0.1979,  0.364

In [31]:
# 2-level nested plates

@vec.vectorize
def model(s: vec.State, vectorized):
    for _ in vec.range("a", size=2, vectorized=vectorized):
        for _ in vec.range("b", size=3, vectorized=vectorized):
            s.x = pyro.sample("x", dist.Normal(0, 1))
            s.y = pyro.sample("y", dist.Normal(s.x, 1))

vec.clear_allocators()
run(model, True)
# run(model, False)

DEBUG:vectorized_loop.vectorized_loop_messenger:b (vectorized): repeat 1
DEBUG:vectorized_loop.vectorized_loop_messenger:a (vectorized): repeat 1


Guide: tensor([[1.0888, 0.9823],
        [0.6738, 1.1079],
        [1.1436, 0.3730]]) ~ BranchDistribution(site_shape=(3, 2), event_shape=torch.Size([]), conds=[tensor([[ True, False],
        [False, False],
        [False, False]]), tensor([[False, False],
        [ True, False],
        [False, False]]), tensor([[False, False],
        [False, False],
        [ True, False]]), tensor([[False,  True],
        [False, False],
        [False, False]]), tensor([[False, False],
        [False,  True],
        [False, False]]), tensor([[False, False],
        [False, False],
        [False,  True]])], dists=[Normal(loc: 0.0, scale: 1.0), Normal(loc: 0.0, scale: 1.0), Normal(loc: 0.0, scale: 1.0), Normal(loc: 0.0, scale: 1.0), Normal(loc: 0.0, scale: 1.0), Normal(loc: 0.0, scale: 1.0)])
Guide: tensor([[ 2.1782,  0.6153],
        [ 1.0825,  0.0347],
        [ 1.7723, -0.9003]]) ~ BranchDistribution(site_shape=(3, 2), event_shape=torch.Size([]), conds=[tensor([[ True, False],
        [False,

In [32]:
# 1 level nested markovs

@vec.vectorize
def model(s: vec.State, vectorized):
    s.x = 10
    for _ in vec.range("a", size=3, vectorized=vectorized):
        s.x = pyro.sample("x", dist.Normal(s.x, 1))
        
vec.clear_allocators()
run(model, True)
# run(model, False)

DEBUG:vectorized_loop.vectorized_loop_messenger:a (vectorized): repeat 2


Guide: tensor([10.5045, 10.1355,  9.3305]) ~ BranchDistribution(site_shape=(3,), event_shape=torch.Size([]), conds=[tensor([ True, False, False]), tensor([False,  True, False]), tensor([False, False,  True])], dists=[Normal(loc: 10.0, scale: 1.0), Normal(loc: 10.504530906677246, scale: 1.0), Normal(loc: 10.135538101196289, scale: 1.0)])
Model: tensor([10.5045, 10.1355,  9.3305]) ~ BranchDistribution(site_shape=(3,), event_shape=torch.Size([]), conds=[tensor(True)], dists=[Normal(loc: torch.Size([3]), scale: torch.Size([3]))])


In [33]:
# 2 level nested markovs

@vec.vectorize
def model(s: vec.State, vectorized):
    s.x = 10
    for _ in vec.range("a", size=2, dim=-2, vectorized=vectorized):
        for _ in vec.range("b", size=3, dim=-1, vectorized=vectorized):
            s.x = pyro.sample("x", dist.Normal(s.x, 1))

vec.clear_allocators()
run(model, True)
# run(model, False)

DEBUG:vectorized_loop.vectorized_loop_messenger:b (vectorized): repeat 2
DEBUG:vectorized_loop.vectorized_loop_messenger:b (vectorized): repeat 2
DEBUG:vectorized_loop.vectorized_loop_messenger:a (vectorized): repeat 2


Guide: tensor([[9.4325, 8.9606, 8.8674],
        [8.0014, 7.1650, 6.5959]]) ~ BranchDistribution(site_shape=(2, 3), event_shape=torch.Size([]), conds=[tensor([[ True, False, False],
        [False, False, False]]), tensor([[False,  True, False],
        [False, False, False]]), tensor([[False, False,  True],
        [False, False, False]]), tensor([[False, False, False],
        [ True, False, False]]), tensor([[False, False, False],
        [False,  True, False]]), tensor([[False, False, False],
        [False, False,  True]])], dists=[Normal(loc: 10.0, scale: 1.0), Normal(loc: 9.432474136352539, scale: 1.0), Normal(loc: 8.960566520690918, scale: 1.0), Normal(loc: 8.867393493652344, scale: 1.0), Normal(loc: 8.001392364501953, scale: 1.0), Normal(loc: 7.165006637573242, scale: 1.0)])
Model: tensor([[9.4325, 8.9606, 8.8674],
        [8.0014, 7.1650, 6.5959]]) ~ BranchDistribution(site_shape=(2, 3), event_shape=torch.Size([]), conds=[tensor(True)], dists=[Normal(loc: torch.Size([2, 3]), 