In [1]:
from src.vae import VAE, VaeDecoderMnist, VaeEncoderMnist, negative_elbo_loss

import torch
import torch.nn as nn
import torch.optim as optim

from torchvision import datasets, transforms

import matplotlib.pyplot as plt

In [2]:
dataset = datasets.MNIST(root='../../data',
                         train=True, transform=transforms.ToTensor(), download=True)

In [3]:
batch_size = 128
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
latent_dim = 2

In [6]:
encoder = VaeEncoderMnist(latent_dim)
decoder = VaeDecoderMnist(latent_dim)

In [7]:
model = VAE(encoder, decoder).to(device)

In [8]:
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [9]:
num_epochs = 2000

In [12]:
def train_epoch(epoch: int,
                model: nn.Module,
                train_loader: torch.utils.data.DataLoader,
                optimizer: optim.Optimizer,
                device: torch.device):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.view(-1, 784).to(device)
        optimizer.zero_grad()

        data_reconstructed, mu, logvar = model(data)
        loss = negative_elbo_loss(data_reconstructed, data, mu, logvar)
        loss.backward()

        train_loss += loss.item()
        optimizer.step()

    print('====> Epoch: {} Average loss: {:.8f}'.format(epoch, train_loss / len(train_loader.dataset)))

In [14]:
for epoch in range(1, num_epochs + 1):
    train_epoch(epoch, model, data_loader, optimizer, device)

====> Epoch: 1 Average loss: 188.93189795
====> Epoch: 2 Average loss: 160.30879575


KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), '../../models/mnist_vae.pth')