In [None]:
import autorootcwd  # Find project root and change working directory
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from mpl_toolkits.axes_grid1 import ImageGrid

from src.data.mnist_datamodule import MNISTDataModule
from src.models.vae_components.vanilla_vae import VAE
from src.models.vae_module import VAEModule

%load_ext autoreload
%autoreload 2

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dm = MNISTDataModule(num_workers=0, transform="default")
model = VAE(input_dim=784, hidden_dim=400, latent_dim=200)

# pl_module = VAEModule(model, lr=0.001)
# Implement training routine myself
dm.setup("fit")
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
# Test dataloader
dataiter = iter(dm.train_dataloader())
batch = next(dataiter)

num_samples = 25
sample_images = [batch[0][i, 0] for i in range(num_samples)]

fig = plt.figure(figsize=(5, 5))
grid = ImageGrid(fig, 111, nrows_ncols=(5, 5), axes_pad=0.1)

for ax, im in zip(grid, sample_images):
    ax.imshow(im, cmap="gray")
    ax.axis("off")

plt.show()

In [None]:
# Define training routine


def loss_function(x, x_hat, mean, log_var):
    # Loss function from https://github.com/AntixK/PyTorch-VAE/blob/master/models/vanilla_vae.py
    recons_loss = nn.functional.mse_loss(x_hat, x)
    kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mean**2 - log_var.exp(), dim=1), dim=0)

    return recons_loss + kld_loss


def train(model, optimizer, epochs, device):
    model.train()
    for epoch in range(epochs):
        overall_loss = 0
        for batch_idx, (x, _) in enumerate(dm.train_dataloader()):
            N, C, H, W = x.size()
            x = x.reshape(N, -1)
            x = x.to(device)

            optimizer.zero_grad()

            x_hat, mean, log_var = model(x)
            loss = loss_function(x, x_hat, mean, log_var)

            overall_loss += loss.item()

            loss.backward()
            optimizer.step()

        print("\tEpoch", epoch + 1, "\tAverage Loss: ", overall_loss / (batch_idx * N))
    return overall_loss


train(model, optimizer, epochs=10, device=device)

In [None]:
model