In [1]:
import torch
from pyroapi import pyro

import pyro.distributions as dist

In [19]:
def model_with(t=None):
    l = pyro.sample("l", dist.Beta(0.5, 0.5))
    t = pyro.sample("t", dist.Bernoulli(l), obs=t)
    return t

def model_without():
    l = pyro.sample("l", dist.Beta(0.5, 0.5))
    t = pyro.sample("t", dist.Bernoulli(l))
    return t

In [25]:
def model_with(t=None):
    with pyro.plate("N"):
        l = pyro.sample("l", dist.Beta(0.5, 0.5))
        t = pyro.sample("t", dist.Bernoulli(l), obs=t)
    return t

def model_without():
    with pyro.plate("N"):
        l = pyro.sample("l", dist.Beta(0.5, 0.5))
        t = pyro.sample("t", dist.Bernoulli(l))
    return t

In [27]:
trace1 = pyro.poutine.trace(model_with).get_trace(torch.tensor(1.))
trace2 = pyro.poutine.trace(pyro.condition(model_without, {"t": torch.tensor(1.)})).get_trace()

trace1.log_prob_sum(), trace2.log_prob_sum()

(tensor(-0.0457), tensor(-1.2248))

In [24]:
cond_with = pyro.condition(model_with, {"l": torch.tensor(0.5)})
cond_without = pyro.condition(model_without, {"l": torch.tensor(0.5)})
trace1 = pyro.poutine.trace(cond_with).get_trace(torch.tensor(1.))
trace2 = pyro.poutine.trace(pyro.condition(cond_without, {"t": torch.tensor(1.)})).get_trace()

trace1.log_prob_sum(), trace2.log_prob_sum()

(tensor(-1.1447), tensor(-1.1447))

In [32]:
def model2_with(t1=None, t2=None):
    l1 = pyro.sample("l1", dist.Beta(0.5, 0.5))
    t1 = pyro.sample("t1", dist.Bernoulli(l1), obs=t1)
    l2 = pyro.sample("l2", dist.Beta(0.5, 0.5))
    t2 = pyro.sample("t2", dist.Bernoulli(l2), obs=t2)
    return t1,t2

def model2_without():
    l1 = pyro.sample("l1", dist.Beta(0.5, 0.5))
    t1 = pyro.sample("t1", dist.Bernoulli(l1))
    l2 = pyro.sample("l2", dist.Beta(0.5, 0.5))
    t2 = pyro.sample("t2", dist.Bernoulli(l2))
    return t1,t2

In [2]:
def model2_with(t1=None, t2=None):
    with pyro.plate("1"):
        l1 = pyro.sample("l1", dist.Beta(1., 1.))
        t1 = pyro.sample("t1", dist.Bernoulli(l1), obs=t1)
    with pyro.plate("2"):
        l2 = pyro.sample("l2", dist.Beta(1., 1.))
        t2 = pyro.sample("t2", dist.Bernoulli(l2), obs=t2)
    return t1,t2

def model2_without():
    with pyro.plate("1"):
        l1 = pyro.sample("l1", dist.Beta(1., 1.))
        t1 = pyro.sample("t1", dist.Bernoulli(l1))
    with pyro.plate("2"):
        l2 = pyro.sample("l2", dist.Beta(1., 1.))
        t2 = pyro.sample("t2", dist.Bernoulli(l2))
    return t1,t2

In [3]:
def model(athlete=None, paying=None):
    with pyro.plate("A"):
        latent_a = pyro.sample("latent_athlete", dist.Beta(1.,1.))
        a = pyro.sample("athlete", dist.Bernoulli(latent_a), obs=athlete)
    with pyro.plate("P"):
        latent_p = pyro.sample("latent_paying", dist.Beta(1.,1.))
        p = pyro.sample("paying", dist.Bernoulli(latent_p), obs=paying)
    return a, p

In [76]:
trace1 = pyro.poutine.trace(model2_with).get_trace(torch.tensor(1.), torch.tensor(0.))
trace2 = pyro.poutine.trace(pyro.condition(model2_without, {"t1": torch.tensor(1.), "t2": torch.tensor(0.)})).get_trace()

trace1.log_prob_sum(), trace2.log_prob_sum()

(tensor(-0.8244), tensor(-2.4242))

In [5]:
cond_with = pyro.condition(model2_with, {"l1": torch.tensor(0.25), "l2": torch.tensor(0.75)})
cond_without = pyro.condition(model2_without, {"l1": torch.tensor(0.25), "l2": torch.tensor(0.75)})
trace1 = pyro.poutine.trace(cond_with).get_trace(torch.tensor(1.), torch.tensor(0.))
trace2 = pyro.poutine.trace(pyro.condition(cond_without, {"t1": torch.tensor(1.), "t2": torch.tensor(0.)})).get_trace()

trace1.log_prob_sum(), trace2.log_prob_sum()

(tensor(-2.7726), tensor(-2.7726))

In [7]:
cond_model = pyro.condition(model, {"latent_athlete": torch.tensor(0.25), "latent_paying": torch.tensor(0.75)})

trace1 = pyro.poutine.trace(cond_model).get_trace(torch.tensor(1.), torch.tensor(0.))
trace2 = pyro.poutine.trace(cond_with).get_trace(torch.tensor(1.), torch.tensor(0.))

trace1.log_prob_sum(), trace2.log_prob_sum()

(tensor(-2.7726), tensor(-2.7726))

In [64]:
dist.Beta(torch.tensor(1.),torch.tensor(1.)).log_prob(torch.tensor(0.5))

tensor(0.)