# 05 SVI Part III: ELBO Gradient Estimators

$$ \text{REINFORCE} = \mathbb{E}_{q_\phi} [f_\phi (\bold{z}) \nabla_\phi \log q_\phi (\bold{z}) + \nabla_\phi f_\phi (\bold{z})] $$

In [6]:
import os
import torch
import torch.distributions.constraints as C
import pyro
import pyro.distributions as D
# Pyro also has a reparameterized Beta distribution so we import
# the non-reparameterized version to make our point
from pyro.distributions.testing.fakes import NonreparameterizedBeta
import pyro.optim as optim
from pyro.infer import SVI, TraceGraph_ELBO

from tqdm import trange

In [4]:
def param_abs_error(param, target):
    return torch.abs(pyro.param(param) - target).item()

In [39]:
class BernoulliBetaExample:
    def __init__(self, max_steps):
        self.max_steps = max_steps
        self.alpha0 = 10.
        self.beta0 = 10. 
        
        self.data = torch.zeros(10)
        self.data[0:6] = torch.ones(6)
        self.n = 10
        self.alpha = self.data.sum() + self.alpha0 
        self.beta = - self.data.sum() + self.beta0 + self.n 
        
        self.alphaq0 = 15. 
        self.betaq0 = 15. 
        
    def model(self, decaying_base):
        f = pyro.sample('fairness', D.Beta(self.alpha0, self.beta0))
        
        with pyro.plate('data', self.n):
            pyro.sample('obs', D.Bernoulli(f), obs=self.data)
        
    def guide(self, decaying_base):
        alpha = pyro.param('alpha', torch.tensor(self.alphaq0),
                            constraint=C.positive)
        beta = pyro.param('beta', torch.tensor(self.beta0),
                           constraint=C.positive)
        baseline_d = {'use_decaying_avg_baseline': decaying_base,
                    'baseline_beta': 0.90}
        pyro.sample('fairness', NonreparameterizedBeta(alpha, beta),
                    infer=dict(baseline=baseline_d))
        
    def inference(self, decaying_base, tol=0.80):
        pyro.clear_param_store()
        optimizer = pyro.optim.Adam({'lr': 0.005, 'betas': (0.93, 0.999)})
        svi = SVI(self.model, self.guide, optimizer, loss=TraceGraph_ELBO())
        print("Doing inference with use_decaying_avg_baseline = %s" % decaying_base)
        
        for k in trange(self.max_steps):
            svi.step(decaying_base)
            
            alpha_e = param_abs_error('alpha', self.alpha)
            beta_e = param_abs_error('beta', self.beta)
            
            if alpha_e < tol and beta_e < tol:
                break 
            
        print(f'Did {k+1} steps')
        print(f'Final errors: alpha: {alpha_e:.4f}, beta: {beta_e:.4f}')
        

In [42]:
bbe = BernoulliBetaExample(10000)

In [43]:
bbe.inference(False)

Doing inference with use_decaying_avg_baseline = False


 90%|█████████ | 9034/10000 [00:48<00:05, 186.57it/s]

Did 9035 steps
Final errors: alpha: 0.7989, beta: 0.5866





In [44]:
bbe.inference(True)

Doing inference with use_decaying_avg_baseline = True


 19%|█▉        | 1937/10000 [00:11<00:48, 165.02it/s]

Did 1938 steps
Final errors: alpha: 0.7906, beta: 0.7996



