## VAEs on MNIST - Experiments

#### Imports

In [1]:
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as tdata
import torchvision
import torchvision.transforms as T

from vaes_ptorch import GaussianModel, GaussianVAE, TrainArgs, get_mlp, train
from vaes_ptorch.args import DivAnnealing
from vaes_ptorch.losses import Likelihood
from vaes_ptorch.train_vae import evaluate

#### Experiment parameters

In [9]:
num_repeats = 2
learning_rates = 10 ** np.linspace(start=-4.0, stop=-2.0, num=2)
divergence_scales = 10 ** np.linspace(start=-2.0, stop=3.0, num=2)

info_vae = False

latent_dim = 12

num_epochs = 3
batch_size = 128
eval_share = 0.7

base_args = TrainArgs(
    likelihood=Likelihood.Bernoulli,
    info_vae=info_vae,
    num_epochs=num_epochs,
    div_annealing=DivAnnealing(
        start_epochs=1, linear_epochs=1, start_scale=0.0, end_scale=0.0,
    ),
    print_every=0,
    eval_every=1,
    smoothing=0.9,
)

use_gpu = True
device = "cuda" if torch.cuda.is_available() and use_gpu else "cpu"
device

'cpu'

#### Helpers

In [10]:
def binarize(x):
    """Converts grayscale pixel values in [0, 1] to binary data in {0, 1}."""
    tensor = T.ToTensor()(x)
    mask = tensor > 0.5
    tensor[mask] = 1.0
    tensor[~mask] = 0.0
    return tensor


def get_data(batch_size: int, eval_share: float):
    dataset = torchvision.datasets.MNIST(
        root=os.path.expanduser("~/vaes_ptorch/data"),
        train=True,
        download=True,
        transform=binarize,
    )
    train_size = int(len(dataset) * eval_share)
    eval_size = len(dataset) - train_size
    train_data, eval_data = tdata.random_split(
        dataset,
        [train_size, eval_size],
        generator=torch.Generator().manual_seed(15),
    )
    train_loader = tdata.DataLoader(
        dataset=train_data, batch_size=batch_size, shuffle=True
    )
    eval_loader = tdata.DataLoader(
        dataset=eval_data, batch_size=batch_size, shuffle=True
    )
    test_set = torchvision.datasets.MNIST(
        root=os.path.expanduser("~/vaes_ptorch/data"),
        train=False,
        download=True,
        transform=binarize,
    )
    test_loader = tdata.DataLoader(
        dataset=test_set, batch_size=batch_size, shuffle=True
    )
    print(f"Train size: {train_size}, Eval size: {eval_size}")
    return train_loader, eval_loader, test_loader


def build_vae(device, latent_dim):
    encoder = GaussianModel(
        model=nn.Sequential(
            nn.Flatten(),
            get_mlp(in_dim=28 * 28, out_dim=2 * latent_dim, h_dims=[512] * 3),
        ),
        out_dim=latent_dim,
        min_var=1e-10,
    )
    decoder = GaussianModel(
        model=nn.Sequential(
            get_mlp(in_dim=latent_dim, out_dim=2 * 28 * 28, h_dims=[512] * 3),
            nn.Unflatten(1, (2, 28, 28)),
        ),
        out_dim=1,
        split_dim=1,
    )
    vae = GaussianVAE(encoder=encoder, decoder=decoder)
    vae = vae.to(device)
    return vae


def get_params(scale: float):
    annealing_params = vars(base_args.div_annealing)
    annealing_params["end_scale"] = scale
    del annealing_params["epoch"]
    params = vars(base_args)
    params["div_annealing"] = DivAnnealing(**annealing_params)
    args = TrainArgs(**params)
    return args

### Experiment loop

In [None]:
errors = []
train_loader, eval_loader, test_loader = get_data(
    batch_size=batch_size, eval_share=eval_share
)
for scale in divergence_scales:
    min_error = float("inf")
    for lr in learning_rates:
        args = get_params(scale)
        eval_errors = []
        for _ in range(num_repeats):
            vae = build_vae(device, latent_dim)
            optimizer = torch.optim.Adam(params=vae.parameters(), lr=lr)
            eval_errors.append(
                train(
                    train_data=train_loader,
                    vae=vae,
                    optimizer=optimizer,
                    args=args,
                    eval_data=eval_loader,
                    device=device,
                ).eval_ewma
            )
        avg_error = sum(eval_errors) / len(eval_errors)
        if avg_error < min_error:
            min_error = avg_error
            best_vae = vae
            best_args = args
    error = evaluate(test_loader, best_vae, args=best_args, device=device)
    errors.append(error)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Train size: 42000, Eval size: 18000
ELBO at the end of epoch #1 is 165.61446
ELBO at the end of epoch #2 is 133.95690
ELBO at the end of epoch #3 is 114.16067
ELBO at the end of epoch #1 is 171.84254
ELBO at the end of epoch #2 is 126.70070
ELBO at the end of epoch #3 is 111.06838
ELBO at the end of epoch #1 is 137.56675
ELBO at the end of epoch #2 is 157.35011
ELBO at the end of epoch #3 is 146.31366
ELBO at the end of epoch #1 is 175.53107
ELBO at the end of epoch #2 is 166.69402
ELBO at the end of epoch #3 is 166.55007
ELBO at the end of epoch #1 is 170.74951
ELBO at the end of epoch #2 is 206.93230
ELBO at the end of epoch #3 is 206.46890
ELBO at the end of epoch #1 is 207.08602
ELBO at the end of epoch #2 is 206.52991
ELBO at the end of epoch #3 is 206.37646
ELBO at the end of epoch #1 is 124.14252
ELBO at the end of epoch #2 is 207.23155
ELBO at the end of epoch #3 is 206.55199
ELBO at the end of epoch #1 is 206.62069
