In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as Datasets
from torch.utils.data import DataLoader
import torchvision.transforms as T
import torch.nn.functional as F
import torchvision.models as models
import torchvision.utils as vutils

import random
import numpy as np
import math
from IPython.display import clear_output
import matplotlib.pyplot as plt

from RES_VAE import VAE as AE


In [None]:
batchSize = 32
imageSize = 64
lr = 1e-4
nepoch = 100
root = "/data"

In [None]:
use_cuda = torch.cuda.is_available()
GPU_indx  = 0
device = torch.device(GPU_indx if use_cuda else "cpu")

In [None]:
def get_data_STL10(transform, batch_size, download = True, root = "/data"):
    print("Loading trainset...")
    trainset = Datasets.STL10(root=root, split='unlabeled', transform=transform, download=download)
    
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=8)
    
    print("Loading testset...")
    testset = Datasets.STL10(root=root, split='test', download=download, transform=transform)

    testloader = DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=8)
    print("Done!")

    return trainloader, testloader

def vae_loss(recon, x, mu, logvar):
    recon_loss = F.binary_cross_entropy_with_logits(recon, x)
    KL_loss = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).mean()
    loss = recon_loss + 0.01 * KL_loss
    return loss

# Feature extractor <br>
Using a pre-tranined VGG-16 we insert an empty layer (after a relu) to capture the feature maps

In [None]:
#create an empty layer that will simply record the feature map passed to it.
class GetFeatures(nn.Module):
    def __init__(self):
        super(GetFeatures, self).__init__()
        self.features = None
    def forward(self, x):
        self.features = x
        return x

#download the pre-trained weights of the VGG-16 and append them to an array of layers .
#we insert a layers_deep layer after a relu layer.
#layers_deep controls how deep we go into the network
def get_feature_extractor(layers_deep = 7):
    C_net = models.vgg16(pretrained=True).to(device)
    C_net = C_net.eval()
    
    layers = []
    for i in range(layers_deep):
        layers.append(C_net.features[i])
        if isinstance(C_net.features[i], nn.ReLU):
            layers.append(GetFeatures())
    return nn.Sequential(*layers)

#this function calculates the L2 loss (MSE) on the feature maps copied by the layers_deep
#between the reconstructed image and the origional
def feature_loss(img, recon_data, Features):
    img_cat = torch.cat((img, torch.sigmoid(recon_data)), 0)
    out = Features(img_cat)
    loss = 0
    c = 0
    for i in range(len(Features)):
        if isinstance(Features[i], GetFeatures):
            loss += (Features[i].features[:(img.shape[0])] - Features[i].features[(img.shape[0]):]).pow(2).mean()
            c+=1
    return loss/c

In [None]:
def lr_Linear(epoch_max, epoch, lr):
    lr_adj = ((epoch_max-epoch)/epoch_max)*lr
    set_lr(lr = lr_adj)

def set_lr(lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [None]:
transform = T.Compose([T.Resize(imageSize), T.ToTensor()])

trainloader, testloader = get_data_STL10(transform, batchSize, download = True, root = root)

In [None]:
#get a test image batch from the testloader to visualise the reconstruction quality
dataiter = iter(testloader)
test_images = dataiter.next()[0]
test_images.shape

In [None]:
plt.figure(figsize = (20,10))
out = vutils.make_grid(test_images[0:8])
plt.imshow(out.numpy().transpose((1, 2, 0)))

In [None]:
vae_net = AE(channel_in = 3).to(device)
Features = get_feature_extractor()
# setup optimizer
optimizer = optim.Adam(vae_net.parameters(), lr=lr, betas=(0.5, 0.999))
BCE_Loss = nn.BCEWithLogitsLoss()

In [None]:
loss_log = []
lowest_test_loss = 1000

In [None]:
for epoch in range(nepoch):
    lr_Linear(nepoch, epoch, lr)
    for i, data in enumerate(trainloader, 0):

        recon_data, mu, logvar = vae_net(data[0].to(device))
        
        loss = vae_loss(recon_data, data[0].to(device), mu, logvar)
        
        loss_feature = feature_loss(data[0].to(device), recon_data, Features)

        loss += loss_feature
        
        loss_log.append(loss.item())
        vae_net.zero_grad()
        loss.backward()
        optimizer.step()

        clear_output(True)
        print('Epoch: [%d/%d], Itteration: [%d/%d] loss: %.4f' 
              % (epoch, nepoch, i, len(trainloader), loss.item()))
        
    with torch.no_grad():
        recon_data, _, _ = vae_net(test_images.to(device), Train = False)
        test_loss = BCE_Loss(recon_data, test_images.to(device))
        
    if test_loss < lowest_test_loss:
        lowest_test_loss = test_loss
        torch.save(vae_net.state_dict(), "Models/VAE_STL10" + str(imageSize) +".pt" )
        vutils.save_image(torch.cat((torch.sigmoid(recon_data.cpu()), test_images),2),"%s/VAE_%s_%d.png" % ("Results" , "VAE_STL10", imageSize))