In [83]:
from collections import OrderedDict

import funsor
import torch
from pyro import set_rng_seed as pyro_set_rng_seed

funsor.set_backend("torch")
torch.set_default_dtype(torch.float32)
pyro_set_rng_seed(0)

import pyro.contrib.funsor
from pyroapi import pyro
from pyroapi import distributions as dist

pyro.clear_param_store()

In [84]:
data = OrderedDict()
data['land'] = 11.
data['water'] = 4.

In [85]:
@pyro.infer.infer_discrete(first_available_dim=-1, temperature=1)
@pyro.infer.config_enumerate
def simple_model(total_count):
    p = pyro.sample("p", dist.Uniform(0, 1))
    return pyro.sample("obs", dist.Binomial(total_count, p))


conditioned_model = pyro.condition(simple_model, data={"obs": torch.as_tensor(data['water'])})


In [86]:
simple_model(sum(data.values()))

tensor(5.)

In [87]:
conditioned_model(sum(data.values()))

tensor(4.)

In [88]:
def guide(total_count):
    p_latent = pyro.param("p_latent", torch.tensor(0.5))
    pyro.sample("p", dist.Delta(p_latent))


In [89]:
optim = pyro.optim.ClippedAdam({"lr":.30})
loss = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0)
svi = pyro.infer.SVI(conditioned_model, guide, optim, loss)

In [90]:
for _ in range(100):
    svi.step(15)
print(pyro.param("p_latent"))

tensor(0.2666, requires_grad=True)


In [91]:
#loss.compute_marginals(conditioned_model, guide, 15)