# VAE with the CIFAR100 dataset
Link from fiskemad...
Load data

In [4]:
import torch
import torchvision
from torchvision import transforms
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.autograd import Variable
import numpy as np
from torch import optim

In [36]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 10

trainset = torchvision.datasets.CIFAR100(root='../data/datasetCIFAR100', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR100(root='../data/datasetCIFAR100', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = trainset.classes # or class_to_idx



Files already downloaded and verified
Files already downloaded and verified


## Define model and train

Models from [here](https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py) and VAE structure from here [git](https://github.com/Jackson-Kang/Pytorch-VAE-tutorial)

In [74]:
cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

class Encoder(nn.Module):
    
    def __init__(self, vgg_name, input_dim, latent_dim):

        super(Encoder, self).__init__()
        cfg[vgg_name] = [input_dim,*cfg[vgg_name]] # first layer always input size of image
        self.features = self._make_layers(cfg[vgg_name])
        self.FC_mean = nn.Linear(512, latent_dim)
        self.FC_var = nn.Linear(512, latent_dim)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1) # Flatten(?)
        mean = self.FC_mean(out)
        log_var = self.FC_var(out)
        return mean, log_var
      

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.GELU()] # changed from ReLU
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)

class Decoder(nn.Module):
    def __init__(self, vgg_name, latent_dim, output_dim):

        super(Decoder, self).__init__()
        cfg[vgg_name] = [latent_dim,*cfg[vgg_name]] # first layer always input size of image
        self.features = self._make_layers(cfg[vgg_name])
        self.FC_output = nn.Linear(cfg[vgg_name][-2], output_dim)
        
    def forward(self, x):
        out     = self.features(x)
        x_hat = torch.sigmoid(self.FC_output(out))
        
        return x_hat
    
    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.GELU()] # changed from ReLU
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)

class Model(nn.Module):
    def __init__(self, Encoder, Decoder):
        super(Model, self).__init__()
        self.Encoder = Encoder
        self.Decoder = Decoder
        
    def reparameterization(self, mean, var):
        epsilon = torch.randn_like(var).to(DEVICE)        # sampling epsilon        
        z = mean + var*epsilon                          # reparameterization trick
        return z
        
                
    def forward(self, x):
        mean, log_var = self.Encoder(x)
        z = self.reparameterization(mean, torch.exp(0.5 * log_var)) # takes exponential function (log var -> var)
        x_hat            = self.Decoder(z)
        
        return x_hat, mean, log_var

In [75]:

data_size = testset[0][0].shape[1] #Fixed
latent_dim = 10 # hyperparameter


encoder = Encoder('VGG19',  input_dim=data_size,     latent_dim=latent_dim)
decoder = Decoder('VGG19',  latent_dim=latent_dim,   output_dim = data_size)

model = Model(Encoder=encoder, Decoder=decoder).to(DEVICE)

## Traning
In CIFAR100. First define loss function

In [65]:

lr = 1e-3
epochs = 30

BCE_loss = nn.BCELoss()

def loss_function(x, x_hat, mean, log_var):
    reproduction_loss = nn.functional.binary_cross_entropy(x_hat, x, reduction='sum')
    KLD      = - 0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp())

    return reproduction_loss + KLD


optimizer = optim.Adam(model.parameters(), lr= 0.001)

Test of output from the encoder and decoder:

In [84]:

x = torch.randn(2,3,32,32)
print(f"size of input{x.size()}")
# Encoder test
mean, var = encoder(x)
print(f"The mean shape {mean.size()}, \nthe variance shape {var.size()}")

# Decoder
epsilon = torch.randn_like(var).to(DEVICE)        # sampling epsilon        
z = mean + var*epsilon  
print(z.size())
#x_hat = decoder(z)

# Model
#x_hat, mean, var = model(x)

size of inputtorch.Size([2, 3, 32, 32])
The mean shape torch.Size([2, 10]), 
the variance shape torch.Size([2, 10])
torch.Size([2, 10])


In [77]:
print("Start training VAE...")
model.train()

for epoch in range(epochs):
    overall_loss = 0
    for batch_idx, (x, _) in enumerate(trainloader):
        x = x.view(batch_size, data_size)
        x = x.to(DEVICE)

        optimizer.zero_grad()

        x_hat, mean, log_var = model(x)
        loss = loss_function(x, x_hat, mean, log_var)
        
        overall_loss += loss.item()
        
        loss.backward()
        optimizer.step()
        
    print("\tEpoch", epoch + 1, "complete!", "\tAverage Loss: ", overall_loss / (batch_idx*batch_size))
    
print("Finish!!")

Start training VAE...


RuntimeError: shape '[10, 32]' is invalid for input of size 30720