## VAEs on MNIST

#### Imports

In [1]:
import os

import torch
import torch.utils.data as tdata
import torchvision
import torchvision.transforms as T

from vaes_ptorch import (
    CNN,
    DeCNN,
    GaussianModel,
    GaussianVAE,
    TrainArgs,
    get_mlp,
    train,
)
from vaes_ptorch.args import DivAnnealing
from vaes_ptorch.train_vae import evaluate
from vaes_ptorch.utils import show

#### Experiment parameters

In [39]:
use_gpu = True

latent_dim = 120

lr = 1e-3
batch_size = 128
num_epochs = 3

print_every = 50
eval_every = 1

info_vae = True
# info_vae = False
# start_scale = 50.0
# end_scale = 50.0
start_scale = 1.0
end_scale = 1.0
start_epochs = 0
linear_epochs = 0

in_channels = 1
kernel_sizes = [
    5,
    5,
    5,
    5,
]
downsampling = [
    True,  # 14
    False,  # 14
    True,  # 7
    False,  # 7
]
out_channels = [
    8,
    8,
    16,
    16,
]
rev_out_channels = [
    16,
    8,
    8,
    2,
]
rev_downsampling = [
    False,
    True,
    False,
    True,
]
rev_kernel_sizes = [
    5,
    5,
    5,
    5,
]

device = "cuda" if torch.cuda.is_available() and use_gpu else "cpu"
device

'cpu'

#### Getting the training data

In [40]:
dataset = torchvision.datasets.MNIST(
    root=os.path.expanduser("~/vaes_ptorch/data"),
    train=True,
    download=True,
    transform=T.ToTensor(),
)

train_size = int(len(dataset) * 0.7)
eval_size = len(dataset) - train_size
train_data, eval_data = tdata.random_split(
    dataset,
    [train_size, eval_size],
    generator=torch.Generator().manual_seed(15),
)

train_loader = tdata.DataLoader(
    dataset=train_data, batch_size=batch_size, shuffle=True
)
eval_loader = tdata.DataLoader(
    dataset=eval_data, batch_size=batch_size, shuffle=True
)

test_set = torchvision.datasets.MNIST(
    root=os.path.expanduser("~/vaes_ptorch/data"),
    train=False,
    download=True,
    transform=T.ToTensor(),
)
test_loader = tdata.DataLoader(
    dataset=test_set, batch_size=batch_size, shuffle=True
)

#### Setting up the VAE model

In [41]:
encoder = GaussianModel(
    model=CNN(
        in_channels=1,
        out_channels=out_channels,
        kernel_sizes=kernel_sizes,
        downsampling=downsampling,
        bn=True,
        f_map_size=7,
        out_dim=latent_dim * 2,
    ),
    out_dim=latent_dim,
    min_var=1e-10,
)
decoder = GaussianModel(
    model=DeCNN(
        in_dim=latent_dim,
        f_map_size=7,
        channel_size=16,
        out_channels=rev_out_channels,
        kernel_sizes=rev_kernel_sizes,
        downsampling=rev_downsampling,
        bn=True,
    ),
    out_dim=1,
    min_var=0.0,
    split_dim=1,
)
vae = GaussianVAE(encoder=encoder, decoder=decoder)
vae = vae.to(device)

#### Initializing the optimizer and training arguments

In [42]:
optimizer = torch.optim.Adam(params=vae.parameters(), lr=lr)

train_args = TrainArgs(
    info_vae=info_vae,
    num_epochs=num_epochs,
    div_annealing=DivAnnealing(
        start_epochs=start_epochs,
        linear_epochs=linear_epochs,
        start_scale=start_scale,
        end_scale=end_scale,
    ),
    print_every=print_every,
    eval_every=eval_every,
    smoothing=0.9,
)

### Training

In [None]:
train(
    train_data=train_loader,
    vae=vae,
    optimizer=optimizer,
    args=train_args,
    eval_data=eval_loader,
    device=device,
)

Step: 0 | Loss: 0.27889 | Div scale: 1.000
NLL: 0.22497 | MMD-div: 0.05392
Step: 50 | Loss: 0.17079 | Div scale: 1.000
NLL: 0.14498 | MMD-div: 0.01666
Step: 100 | Loss: 0.13520 | Div scale: 1.000
NLL: 0.11376 | MMD-div: 0.01549
Step: 150 | Loss: 0.10655 | Div scale: 1.000
NLL: 0.08765 | MMD-div: 0.01664
Step: 200 | Loss: 0.07750 | Div scale: 1.000
NLL: 0.05475 | MMD-div: 0.01565
Step: 250 | Loss: 0.05712 | Div scale: 1.000
NLL: 0.03814 | MMD-div: 0.01653
Step: 300 | Loss: 0.04639 | Div scale: 1.000
NLL: 0.02830 | MMD-div: 0.01667
ELBO at the end of epoch #1 is 0.02645
Step: 350 | Loss: 0.04124 | Div scale: 1.000
NLL: 0.02425 | MMD-div: 0.01534
Step: 400 | Loss: 0.03764 | Div scale: 1.000
NLL: 0.02125 | MMD-div: 0.01562
Step: 450 | Loss: 0.03471 | Div scale: 1.000
NLL: 0.01915 | MMD-div: 0.01529
Step: 500 | Loss: 0.03109 | Div scale: 1.000
NLL: 0.01518 | MMD-div: 0.01473
Step: 550 | Loss: 0.02976 | Div scale: 1.000
NLL: 0.01436 | MMD-div: 0.01622
Step: 600 | Loss: 0.02908 | Div scale: 1

### Testing

In [None]:
evaluate(test_loader, vae, device=device)

In [None]:
input_images = next(iter(test_loader))[0][:16].to(device)
show(input_images)

In [None]:
reconstructed_images = vae(input_images)
show(reconstructed_images.mu_x)