# Variational Auto Encoder

This notbook tests the **VAE** using the actual observations from a gym env.

This model is one part of the **WorldModel** included in the *WorlRewardModel* which is responisble to encode the observation <br> for the **MD-RNN model** so it can predict the next observation. 

In [1]:
import pickle
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision.transforms as T
import matplotlib
from torchvision.utils import make_grid
from AutoEncoder import *
from utils import ImageDataset
from torch.utils.data import Dataset, DataLoader
from utils import save_reconstructed_images, image_to_vid, save_loss_plot
import matplotlib.pyplot as plt

from tqdm.notebook import tnrange

matplotlib.style.use('ggplot')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
# initialize the model
model = ConvVAE().to(device)
# set the learning parameters
lr = 0.001
epochs = 400
batch_size = 64
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.BCELoss(reduction='sum')
# a list to save all the reconstructed images in PyTorch grid format
grid_images = []

In [3]:
dataset = ImageDataset("out/carRacing_cleaned.pickle")

trainloader = DataLoader(
    dataset, batch_size=batch_size, shuffle=True
)

In [None]:
%matplotlib inline
from IPython.display import clear_output

to_pil_image = T.ToPILImage()

train_loss = []
valid_loss = []

for epoch in tnrange(epochs, desc="Epoche"):
    train_epoch_loss = model.trainStep(
        trainloader, dataset, optimizer, criterion
    )
    valid_epoch_loss, recon_images = model.validate(
        trainloader, dataset, criterion
    )
    train_loss.append(train_epoch_loss)
    valid_loss.append(train_epoch_loss)
    # save the reconstructed images from the validation loop
    save_reconstructed_images(recon_images, epoch+1)
    # convert the reconstructed images to PyTorch image grid format
    image_grid = make_grid(recon_images.detach().cpu())
    grid_images.append(image_grid)
    
# save the reconstructions as a .gif file
image_to_vid(grid_images)
# save the loss plots to disk
save_loss_plot(train_loss, valid_loss)
print('TRAINING COMPLETE')

Epoche:   0%|          | 0/400 [00:00<?, ?it/s]