<img src="../../assets/AE.png" style="float:right;height:200px">

# **V**ariational **A**uto **E**ncoder

This notebook tests the implementation of a variational autoencoder.

This model is one part of the **WorldModel** included in the *WorlRewardModel* which is responisble to encode the observation <br> for the **MD-RNN model** so it can predict the next observation. This notbook tests the **VAE** using the MNIST dataset.


In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
import matplotlib
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from AutoEncoder import *
from utils import save_reconstructed_images, image_to_vid, save_loss_plot
matplotlib.style.use('ggplot')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
# initialize the model
model = ConvVAE().to(device)
# set the learning parameters
lr = 0.001
epochs = 50
batch_size = 128
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.BCELoss(reduction='sum')
# a list to save all the reconstructed images in PyTorch grid format
grid_images = []

In [3]:
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
])
# training set and train data loader
trainset = torchvision.datasets.MNIST(
    root='../input', train=True, download=True, transform=transform
)
trainloader = DataLoader(
    trainset, batch_size=batch_size, shuffle=True
)
# validation set and validation data loader
testset = torchvision.datasets.MNIST(
    root='../input', train=False, download=True, transform=transform
)
testloader = DataLoader(
    testset, batch_size=batch_size, shuffle=False
)

In [4]:
trainset[0][0].shape

torch.Size([1, 32, 32])

In [5]:
train_loss = []
valid_loss = []
for epoch in range(epochs):
    print(f"Epoch {epoch+1} of {epochs}")
    train_epoch_loss = model.trainStep(
        trainloader, trainset, optimizer, criterion
    )
    valid_epoch_loss, recon_images = model.validate(
        testloader, testset, criterion
    )
    train_loss.append(train_epoch_loss)
    valid_loss.append(valid_epoch_loss)
    # save the reconstructed images from the validation loop
    save_reconstructed_images(recon_images, epoch+1)
    # convert the reconstructed images to PyTorch image grid format
    image_grid = make_grid(recon_images.detach().cpu())
    grid_images.append(image_grid)
    print(f"Train Loss: {train_epoch_loss:.4f}")
    print(f"Val Loss: {valid_epoch_loss:.4f}")
    
# save the reconstructions as a .gif file
image_to_vid(grid_images)
# save the loss plots to disk
save_loss_plot(train_loss, valid_loss)
print('TRAINING COMPLETE')

Epoch 1 of 50


  2%|▏         | 8/468 [00:00<00:23, 19.20it/s]

torch.Size([128, 64, 1, 1])
torch.Size([128, 64])
torch.Size([128, 64, 4, 4])
torch.Size([128, 32, 8, 8])
torch.Size([128, 16, 16, 16])
torch.Size([128, 1, 32, 32])
torch.Size([128, 1, 32, 32]) torch.Size([128, 1, 32, 32])
torch.Size([128, 64, 1, 1])
torch.Size([128, 64])
torch.Size([128, 64, 4, 4])
torch.Size([128, 32, 8, 8])
torch.Size([128, 16, 16, 16])
torch.Size([128, 1, 32, 32])
torch.Size([128, 1, 32, 32]) torch.Size([128, 1, 32, 32])
torch.Size([128, 64, 1, 1])
torch.Size([128, 64])
torch.Size([128, 64, 4, 4])
torch.Size([128, 32, 8, 8])
torch.Size([128, 16, 16, 16])
torch.Size([128, 1, 32, 32])
torch.Size([128, 1, 32, 32]) torch.Size([128, 1, 32, 32])
torch.Size([128, 64, 1, 1])
torch.Size([128, 64])
torch.Size([128, 64, 4, 4])
torch.Size([128, 32, 8, 8])
torch.Size([128, 16, 16, 16])
torch.Size([128, 1, 32, 32])
torch.Size([128, 1, 32, 32]) torch.Size([128, 1, 32, 32])
torch.Size([128, 64, 1, 1])
torch.Size([128, 64])
torch.Size([128, 64, 4, 4])
torch.Size([128, 32, 8, 8])
to

  5%|▍         | 22/468 [00:00<00:11, 40.14it/s]

torch.Size([128, 64, 1, 1])
torch.Size([128, 64])
torch.Size([128, 64, 4, 4])
torch.Size([128, 32, 8, 8])
torch.Size([128, 16, 16, 16])
torch.Size([128, 1, 32, 32])
torch.Size([128, 1, 32, 32]) torch.Size([128, 1, 32, 32])
torch.Size([128, 64, 1, 1])
torch.Size([128, 64])
torch.Size([128, 64, 4, 4])
torch.Size([128, 32, 8, 8])
torch.Size([128, 16, 16, 16])
torch.Size([128, 1, 32, 32])
torch.Size([128, 1, 32, 32]) torch.Size([128, 1, 32, 32])
torch.Size([128, 64, 1, 1])
torch.Size([128, 64])
torch.Size([128, 64, 4, 4])
torch.Size([128, 32, 8, 8])
torch.Size([128, 16, 16, 16])
torch.Size([128, 1, 32, 32])
torch.Size([128, 1, 32, 32]) torch.Size([128, 1, 32, 32])
torch.Size([128, 64, 1, 1])
torch.Size([128, 64])
torch.Size([128, 64, 4, 4])
torch.Size([128, 32, 8, 8])
torch.Size([128, 16, 16, 16])
torch.Size([128, 1, 32, 32])
torch.Size([128, 1, 32, 32]) torch.Size([128, 1, 32, 32])
torch.Size([128, 64, 1, 1])
torch.Size([128, 64])
torch.Size([128, 64, 4, 4])
torch.Size([128, 32, 8, 8])
to

  7%|▋         | 35/468 [00:00<00:08, 50.04it/s]

torch.Size([128, 64, 1, 1])
torch.Size([128, 64])
torch.Size([128, 64, 4, 4])
torch.Size([128, 32, 8, 8])
torch.Size([128, 16, 16, 16])
torch.Size([128, 1, 32, 32])
torch.Size([128, 1, 32, 32]) torch.Size([128, 1, 32, 32])
torch.Size([128, 64, 1, 1])
torch.Size([128, 64])
torch.Size([128, 64, 4, 4])
torch.Size([128, 32, 8, 8])
torch.Size([128, 16, 16, 16])
torch.Size([128, 1, 32, 32])
torch.Size([128, 1, 32, 32]) torch.Size([128, 1, 32, 32])
torch.Size([128, 64, 1, 1])
torch.Size([128, 64])
torch.Size([128, 64, 4, 4])
torch.Size([128, 32, 8, 8])
torch.Size([128, 16, 16, 16])
torch.Size([128, 1, 32, 32])
torch.Size([128, 1, 32, 32]) torch.Size([128, 1, 32, 32])
torch.Size([128, 64, 1, 1])
torch.Size([128, 64])
torch.Size([128, 64, 4, 4])
torch.Size([128, 32, 8, 8])
torch.Size([128, 16, 16, 16])
torch.Size([128, 1, 32, 32])
torch.Size([128, 1, 32, 32]) torch.Size([128, 1, 32, 32])
torch.Size([128, 64, 1, 1])
torch.Size([128, 64])
torch.Size([128, 64, 4, 4])
torch.Size([128, 32, 8, 8])
to

 10%|█         | 49/468 [00:01<00:10, 40.71it/s]

torch.Size([128, 64, 1, 1])
torch.Size([128, 64])
torch.Size([128, 64, 4, 4])
torch.Size([128, 32, 8, 8])
torch.Size([128, 16, 16, 16])
torch.Size([128, 1, 32, 32])
torch.Size([128, 1, 32, 32]) torch.Size([128, 1, 32, 32])
torch.Size([128, 64, 1, 1])
torch.Size([128, 64])
torch.Size([128, 64, 4, 4])
torch.Size([128, 32, 8, 8])
torch.Size([128, 16, 16, 16])
torch.Size([128, 1, 32, 32])
torch.Size([128, 1, 32, 32]) torch.Size([128, 1, 32, 32])
torch.Size([128, 64, 1, 1])
torch.Size([128, 64])
torch.Size([128, 64, 4, 4])
torch.Size([128, 32, 8, 8])
torch.Size([128, 16, 16, 16])
torch.Size([128, 1, 32, 32])
torch.Size([128, 1, 32, 32]) torch.Size([128, 1, 32, 32])
torch.Size([128, 64, 1, 1])
torch.Size([128, 64])
torch.Size([128, 64, 4, 4])
torch.Size([128, 32, 8, 8])
torch.Size([128, 16, 16, 16])
torch.Size([128, 1, 32, 32])
torch.Size([128, 1, 32, 32]) torch.Size([128, 1, 32, 32])
torch.Size([128, 64, 1, 1])
torch.Size([128, 64])
torch.Size([128, 64, 4, 4])
torch.Size([128, 32, 8, 8])
to




KeyboardInterrupt: 