# Comparing MC gradient estimators on probabilistic objectives

In [None]:
import torch
import matplotlib.pyplot as plt

from mc_estimators.measure_valued_gradient import MVD
from mc_estimators.reinforce_gradient import Reinforce
from mc_estimators.pathwise_gradient import Pathwise


# === EXPERIMENT PARAMETERS ===
seed = 4243254
episodes = 5000
episode_size = 100
learning_rate = 1e-3
optimizer_class = torch.optim.SGD
loss_fn = lambda x: x**2

torch.manual_seed(seed)

# Generate initial parameters
init_mean = torch.randn(1, requires_grad=False)
init_cov = 1


def train(estimator):
    mean = torch.nn.Linear(1, 1, bias=False)
    with torch.no_grad():
        mean.weight.copy_(init_mean.clone())
    
    optimizer = optimizer_class(mean.parameters(), lr=learning_rate)
    
    means = [init_mean]
    covs = [init_cov]
    
    x = torch.ones(1)
    for episode in range(episodes):
        optimizer.zero_grad()
        estimator(torch.empty(0), loss_fn, (mean(x), init_cov))[0].mean(dim=0).backward()
        optimizer.step()
        
        with torch.no_grad():
            means.append(mean(x))
            covs.append(init_cov)
    
    return means, covs


estimators = {
    # Measure-Valued Gradient
    "MVD": MVD(int((episode_size + 1) / 2), 1, coupled=False),
    # Log Gradient
    "REINFORCE": Reinforce(episode_size, torch.distributions.Normal), 
    # Reparameterization-Trick Gradient
    "Pathwise": Pathwise(episode_size, torch.distributions.Normal) 
}

# Train and save results
results = {name: train(estimator) 
           for i, (name, estimator) in enumerate(estimators.items())}

In [2]:
import torch

import models.vae
import train.vae
from mc_estimators.reinforce_gradient import Reinforce


# === EXPERIMENT PARAMETERS ===
seed = 4243254
epochs = 20
episode_size = 1000
learning_rate = 1e-2
latent_dim = 20
hidden_dim = 200
optimizer_class = torch.optim.Adam

torch.manual_seed(seed)

def pp(mean):
    return mean.squeeze(), .1 * torch.eye(latent_dim)

# Load the data
data_holder = train.vae.DataHolder()
data_holder.load_datasets()

# Train the Variational Auto Encoder
data_dim = data_holder.height * data_holder.width
encoder = models.vae.Encoder(data_dim, hidden_dim, (latent_dim,), post_processor=pp)
decoder = models.vae.Decoder(data_dim, hidden_dim, latent_dim)
vae_model = models.vae.VAE(encoder, decoder, Reinforce(episode_size, torch.distributions.MultivariateNormal))

vae = train.vae.VAE(vae_model, data_holder, torch.optim.Adam)
vae.train(epochs)

torch.save(vae_model, f'results/s{seed}_e{epochs}_es{episode_size}_lr{learning_rate}_z{latent_dim}_h{hidden_dim}.pt')

===> Epoch: 1/20, Avg Train Loss: 2.245
===> Epoch: 2/20, Avg Train Loss: 2.193
===> Epoch: 3/20, Avg Train Loss: 2.184
===> Epoch: 4/20, Avg Train Loss: 2.182
===> Epoch: 5/20, Avg Train Loss: 2.183
===> Epoch: 6/20, Avg Train Loss: 2.184
===> Epoch: 7/20, Avg Train Loss: 2.184
===> Epoch: 8/20, Avg Train Loss: 2.184
===> Epoch: 9/20, Avg Train Loss: 2.184
===> Epoch: 10/20, Avg Train Loss: 2.183
===> Epoch: 11/20, Avg Train Loss: 2.182
===> Epoch: 12/20, Avg Train Loss: 2.182
===> Epoch: 13/20, Avg Train Loss: 2.181
===> Epoch: 14/20, Avg Train Loss: 2.180
===> Epoch: 15/20, Avg Train Loss: 2.181
===> Epoch: 16/20, Avg Train Loss: 2.181
===> Epoch: 17/20, Avg Train Loss: 2.183
===> Epoch: 18/20, Avg Train Loss: 2.183
===> Epoch: 19/20, Avg Train Loss: 2.180
===> Epoch: 20/20, Avg Train Loss: 2.181


In [None]:
import matplotlib.pyplot as plt


fig, axes = plt.subplots(nrows=len(results), ncols=2, figsize=(15, 4*len(results)))

for i, (name, (means, covs)) in enumerate(results.items()):
    axis = axes[i] if len(results) > 1 else axes
    axis[0].plot(means)
    axis[0].set_title(f'Mean ({name})')
    axis[1].plot(covs)
    axis[1].set_title(f'Covariance ({name})')

fig.tight_layout()
plt.show()

In [21]:
from torchvision.utils import save_image
import os


vae_model.eval()

for i, (x_batch, _) in enumerate(data_holder.train_holder):
    # Compare images from the test set with generated ones.
    x_batch = next(iter(data_holder.train_holder))[0]
    x_pred_batch = decoder(vae_model.probabilistic(x_batch, None, encoder(x_batch)))
    n = min(x_batch.size(0), 8)
    comparison = torch.cat((x_batch[:n],
                            x_pred_batch.view(data_holder.train_holder.batch_size, 1,
                                              data_holder.height, data_holder.width)[:n]))
    save_image(comparison, f'results/recon{i}.png', nrow=n)