# Comparing estimators in a VAE network

In [None]:
import torch

import models.vae
import train.vae
from mc_estimators.reinforce import Reinforce


# === EXPERIMENT PARAMETERS ===
seed = 4243254
epochs = 20
episode_size = 1000
learning_rate = 1e-2
latent_dim = 20
hidden_dim = 200
optimizer_class = torch.optim.Adam

torch.manual_seed(seed)

def pp(mean):
    # sigma = sigma.reshape((latent_dim, latent_dim))
    # sigma.T @ sigma
    return mean.squeeze(), .1 * torch.eye(latent_dim)

# Load the data
data_holder = train.vae.DataHolder()
data_holder.load_datasets()

# Train the Variational Auto Encoder
data_dim = data_holder.height * data_holder.width
encoder = models.vae.Encoder(data_dim, hidden_dim, (latent_dim,), post_processor=pp)
decoder = models.vae.Decoder(data_dim, hidden_dim, latent_dim)
vae_model = models.vae.VAE(encoder, decoder, Reinforce(episode_size, torch.distributions.MultivariateNormal))

vae = train.vae.VAE(vae_model, data_holder, torch.optim.Adam)
vae.train(epochs)

torch.save(vae_model, f'results/s{seed}_e{epochs}_es{episode_size}_lr{learning_rate}_z{latent_dim}_h{hidden_dim}.pt')

In [None]:
from torchvision.utils import save_image
import os


vae_model.eval()

for i, (x_batch, _) in enumerate(data_holder.train_holder):
    # Compare images from the test set with generated ones.
    x_batch = next(iter(data_holder.train_holder))[0]
    x_pred_batch = decoder(vae_model.probabilistic(x_batch, None, encoder(x_batch)))
    n = min(x_batch.size(0), 8)
    comparison = torch.cat((x_batch[:n],
                            x_pred_batch.view(data_holder.train_holder.batch_size, 1,
                                              data_holder.height, data_holder.width)[:n]))
    save_image(comparison, f'results/recon{i}.png', nrow=n)