In [1]:
import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline  

import torch
import torch.nn as nn
import torch.optim as optim 
import torch.nn.functional as F

import itertools

from copy import deepcopy

import optimization_lib as optim_lib
import rao_blackwellization_lib as rb_lib
from toy_experiment_lib import ToyExperiment

In [2]:
np.random.seed(454)
_ = torch.manual_seed(454)

# Fixed parameters

In [3]:
k = 5
p0 = torch.rand(k)

In [4]:
eta = torch.Tensor([5.])
toy_experiment = ToyExperiment(eta, p0)

# Get true gradient 

In [5]:
toy_experiment.set_parameter(eta)
full_loss = toy_experiment.get_full_loss()

In [6]:
full_loss.backward()

In [7]:
true_grad = toy_experiment.eta.grad

In [8]:
true_grad

tensor(1.00000e-02 *
       [-8.1110])

In [13]:
toy_experiment.set_parameter(eta)
logits = toy_experiment.get_log_q()

In [14]:
logits

tensor([[-4.1586, -4.3019, -2.3413, -1.6019, -0.3958]])

In [15]:
gumbel_softmax_sample(logits, temperature = 100)

tensor([[ 0.1855,  0.1996,  0.2036,  0.2025,  0.2088]])

In [16]:
gumbel_softmax(logits, temperature = 100000)

tensor([[ 0.,  0.,  0.,  0.,  1.]])

In [17]:
n_samples = 10000
grads = torch.Tensor(n_samples)

for i in range(n_samples): 
    toy_experiment.set_parameter(eta)
    logits = toy_experiment.get_log_q()
        
    seq_tensor = torch.LongTensor([i for i in range(k)]).float()
    z = (gumbel_softmax(logits, temperature = 10000) * seq_tensor).sum()
    
    loss = toy_experiment.get_f_z(z = z)

    loss.backward()
    
    grads[i] = toy_experiment.eta.grad

In [18]:
grads

tensor(1.00000e-04 *
       [ 0.8180,  0.8181,  0.8181,  ..., -0.8180,  0.8181,  0.8181])

In [19]:
grads.mean()

tensor(1.00000e-05 *
       5.4133)

In [20]:
grads.std() / np.sqrt(n_samples) * 3

tensor(1.00000e-06 *
       1.4873)