In [5]:
import os

import torch
import torch.utils.data as tdata
import torchvision
import torchvision.transforms as T

from vaes_ptorch import (
    CNN,
    GaussianModel,
    GaussianVAE,
    TrainArgs,
    get_mlp,
    train,
)
from vaes_ptorch.args import DivAnnealing

In [6]:
dataset = torchvision.datasets.MNIST(
    root=os.path.expanduser("~/vaes_ptorch/data"),
    train=True,
    download=True,
    transform=T.ToTensor(),
)

In [7]:
data_dim = 2
latent_dim = 1

h_size = 128
h_layers = 5

lr = 1e-3
batch_size = 256
num_epochs = 500

print_every = 100

# info_vae = True
info_vae = False
start_scale = 0.005
end_scale = 0.005
# start_scale = 1.0
# end_scale = 1.0
start_epochs = 0
linear_epochs = 0

In [4]:
dataloader = tdata.DataLoader(
    dataset=dataset, batch_size=batch_size, shuffle=True
)

In [10]:
encoder = GaussianModel(
    model=get_mlp(
        in_dim=data_dim, out_dim=latent_dim, h_dims=[h_size] * h_layers,
    ),
    out_dim=latent_dim,
    min_var=1e-2,
)
decoder = GaussianModel(
    model=get_mlp(
        in_dim=latent_dim, out_dim=data_dim, h_dims=[h_size] * h_layers,
    ),
    out_dim=data_dim,
    min_var=0.0,
)
vae = GaussianVAE(encoder=encoder, decoder=decoder)
optimizer = torch.optim.Adam(params=vae.parameters(), lr=lr)
train_args = TrainArgs(
    info_vae=info_vae,
    num_epochs=num_epochs,
    div_annealing=DivAnnealing(
        start_epochs=start_epochs,
        linear_epochs=linear_epochs,
        start_scale=start_scale,
        end_scale=end_scale,
    ),
    print_every=print_every,
    smoothing=0.9,
)

In [11]:
train(data=dataloader, vae=vae, optimizer=optimizer, args=train_args)

torch.Size([256, 1, 28, 28])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (7168x28 and 2x128)