In [None]:
from src.vae import VAE, VaeDecoderMnist, VaeEncoderMnist, negative_elbo_loss

import torch
import torch.nn as nn
import torch.optim as optim

from torchvision import datasets, transforms

import matplotlib.pyplot as plt

In [None]:
dataset = datasets.MNIST(root='../../data',
                         train=True, transform=transforms.ToTensor(), download=True)

In [None]:
batch_size = 128
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
latent_dim = 2

In [None]:
model = VAE(VaeEncoderMnist(latent_dim),
            VaeDecoderMnist(latent_dim)).to(device)

In [None]:
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [None]:
num_epochs = 2000

In [None]:
def train_epoch(epoch: int,
                model: nn.Module,
                train_loader: torch.utils.data.DataLoader,
                optimizer: optim.Optimizer,
                device: torch.device):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.view(-1, 784).to(device)
        optimizer.zero_grad()

        data_reconstructed, mu, logvar = model(data)
        loss = negative_elbo_loss(data_reconstructed, data, mu, logvar)
        loss.backward()

        train_loss += loss.item()
        optimizer.step()

    print('====> Epoch: {} Average loss: {:.8f}'.format(epoch, train_loss / len(train_loader.dataset)))

In [None]:
for epoch in range(1, num_epochs + 1):
    train_epoch(epoch, model, data_loader, optimizer, device)

In [None]:
torch.save(model.state_dict(), '../../models/mnist_vae.pth')