In [None]:
import pyro
import torch
import pyro.optim as optim
import pyro.distributions as dist
import torch.distributions.constraints as constraints

from pyro.distributions.testing.fakes import NonreparameterizedBeta
from pyro.infer import SVI, TraceGraph_ELBO
from tqdm.auto import tqdm

assert pyro.__version__.startswith('1.9.1')

import sys
sys.path.append("../")
sys.path.append("../../")

sys.dont_write_bytecode = True

from vectorized_loop.ops import Index
import vectorized_loop as vec

import warnings
pyro.primitives.enable_validation(False)
warnings.filterwarnings('ignore')
pyro.set_rng_seed(20)

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_default_device(device)

max_steps = 10000

def param_abs_error(name, target):
    return torch.sum(torch.abs(target - pyro.param(name))).item()

In [2]:
class BernoulliBetaExample:

    def __init__(self, max_steps):
        self.max_steps = max_steps
        self.alpha0 = 10.0
        self.beta0 = 10.0
        self.data = torch.zeros(10)
        self.data[0:6] = torch.ones(6)
        self.n_data = self.data.size(0)
        self.alpha_n = self.data.sum() + self.alpha0
        self.beta_n = - self.data.sum() + torch.tensor(self.beta0 + self.n_data)
        self.alpha_q_0 = 15.0
        self.beta_q_0 = 15.0

    def model(self, use_decaying_avg_baseline):
        f = pyro.sample("latent_fairness", dist.Beta(self.alpha0, self.beta0))
        with pyro.plate("data_plate"):
            pyro.sample("obs", dist.Bernoulli(f), obs=self.data)

    def guide(self, use_decaying_avg_baseline):
        alpha_q = pyro.param("alpha_q", torch.tensor(self.alpha_q_0),
                             constraint=constraints.positive)
        beta_q = pyro.param("beta_q", torch.tensor(self.beta_q_0),
                            constraint=constraints.positive)
        baseline_dict = {'use_decaying_avg_baseline': use_decaying_avg_baseline,
                         'baseline_beta': 0.90}
        pyro.sample("latent_fairness", NonreparameterizedBeta(alpha_q, beta_q),
                    infer=dict(baseline=baseline_dict))

    def do_inference(self, use_decaying_avg_baseline, tolerance=0.80):
        pyro.clear_param_store()

        optimizer = optim.Adam({"lr": .0005, "betas": (0.93, 0.999)})
        svi = SVI(self.model, self.guide, optimizer, loss=TraceGraph_ELBO())
        print("Doing inference with use_decaying_avg_baseline=%s" % use_decaying_avg_baseline)

        for k in (epoch := tqdm(range(self.max_steps))):
            loss = svi.step(use_decaying_avg_baseline)
            epoch.set_description(f"Loss: {loss:.2f}")

            alpha_error = param_abs_error("alpha_q", self.alpha_n)
            beta_error = param_abs_error("beta_q", self.beta_n)

            if alpha_error < tolerance and beta_error < tolerance:
                break

        print("\nDid %d steps of inference." % k)
        print(("Final absolute errors for the two variational parameters " +
               "were %.4f & %.4f") % (alpha_error, beta_error))

bbe = BernoulliBetaExample(max_steps=max_steps)
bbe.do_inference(use_decaying_avg_baseline=True)
bbe.do_inference(use_decaying_avg_baseline=False)

Doing inference with use_decaying_avg_baseline=True


  0%|          | 0/10000 [00:00<?, ?it/s]


Did 156 steps of inference.
Final absolute errors for the two variational parameters were 0.7986 & 0.7642
Doing inference with use_decaying_avg_baseline=False


  0%|          | 0/10000 [00:00<?, ?it/s]


Did 253 steps of inference.
Final absolute errors for the two variational parameters were 0.7990 & 0.7468


In [3]:
class BernoulliBetaExample:
    
    def __init__(self, max_steps):
        self.max_steps = max_steps
        self.alpha0 = 10.0
        self.beta0 = 10.0
        self.data = torch.zeros(10)
        self.data[0:6] = torch.ones(6)
        self.n_data = self.data.size(0)
        self.alpha_n = self.data.sum() + self.alpha0
        self.beta_n = - self.data.sum() + torch.tensor(self.beta0 + self.n_data)
        self.alpha_q_0 = 15.0
        self.beta_q_0 = 15.0

    @vec.vectorize
    def model(s: vec.State, self, use_decaying_avg_baseline):
        s.f = pyro.sample("latent_fairness", dist.Beta(self.alpha0, self.beta0))
        for s.i in vec.range("data_plate", self.n_data, vectorized=True, device=device):
            pyro.sample("obs", dist.Bernoulli(s.f), obs=Index(self.data)[s.i])
                    
    def guide(self, use_decaying_avg_baseline):
        alpha_q = pyro.param("alpha_q", torch.tensor(self.alpha_q_0),
                             constraint=constraints.positive)
        beta_q = pyro.param("beta_q", torch.tensor(self.beta_q_0),
                            constraint=constraints.positive)
        baseline_dict = {'use_decaying_avg_baseline': use_decaying_avg_baseline,
                         'baseline_beta': 0.90}
        pyro.sample("latent_fairness", NonreparameterizedBeta(alpha_q, beta_q),
                    infer=dict(baseline=baseline_dict))

    def do_inference(self, use_decaying_avg_baseline, tolerance=0.80):
        pyro.clear_param_store()
        vec.clear_allocators()

        optimizer = optim.Adam({"lr": .0005, "betas": (0.93, 0.999)})
        svi = SVI(self.model, self.guide, optimizer, loss=vec.TraceGraph_ELBO())
        print("Doing inference with use_decaying_avg_baseline=%s" % use_decaying_avg_baseline)

        for k in (epoch := tqdm(range(self.max_steps))):
            loss = svi.step(use_decaying_avg_baseline)
            epoch.set_description(f"Loss: {loss:.2f}")

            alpha_error = param_abs_error("alpha_q", self.alpha_n)
            beta_error = param_abs_error("beta_q", self.beta_n)

            if alpha_error < tolerance and beta_error < tolerance:
                break

        print("\nDid %d steps of inference." % k)
        print(("Final absolute errors for the two variational parameters " +
               "were %.4f & %.4f") % (alpha_error, beta_error))

bbe = BernoulliBetaExample(max_steps=max_steps)
bbe.do_inference(use_decaying_avg_baseline=True)
bbe.do_inference(use_decaying_avg_baseline=False)

Doing inference with use_decaying_avg_baseline=True


  0%|          | 0/10000 [00:00<?, ?it/s]


Did 253 steps of inference.
Final absolute errors for the two variational parameters were 0.7436 & 0.7996
Doing inference with use_decaying_avg_baseline=False


  0%|          | 0/10000 [00:00<?, ?it/s]


Did 430 steps of inference.
Final absolute errors for the two variational parameters were 0.7993 & 0.7812
