In [1]:
# TODO: make this into a unnittest, and also find an example that runs faster ... 

In this notebook, we examine with the variances of the gradient estimators described in bernoulli_optimization_experiments.py

In [2]:
import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline  

import torch
import torch.nn as nn

import sys
sys.path.insert(0, './../libraries/')

import partial_marginalization_lib as pm_lib
import bernoulli_experiments_lib as bern_lib
import importance_sample_lib as imp_lib

import torch.optim as optim

from copy import deepcopy

from torch.distributions import Categorical

import itertools

In [3]:
softmax = nn.Softmax(dim = 0)

sigmoid = nn.Sigmoid()

In [4]:
np.random.seed(454)
_ = torch.manual_seed(454)

In [5]:
# fixed parameters
d = 2
# p0 = torch.rand(d)
p0 = torch.Tensor([0.6, 0.51])
print('p0: ', p0, '\n')

print('sum(p0^2): ', torch.sum(p0**2))
print('sum((1 - p0)^2): ', torch.sum((1 - p0)**2), '\n')

# the optima
x_optimal = torch.argmin(torch.Tensor([torch.sum(p0**2), torch.sum((1 - p0)**2)]))

optimal_loss = torch.min(torch.Tensor([torch.sum(p0**2), torch.sum((1 - p0)**2)]))

print('optimal loss: ', optimal_loss)
print('optimal x: ', x_optimal.numpy())

p0:  tensor([ 0.6000,  0.5100]) 

sum(p0^2):  tensor(0.6201)
sum((1 - p0)^2):  tensor(0.4001) 

optimal loss:  tensor(0.4001)
optimal x:  1


In [6]:
# random init for phi
phi0 = torch.Tensor([0.0])
phi0.requires_grad_(True)
print('init phi0: ', phi0)
print('init e_b: ', sigmoid(phi0))

init phi0:  tensor([ 0.])
init e_b:  tensor([ 0.5000])


In [7]:
params = [phi0]
optimizer = optim.SGD(params, lr = 1.0)

# True gradient

In [8]:
bern_experiment = bern_lib.BernoulliExperiments(p0, d, phi0)

In [9]:
bern_experiment.set_var_params(deepcopy(phi0))

In [10]:
optimizer.zero_grad()

In [11]:
loss = bern_experiment.get_full_loss()

In [12]:
loss.backward()

In [13]:
true_grad = deepcopy(bern_experiment.var_params['phi'].grad)
print(true_grad)

tensor(1.00000e-02 *
       [-5.5000])


In [14]:
bern_experiment.set_var_params(deepcopy(phi0))
optimizer.zero_grad()

importance_weights = None # torch.rand(log_q.shape)
ps_loss = imp_lib.get_importance_sampled_loss(bern_experiment.f_z, 
                                              log_q = bern_experiment.get_log_q(), 
                                importance_weights = importance_weights,
                                use_baseline = True)

In [15]:
ps_loss.backward()

In [16]:
bern_experiment.var_params['phi'].grad

tensor([ 0.])

In [17]:
def sample_bern_gradient(phi0, bern_experiment, importance_weights,
                            use_baseline = True,
                            n_samples = 10000):
    params = [phi0]
    optimizer = optim.SGD(params, lr = 1.0)
    
    grad_array = torch.zeros(n_samples)

    for i in range(n_samples):
        bern_experiment.set_var_params(deepcopy(phi0))
        optimizer.zero_grad()
        ps_loss = imp_lib.get_importance_sampled_loss(bern_experiment.f_z, bern_experiment.get_log_q(),
                                importance_weights = importance_weights,
                                use_baseline = use_baseline)
        ps_loss.backward()

        grad_array[i] = bern_experiment.var_params['phi'].grad

    return grad_array


In [18]:
n_samples = 10000
reinforce_grads = sample_bern_gradient(phi0, bern_experiment, importance_weights,
                            use_baseline = True,
                            n_samples = n_samples)

In [19]:
print('true_grad: ', true_grad.numpy())
print('mean reinforce grad: ', torch.mean(reinforce_grads).numpy())
print('variance: ', torch.var(reinforce_grads).numpy())

print('scaled error: ', (torch.std(reinforce_grads) / np.sqrt(n_samples) * 3).numpy())

true_grad:  [-0.05500001]
mean reinforce grad:  -0.05547201
variance:  0.008126713
scaled error:  0.0027044485


In [20]:
importance_weights_ = torch.rand((1, d**2))
importance_weights = importance_weights_ / importance_weights_.sum(dim=1, keepdim=True)

ps_loss = imp_lib.get_importance_sampled_loss(bern_experiment.f_z, 
                                              log_q = bern_experiment.get_log_q(), 
                                importance_weights = importance_weights,
                                use_baseline = True)

In [21]:
n_samples = 100000
reinforce_grads = sample_bern_gradient(phi0, bern_experiment, importance_weights,
                            use_baseline = True,
                            n_samples = n_samples)

In [22]:
print('true_grad: ', true_grad.numpy())
print('mean reinforce grad: ', torch.mean(reinforce_grads).numpy())
print('variance: ', torch.var(reinforce_grads).numpy())

print('scaled error: ', (torch.std(reinforce_grads) / np.sqrt(n_samples) * 3).numpy())

true_grad:  [-0.05500001]
mean reinforce grad:  -0.054811023
variance:  0.0061285673
scaled error:  0.00074267836
