In [1]:
%matplotlib notebook
import torch
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from torch.distributions import Normal

def plot_grad_mean_std(mu_supp, grads):
    mu = mu_supp.detach()
    fig, ax = plt.subplots(2, figsize=(6, 8))
    for key, grad in grads.items():
        grad_mean = grad.mean(0)
        grad_std = grad.std(0)
        ax[0].plot(mu, grad_mean, label=key)
        ax[0].fill_between(
            mu,
            grad_mean - grad_std,
            grad_mean + grad_std,
            alpha=0.2)
        
        ax[1].plot(mu, grad_std, label=key)
    
    ax[0].set_xlabel('$\mu$')
    ax[1].set_xlabel('$\mu$')
    ax[0].set_ylabel('Gradient mean')
    ax[1].set_ylabel('Gradient std')
    ax[0].legend()
    ax[0].grid()

Задача 

$$ \max_\mu \mathbb E_{q(\xi \mid \mu, 1)} \xi^2 $$

In [2]:
M = 1024
gradient_dict = {}

mu_supp = torch.linspace(-4, 4, 128).requires_grad_(True)
q_xi = Normal(loc=mu_supp, scale=torch.ones_like(mu_supp))

def f(x):
    return x ** 2

# REINFORCE

In [3]:
grads = torch.empty(M, 128)

for m in range(M):
    xi_samples = q_xi.sample()
    surrogate = (f(xi_samples) * q_xi.log_prob(xi_samples)).sum()
    surrogate.backward()
    
    grads[m] = mu_supp.grad
    mu_supp.grad.zero_()

gradient_dict['REINFORCE'] = grads
plot_grad_mean_std(mu_supp, gradient_dict)

<IPython.core.display.Javascript object>

# REINFORCE with $K$ samples

In [4]:
grads = torch.empty(M, 128)
K = 8

for m in range(M):
    xi_samples = q_xi.sample((K, ))
    surrogate = (f(xi_samples) * q_xi.log_prob(xi_samples)).mean(0).sum()
    surrogate.backward()
    
    grads[m] = mu_supp.grad
    mu_supp.grad.zero_()

gradient_dict['REINFORCE $K={}$'.format(K)] = grads
plot_grad_mean_std(mu_supp, gradient_dict)

<IPython.core.display.Javascript object>

# VIMCO

In [5]:
grads = torch.empty(M, 128)
K = 8

for m in range(M):
    xi_samples = q_xi.sample((K,))
    f_xi = f(xi_samples)
    diff = (f_xi - (f_xi.sum(0) - f_xi) / (K - 1))
    surrogate = (diff * q_xi.log_prob(xi_samples)).mean(0).sum()
    surrogate.backward()
    
    grads[m] = mu_supp.grad
    mu_supp.grad.zero_()

gradient_dict['VIMCO $K={}$'.format(K)] = grads
plot_grad_mean_std(mu_supp, gradient_dict)

<IPython.core.display.Javascript object>

# REINFORCE + RANDOM BASELINE

In [6]:
# delete everything
gradient_dict = {'REINFORCE': gradient_dict['REINFORCE']}

grads = torch.empty(M, 128)

for m in range(M):
    xi_samples = q_xi.sample()
    diff = f(xi_samples) - 10
    surrogate = (diff * q_xi.log_prob(xi_samples)).sum()
    surrogate.backward()
    
    grads[m] = mu_supp.grad
    mu_supp.grad.zero_()

gradient_dict['REINFORCE + random score baseline'] = grads
plot_grad_mean_std(mu_supp, gradient_dict)

<IPython.core.display.Javascript object>

# Optimal Constant Baseline

In [7]:
grads = torch.empty(M, 128)

# estimate cov / var
with torch.no_grad():
    K = 512
    xi_samples = q_xi.sample((K,))
    f_xi = f(xi_samples)
    cov = (f_xi * (mu_supp - xi_samples) ** 2).mean(0)
    var = (mu_supp - xi_samples).var(0)
    cov_var = cov / var

for m in range(M):
    xi_samples = q_xi.sample()
    diff = (f(xi_samples) - cov_var)
    surrogate = (diff * q_xi.log_prob(xi_samples)).sum()
    surrogate.backward()
    
    grads[m] = mu_supp.grad
    mu_supp.grad.zero_()

gradient_dict['REINFORCE + optimal score baseline'] = grads
plot_grad_mean_std(mu_supp, gradient_dict)

<IPython.core.display.Javascript object>

# Reparametrization

In [8]:
grads = torch.empty(M, 128)

for m in range(M):
    xi_samples = q_xi.rsample()
    f_xi = f(xi_samples)
    surrogate = f_xi.sum()
    surrogate.backward()
    
    grads[m] = mu_supp.grad
    mu_supp.grad.zero_()
    
gradient_dict['Reparametrization'] = grads
plot_grad_mean_std(mu_supp, gradient_dict)

<IPython.core.display.Javascript object>