In [None]:
import sys
import math
import torch

sys.dont_write_bytecode = True

import pyro
import pyro.primitives
import pyro.distributions as dist
import pyro.distributions.constraints as constraints

from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from tqdm.auto import tqdm

import sys
sys.path.append("../")
sys.path.append("../../")

import warnings
pyro.primitives.enable_validation(False)
warnings.filterwarnings('ignore')
pyro.set_rng_seed(20)

from vectorized_loop.ops import Index
import vectorized_loop as vec
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_default_device(device)

In [2]:
def model_baseline_1(data):
    alpha0 = torch.tensor(10.0)
    beta0 = torch.tensor(10.0)
    f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
    for i in range(len(data)):
        pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

def model_baseline_2(data):
    alpha0 = torch.tensor(10.0)
    beta0 = torch.tensor(10.0)
    f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
    with pyro.plate("data", len(data)):
        pyro.sample("obs", dist.Bernoulli(f), obs=data)

@vec.vectorize
def model_ours(s: vec.State, data):
    alpha0 = torch.tensor(10.0)
    beta0 = torch.tensor(10.0)
    s.f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
    for s.i in vec.range("data", len(data), vectorized=True, device=device):
        pyro.sample("obs", dist.Bernoulli(s.f), obs=Index(data)[s.i])

def guide(data):
    alpha_q = pyro.param("alpha_q", torch.tensor(15.0),
                         constraint=constraints.positive)
    beta_q = pyro.param("beta_q", torch.tensor(15.0),
                        constraint=constraints.positive)
    pyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))

def run(model, elbo):
    pyro.clear_param_store()
    vec.clear_allocators()

    n_step = 50
    data = torch.cat([torch.ones(100), torch.zeros(400)])
    
    adam_params = {"lr": 0.0005, "betas": (0.90, 0.999)}
    optimizer = Adam(adam_params)
    svi = SVI(model, guide, optimizer, elbo)

    for step in tqdm(range(n_step)):
        svi.step(data)
        if step % 100 == 0:
            print('.', end='')

    alpha_q = pyro.param("alpha_q").item()
    beta_q = pyro.param("beta_q").item()

    inferred_mean = alpha_q / (alpha_q + beta_q)
    factor = beta_q / (alpha_q * (1.0 + alpha_q + beta_q))
    inferred_std = inferred_mean * math.sqrt(factor)

    print("\nBased on the data and our prior belief, the fairness " +
          "of the coin is %.3f +- %.3f" % (inferred_mean, inferred_std))

print("Running Baseline_1")
run(model_baseline_1, Trace_ELBO())

print("Running Baseline_2")
run(model_baseline_2, Trace_ELBO())

print("Running Ours")
run(model_ours, vec.Trace_ELBO())

Running Baseline_1


  0%|          | 0/50 [00:00<?, ?it/s]

.
Based on the data and our prior belief, the fairness of the coin is 0.488 +- 0.090
Running Baseline_2


  0%|          | 0/50 [00:00<?, ?it/s]

.
Based on the data and our prior belief, the fairness of the coin is 0.488 +- 0.090
Running Ours


  0%|          | 0/50 [00:00<?, ?it/s]

.
Based on the data and our prior belief, the fairness of the coin is 0.488 +- 0.090
