In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.datasets as Datasets
import torchvision.transforms as transforms
import torch.nn.functional as F
import torchvision.models as models
import torchvision.utils as vutils
from torch.hub import load_state_dict_from_url

import os
import random
import numpy as np
import math
from IPython.display import clear_output
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import trange, tqdm

import Helpers as hf
from RES_VAE_Dynamic import VAE
from vgg19 import VGG19

In [None]:
batch_size = 64
image_size = 64
lr = 1e-4
nepoch = 100
start_epoch = 0
dataset_root = "/media/luke/Quick Storage/Data"
save_dir = os.getcwd()
model_name = "STL10"
load_checkpoint  = False

In [None]:
use_cuda = torch.cuda.is_available()
gpu_indx  = 0
device = torch.device(gpu_indx if use_cuda else "cpu")

In [None]:
def get_data_STL10(transform, batch_size, download = True, root = "/data"):
    print("Loading trainset...")
    trainset = Datasets.STL10(root=root, split='unlabeled', transform=transform, download=download)
    
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
    
    print("Loading testset...")
    testset = Datasets.STL10(root=root, split='test', download=download, transform=transform)

    testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
    print("Done!")

    return trainloader, testloader

In [None]:
#OLD way of getting features and calculating loss - Not used

# #create an empty layer that will simply record the feature map passed to it.
# class GetFeatures(nn.Module):
#     def __init__(self):
#         super(GetFeatures, self).__init__()
#         self.features = None
#     def forward(self, x):
#         self.features = x
#         return x

# #download the pre-trained weights of the VGG-19 and append them to an array of layers .
# #we insert a GetFeatures layer after a relu layer.
# #layers_deep controls how deep we go into the network
# def get_feature_extractor(layers_deep = 7):
#     C_net = models.vgg19(pretrained=True).to(device)
#     C_net = C_net.eval()
    
#     layers = []
#     for i in range(layers_deep):
#         layers.append(C_net.features[i])
#         if isinstance(C_net.features[i], nn.ReLU):
#             layers.append(GetFeatures())
#     return nn.Sequential(*layers)

# #this function calculates the L2 loss (MSE) on the feature maps copied by the layers_deep
# #between the reconstructed image and the origional
# def feature_loss(img, recon_data, feature_extractor):
#     img_cat = torch.cat((img, torch.sigmoid(recon_data)), 0)
#     out = feature_extractor(img_cat)
#     loss = 0
#     for i in range(len(feature_extractor)):
#         if isinstance(feature_extractor[i], GetFeatures):
#             loss += (feature_extractor[i].features[:(img.shape[0])] - feature_extractor[i].features[(img.shape[0]):]).pow(2).mean()
#     return loss/(i+1)


In [None]:
transform = transforms.Compose([transforms.Resize(image_size),
                                transforms.CenterCrop(image_size),
                                transforms.RandomHorizontalFlip(0.5),
                                transforms.ToTensor(),
                                transforms.Normalize(0.5, 0.5)])

trainloader, testloader = get_data_STL10(transform, batch_size, download=False, root=dataset_root)

In [None]:
#get a test image batch from the testloader to visualise the reconstruction quality
dataiter = iter(testloader)
test_images, _ = dataiter.next()
test_images.shape

In [None]:
plt.figure(figsize = (20,10))
out = vutils.make_grid(test_images[0:8], normalize=True)
plt.imshow(out.numpy().transpose((1, 2, 0)))

In [None]:
# Create the feature loss module
feature_extractor = VGG19().to(device)

In [None]:
#Create VAE network
vae_net = VAE(channel_in=3, ch=64, blocks=(1, 2, 4, 8), latent_channels=512).to(device)
# setup optimizer
optimizer = optim.Adam(vae_net.parameters(), lr=lr, betas=(0.5, 0.999))
#Loss function
loss_log = []

In [None]:
#Create the save directory if it does note exist
if not os.path.isdir(save_dir + "/Models"):
    os.makedirs(save_dir + "/Models")
if not os.path.isdir(save_dir + "/Results"):
    os.makedirs(save_dir + "/Results")

if load_checkpoint:
    checkpoint = torch.load(save_dir + "/Models/" + model_name + "_" + str(image_size) + ".pt", map_location = "cpu")
    print("Checkpoint loaded")
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    vae_net.load_state_dict(checkpoint['model_state_dict'])
    start_epoch = checkpoint["epoch"]
    loss_log = checkpoint["loss_log"]
else:
    #If checkpoint does exist raise an error to prevent accidental overwriting
    if os.path.isfile(save_dir + "/Models/" + model_name + "_" + str(image_size) + ".pt"):
        raise ValueError("Warning Checkpoint exists")
    else:
        print("Starting from scratch")

In [None]:
for epoch in trange(start_epoch, nepoch, leave=False):
    vae_net.train()
    for i, (images, _) in enumerate(tqdm(trainloader, leave=False)):
        images = images.to(device)

        recon_img, mu, logvar = vae_net(images)
        #VAE loss
        kl_loss = hf.kl_loss(mu, logvar)
        mse_loss = F.mse_loss(recon_img, images)
        
        #Perception loss
        feat_in = torch.cat((recon_img, images), 0)
        feature_loss = feature_extractor(feat_in)
        
        loss = kl_loss + mse_loss + feature_loss
    
        loss_log.append(loss.item())
        vae_net.zero_grad()
        loss.backward()
        optimizer.step()

    #In eval mode the model will use mu as the encoding instead of sampling from the distribution
    vae_net.eval()
    with torch.no_grad():
        recon_img, _, _ = vae_net(test_images.to(device))
        img_cat = torch.cat((recon_img.cpu(), test_images), 2)
        
        vutils.save_image(img_cat,
                          "%s/%s/%s_%d.png" % (save_dir, "Results" , model_name, image_size),
                          normalize=True)

        #Save a checkpoint
        torch.save({
                    'epoch'                         : epoch,
                    'loss_log'                      : loss_log,
                    'model_state_dict'              : vae_net.state_dict(),
                    'optimizer_state_dict'          : optimizer.state_dict()

                     }, save_dir + "/Models/" + model_name + "_" + str(image_size) + ".pt")  