In [2]:
import pyro
import pyro.infer
import pyro.optim
from pyro import poutine
import torch
from torch.distributions import constraints
import pyro.distributions as dist
from tqdm import tqdm


pyro.set_rng_seed(1234)

In [3]:
class Experiment:

    def __init__(self, prior_params):
        self.prior_params = prior_params

    def _prior_dist(self):
        """
        P(\theta)
        :return:
        """
        return pyro.sample('theta', dist.Beta(**self.prior_params))

    def _likelihood(self):
        """
        P(x | \theta)
        :param theta:
        :return:
        """
        theta = self._prior_dist()
        return pyro.sample('like_sample', dist.Bernoulli(theta))

    def _posterior_dist(self, obs, obs_ind):
        """
        P(\theta | x)
        :param likelihood:
        :param obs:
        :param obs_ind: Index of the observed data (for naming)
        :return:
        """
        return pyro.condition(self._likelihood, data = {'obs_{}'.format(obs_ind): obs})

    def _model(self, observations):
        for i in range(len(observations)):
            self._posterior_dist(observations[i], i)


    def _variational_dist(self, observations):
        alpha_q = pyro.param('alpha_q',torch.tensor(15.0),
                         constraint=constraints.positive)
        beta_q= pyro.param('beta_q', torch.tensor(15.0),
                         constraint = constraints.positive)
        return pyro.sample('posteriori_weight', dist.Beta(alpha_q, beta_q))

    def _loss_func(self):
        return pyro.infer.Trace_ELBO()

    def _loss(self, model, guide, *args, **kwargs):
        guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)

        model_trace = poutine.trace(
            poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs)
        # construct the elbo loss function

        return -1 * (model_trace.log_prob_sum() - guide_trace.log_prob_sum())


    def run(self,
            obervations,
            num_iter: int):

        pyro.clear_param_store()

        optimizer = pyro.optim.Adam({"lr": 0.0005, "betas": (0.90, 0.999)})

        svi = pyro.infer.SVI(model = self._model,
                             guide = self._variational_dist,
                             optim = optimizer,
                             loss = self._loss)

        for t in tqdm(range(num_iter)):
            svi.step(obervations)
        return pyro.param('alpha_q').item(), pyro.param('beta_q').item()

In [4]:
exp = Experiment({'concentration1':10., 'concentration0': 10.},)
observations = []
for _ in range(6):
    observations.append(torch.tensor(1.0))
for _ in range(4):
    observations.append(torch.tensor(0.0))

a, b = exp.run(observations, 5000)
print(a/(a+b))

100%|██████████| 5000/5000 [00:04<00:00, 1233.61it/s]

0.5024239189391284





In [None]:
guide_trace = poutine.trace(exp._variational_dist).get_trace(*args, **kwargs)

model_trace = poutine.trace(
    poutine.replay(exp._model, trace=guide_trace)).get_trace(*args, **kwargs)
# construct the elbo loss function

return -1 * (model_trace.log_prob_sum() - guide_trace.log_prob_sum())