In [104]:
import operator
import string
from functools import reduce
from typing import Union

import funsor
import torch
from pyro import set_rng_seed as pyro_set_rng_seed
from pyro.infer import config_enumerate, TraceEnum_ELBO
from torch import Tensor

from sbb.counting import enumerate_sequences, count_elems, count_sequences

funsor.set_backend("torch")
torch.set_default_dtype(torch.float32)
pyro_set_rng_seed(0)

from pyroapi import pyro
import pyro.distributions as dist

pyro.clear_param_store()

In [105]:
haydn_symphonies = torch.arange(30)
modern_works = torch.arange(15)
beethoven_symphonies = torch.arange(9)

In [106]:
%%time
p = enumerate_sequences(haydn_symphonies, modern_works, beethoven_symphonies)
count_elems(p)

CPU times: user 628 µs, sys: 0 ns, total: 628 µs
Wall time: 589 µs


4050

In [107]:
%%time
count_sequences(haydn_symphonies, modern_works, beethoven_symphonies)

CPU times: user 19 µs, sys: 3 µs, total: 22 µs
Wall time: 25.7 µs


4050

In [108]:
@config_enumerate(default='sequential')
def model():
    with pyro.plate("program_haydn"):
        haydn = pyro.sample("haydn", dist.Categorical(logits=torch.zeros(30)), infer={'enumerate': "parallel"})

    with pyro.plate("program_modern"):
        modern = pyro.sample("modern", dist.Categorical(logits=torch.zeros(15)), infer={'enumerate': "parallel"})

    with pyro.plate("program_beethoven"):
        beethoven = pyro.sample("beethoven", dist.Categorical(logits=torch.zeros(9)), infer={'enumerate': "parallel"})

    pieces = [0, 1, 2]
    with pyro.plate("program_order"):
        first = pyro.sample("first", dist.Categorical(logits=torch.zeros(len(pieces))))
        piece1 = pieces[first]
        del pieces[first]
        second = pyro.sample("second", dist.Categorical(logits=torch.zeros(len(pieces))))
        piece2 = pieces[second]
        del pieces[second]
        piece3 = pieces[0]

    print(f"  model haydn.shape {haydn.shape}")
    print(f"  model modern.shape {modern.shape}")
    print(f"  model beethoven.shape {beethoven.shape}")
    print(f"  model first.shape {first.shape}")
    print(f"  model second.shape {second.shape}")
    print(f"  model haydn {haydn}")
    print(f"  model modern {modern}")
    print(f"  model beethoven {beethoven}")
    print(f"  model first {first}")
    print(f"  model second {second}")

    return piece1, piece2, piece3, haydn, modern, beethoven

In [109]:
@config_enumerate(default="sequential")
def guide():
    with pyro.plate("program_haydn"):
        haydn = pyro.sample("haydn", dist.Categorical(logits=torch.zeros(30)), infer={"enumerate": "parallel"})

    with pyro.plate("program_modern"):
        modern = pyro.sample("modern", dist.Categorical(logits=torch.zeros(15)), infer={"enumerate": "parallel"})

    with pyro.plate("program_beethoven"):
        beethoven = pyro.sample("beethoven", dist.Categorical(logits=torch.zeros(9)), infer={"enumerate": "parallel"})

    pieces = [0, 1, 2]
    with pyro.plate("program_order"):
        first = pyro.sample("first", dist.Categorical(logits=torch.zeros(len(pieces))))
        piece1 = pieces[first]
        del pieces[first]
        second = pyro.sample("second", dist.Categorical(logits=torch.zeros(len(pieces))))
        piece2 = pieces[second]
        del pieces[second]
        piece3 = pieces[0]

    print(f"  guide haydn.shape {haydn.shape}")
    print(f"  guide modern.shape {modern.shape}")
    print(f"  guide beethoven.shape {beethoven.shape}")
    print(f"  guide first.shape {first.shape}")
    print(f"  guide second.shape {second.shape}")
    print(f"  guide haydn {haydn}")
    print(f"  guide modern {modern}")
    print(f"  guide beethoven {beethoven}")
    print(f"  guide first {first}")
    print(f"  guide second {second}")

    return piece1, piece2, piece3, haydn, modern, beethoven

In [110]:
print("Model:")
model()

Model:
  model haydn.shape torch.Size([])
  model modern.shape torch.Size([])
  model beethoven.shape torch.Size([])
  model first.shape torch.Size([])
  model second.shape torch.Size([])
  model haydn 29
  model modern 10
  model beethoven 4
  model first 2
  model second 1


(2, 1, 0, tensor(29), tensor(10), tensor(4))

In [None]:
elbo = TraceEnum_ELBO(max_plate_nesting=1)
elbo.loss(model, guide)

In [115]:
def model():
    pieces = list(range(30 + 15 + 9))

    piece1_index = pyro.sample("piece1", dist.Categorical(logits=torch.zeros(len(pieces))))
    piece1 = pieces[piece1_index]

    del pieces[piece1_index]

    piece2_index = pyro.sample("piece2", dist.Categorical(logits=torch.zeros(len(pieces))))
    piece2 = pieces[piece2_index]

    del pieces[piece2_index]

    piece3_index = pyro.sample("piece3", dist.Categorical(logits=torch.zeros(len(pieces))))
    piece3 = pieces[piece3_index]

    del pieces[piece3_index]

    return piece1, piece2, piece3

In [118]:
model()

(33, 25, 22)