## VAEs on MNIST

#### Imports

In [1]:
import os

import torch
import torch.utils.data as tdata
import torchvision
import torchvision.transforms as T

from vaes_ptorch import (
    CNN,
    DeCNN,
    GaussianModel,
    GaussianVAE,
    TrainArgs,
    get_mlp,
    train,
)
from vaes_ptorch.args import DivAnnealing
from vaes_ptorch.train_vae import evaluate
from vaes_ptorch.utils import show

#### Experiment parameters

In [24]:
use_gpu = True

latent_dim = 10

lr = 1e-3
batch_size = 128
num_epochs = 15

print_every = 50
eval_every = 1

# info_vae = True
info_vae = False
# start_scale = 50.0
# end_scale = 50.0
start_scale = 0.1
end_scale = 0.1
start_epochs = 0
linear_epochs = 0

in_channels = 1
kernel_sizes = [5, 5, 5,]
out_channels = [16, 32, 64,]
rev_out_channels = [16, 2]
rev_kernel_sizes = [5, 5]

device = "cuda" if torch.cuda.is_available() and use_gpu else "cpu"
device

'cuda'

#### Getting the training data

In [25]:
dataset = torchvision.datasets.MNIST(
    root=os.path.expanduser("~/vaes_ptorch/data"),
    train=True,
    download=True,
    transform=T.ToTensor(),
)

train_size = int(len(dataset) * 0.7)
eval_size = len(dataset) - train_size
train_data, eval_data = tdata.random_split(
    dataset,
    [train_size, eval_size],
    generator=torch.Generator().manual_seed(15),
)

train_loader = tdata.DataLoader(
    dataset=train_data, batch_size=batch_size, shuffle=True
)
eval_loader = tdata.DataLoader(
    dataset=eval_data, batch_size=batch_size, shuffle=True
)

test_set = torchvision.datasets.MNIST(
    root=os.path.expanduser("~/vaes_ptorch/data"),
    train=False,
    download=True,
    transform=T.ToTensor(),
)
test_loader = tdata.DataLoader(
    dataset=test_set, batch_size=batch_size, shuffle=True
)

#### Setting up the VAE model

In [26]:
encoder = GaussianModel(
    model=CNN(
        in_channels=1,
        out_channels=out_channels,
        kernel_sizes=kernel_sizes,
        bn=True,
        f_map_size=3,
        out_dim=latent_dim * 2,
    ),
    out_dim=latent_dim,
    min_var=1e-6,
)
decoder = GaussianModel(
    model=DeCNN(
        in_dim=latent_dim,
        f_map_size=7,
        channel_size=32,
        out_channels=rev_out_channels,
        kernel_sizes=rev_kernel_sizes,
        bn=True,
    ),
    out_dim=1,
    min_var=0.0,
    split_dim=1,
)
vae = GaussianVAE(encoder=encoder, decoder=decoder)
vae = vae.to(device)

#### Initializing the optimizer and training arguments

In [27]:
optimizer = torch.optim.Adam(params=vae.parameters(), lr=lr)

train_args = TrainArgs(
    info_vae=info_vae,
    num_epochs=num_epochs,
    div_annealing=DivAnnealing(
        start_epochs=start_epochs,
        linear_epochs=linear_epochs,
        start_scale=start_scale,
        end_scale=end_scale,
    ),
    print_every=print_every,
    eval_every=eval_every,
    smoothing=0.9,
)

### Training

In [None]:
train(
    train_data=train_loader,
    vae=vae,
    optimizer=optimizer,
    args=train_args,
    eval_data=eval_loader,
    device=device,
)

Step: 0 | Loss: 182.38063 | Div scale: 0.100
NLL: 151.17239 | KL: 312.08234
Step: 50 | Loss: 92.97064 | Div scale: 0.100
NLL: 85.81812 | KL: 15.60812
Step: 100 | Loss: 76.20986 | Div scale: 0.100
NLL: 71.40569 | KL: 16.44229
Step: 150 | Loss: 64.01343 | Div scale: 0.100
NLL: 58.72035 | KL: 22.93576
Step: 200 | Loss: 51.30028 | Div scale: 0.100
NLL: 47.11209 | KL: 24.04874
Step: 250 | Loss: 38.58748 | Div scale: 0.100
NLL: 33.78138 | KL: 24.81089
Step: 300 | Loss: 31.77712 | Div scale: 0.100
NLL: 28.60517 | KL: 24.00307
ELBO at the end of epoch #1 is 27.35907
Step: 350 | Loss: 29.12336 | Div scale: 0.100
NLL: 25.96612 | KL: 24.32596
Step: 400 | Loss: 25.81496 | Div scale: 0.100
NLL: 22.47495 | KL: 24.46352
Step: 450 | Loss: 22.52387 | Div scale: 0.100
NLL: 19.93489 | KL: 23.43153
Step: 500 | Loss: 21.76904 | Div scale: 0.100
NLL: 20.00658 | KL: 23.75232
Step: 550 | Loss: 21.85739 | Div scale: 0.100
NLL: 19.11955 | KL: 22.24253
Step: 600 | Loss: 21.90847 | Div scale: 0.100
NLL: 19.68378 

### Testing

In [None]:
evaluate(test_loader, vae, device=device)

In [None]:
input_images = next(iter(test_loader))[0][:16].to(device)
show(input_images)

In [None]:
reconstructed_images = vae(input_images)
show(reconstructed_images.mu_x)