## VAEs on MNIST

In [1]:
import os

import torch
import torch.utils.data as tdata
import torchvision
import torchvision.transforms as T

from vaes_ptorch import (
    CNN,
    DeCNN,
    GaussianModel,
    GaussianVAE,
    TrainArgs,
    get_mlp,
    train,
)
from vaes_ptorch.args import DivAnnealing

In [2]:
dataset = torchvision.datasets.MNIST(
    root=os.path.expanduser("~/vaes_ptorch/data"),
    train=True,
    download=True,
    transform=T.ToTensor(),
)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [3]:
latent_dim = 2

lr = 1e-3
batch_size = 128
num_epochs = 2

print_every = 1

# info_vae = True
info_vae = False
# start_scale = 0.005
# end_scale = 0.005
start_scale = 1.0
end_scale = 1.0
start_epochs = 0
linear_epochs = 0

in_channels = 1
kernel_sizes = [5, 5]
out_channels = [16, 32]
rev_out_channels = [16, 2]

In [4]:
dataloader = tdata.DataLoader(
    dataset=dataset, batch_size=batch_size, shuffle=True
)

In [7]:
encoder = GaussianModel(
    model=CNN(
        in_channels=1,
        out_channels=out_channels,
        kernel_sizes=kernel_sizes,
        bn=True,
        f_map_size=7,
        out_dim=latent_dim * 2,
    ),
    out_dim=latent_dim,
    min_var=1e-2,
)
decoder = GaussianModel(
    model=DeCNN(
        in_dim=latent_dim,
        f_map_size=7,
        channel_size=32,
        out_channels=rev_out_channels,
        kernel_sizes=kernel_sizes,
        bn=True,
    ),
    out_dim=1,
    min_var=0.0,
    split_dim=1,
)
vae = GaussianVAE(encoder=encoder, decoder=decoder)
optimizer = torch.optim.Adam(params=vae.parameters(), lr=lr)
train_args = TrainArgs(
    info_vae=info_vae,
    num_epochs=num_epochs,
    div_annealing=DivAnnealing(
        start_epochs=start_epochs,
        linear_epochs=linear_epochs,
        start_scale=start_scale,
        end_scale=end_scale,
    ),
    print_every=print_every,
    smoothing=0.9,
)

In [8]:
train(data=dataloader, vae=vae, optimizer=optimizer, args=train_args)

Step: 0 | Loss: 50.14255 | Div scale: 1.000
NLL: 0.15880 | KL: 49.98375
Step: 1 | Loss: 65.20184 | Div scale: 1.000
NLL: 0.15404 | KL: 200.58148
Step: 2 | Loss: 63.74879 | Div scale: 1.000
NLL: 0.15173 | KL: 50.51955
Step: 3 | Loss: 60.43594 | Div scale: 1.000
NLL: 0.14350 | KL: 30.47680
Step: 4 | Loss: 61.96195 | Div scale: 1.000
NLL: 0.14011 | KL: 75.55598
Step: 5 | Loss: 62.88298 | Div scale: 1.000
NLL: 0.13578 | KL: 71.03645
Step: 6 | Loss: 59.55508 | Div scale: 1.000
NLL: 0.13870 | KL: 29.46522
Step: 7 | Loss: 54.84325 | Div scale: 1.000
NLL: 0.13236 | KL: 12.30449
Step: 8 | Loss: 51.66627 | Div scale: 1.000
NLL: 0.12932 | KL: 22.94410
Step: 9 | Loss: 50.86641 | Div scale: 1.000
NLL: 0.12807 | KL: 43.53960
Step: 10 | Loss: 49.42676 | Div scale: 1.000
NLL: 0.12549 | KL: 36.34443
Step: 11 | Loss: 46.67629 | Div scale: 1.000
NLL: 0.12682 | KL: 21.79520
Step: 12 | Loss: 42.82414 | Div scale: 1.000
NLL: 0.12359 | KL: 8.03126
Step: 13 | Loss: 39.51317 | Div scale: 1.000
NLL: 0.12114 | K