# VAE with the CIFAR100 dataset
Link from fiskemad...
Load data

In [10]:
import torch
import torchvision
from torchvision import transforms
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.autograd import Variable
import numpy as np
from torch import optim

In [11]:
transform = transforms.Compose(
    [transforms.ToTensor()])

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 32 # If you are using GPU, you can set the batch size to be 2, 4, 8, 16, 32..., this makes the GPUs work more effciently!

trainset = torchvision.datasets.CIFAR100(root='../data/datasetCIFAR100', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR100(root='../data/datasetCIFAR100', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = trainset.classes # or class_to_idx

# Saving of model and checkpoints.
saveModelPath = "../trainedModels/VAE_CIFAR100.pth"
firstTrain = True


Files already downloaded and verified
Files already downloaded and verified


## Define model and train

Models from [here](https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py) and VAE structure from here [git](https://github.com/Jackson-Kang/Pytorch-VAE-tutorial)

In [12]:
cfg = {
    'VGG9': [64, 'M', 128, 'M', 256, 256, 'M', 512, 'M', 512, 'M'],
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

class Encoder(nn.Module):
    
    def __init__(self, vgg_name, input_dim, latent_dim):

        super(Encoder, self).__init__()
        cfg[vgg_name] = [input_dim,*cfg[vgg_name]] # first layer always input size of image
        self.features = self._make_layers(cfg[vgg_name])
        self.FC_mean = nn.Conv2d(cfg[vgg_name][-2], latent_dim, kernel_size=1)
        self.FC_var = nn.Conv2d(cfg[vgg_name][-2], latent_dim, kernel_size=1)

    def forward(self, x):
        out = self.features(x)
        #out = out.view(out.size(0), -1) # Flatten(?)
        mean = self.FC_mean(out)
        mean = mean.reshape(mean.size(0), mean.size(1), -1)
        mean = torch.mean(mean, dim=-1, keepdim=True).unsqueeze(-1) # b, latent, 1, 1

        log_var = self.FC_var(out)
        log_var = log_var.reshape(log_var.size(0), log_var.size(1), -1)
        log_var = torch.mean(log_var, dim=-1, keepdim=True).unsqueeze(-1) # b, latent, 1, 1
        return mean, log_var
      

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.GELU()] # changed from ReLU
                in_channels = x
        #layers += [nn.AvgPool2d(kernel_size=1, stride=1)] giver ikke mening...
        return nn.Sequential(*layers)

class Decoder(nn.Module):
    def __init__(self, vgg_name, latent_dim, output_dim):

        super(Decoder, self).__init__()
        cfg[vgg_name] = [latent_dim,*cfg[vgg_name]] # first layer always input size of image
        self.latent_dim = latent_dim
        self.features = self._make_layers(cfg[vgg_name])
        self.FC_output = nn.Conv2d(cfg[vgg_name][-2], output_dim, kernel_size=1) # when kernel size is set to 1, this is indeed a FC layer:)
        # self.FC_output = nn.Linear(cfg[vgg_name][-2], output_dim)
        
    def forward(self, x):
        out  = self.features(x)
        x_hat = torch.sigmoid(self.FC_output(out))
        
        return x_hat
    
    def _make_layers(self, cfg):
        layers = []
        in_channels = self.latent_dim
        for i in range(len(cfg)):
            if cfg[i] == 'M':
                layers += [nn.ConvTranspose2d(cfg[i-1], cfg[i-1], kernel_size=2, stride=2)] # in decoder, we should upsample the image, instead of downsample it
            else:
                layers += [nn.Conv2d(in_channels, cfg[i], kernel_size=3, padding=1),
                           nn.BatchNorm2d(cfg[i]),
                           nn.LeakyReLU()] # changed from ReLU
                in_channels = cfg[i]
        # layers += [nn.AvgPool2d(kernel_size=1, stride=1)] giver ikke mening...
        return nn.Sequential(*layers)

class Model(nn.Module):
    def __init__(self, Encoder, Decoder):
        super(Model, self).__init__()
        self.Encoder = Encoder
        self.Decoder = Decoder
        
    def reparameterization(self, mean, var):
        epsilon = torch.randn_like(var).to(DEVICE)        # sampling epsilon        
        z = mean + var*epsilon                          # reparameterization trick
        return z
        
                
    def forward(self, x):
        mean, log_var = self.Encoder(x)
        z = self.reparameterization(mean, torch.exp(0.5 * log_var)) # takes exponential function (log var -> var)
        
        # also here, unflatten()
        z = z.view(z.size(0), z.size(1), 1, 1)
        
        x_hat = self.Decoder(z)
        
        return x_hat, mean, log_var

In [13]:

channel_size = testset[0][0].shape[0] #Fixed, dim 0 is the feature channel number
latent_dim = 256 # hyperparameter
lr = 1e-4
epochs = 5

encoder = Encoder('VGG11',  input_dim=channel_size,     latent_dim=latent_dim)
decoder = Decoder('VGG11',  latent_dim=latent_dim,   output_dim = channel_size)

model = Model(Encoder=encoder, Decoder=decoder).to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(), lr = lr)#optim.Adam(model.parameters(), lr= lr)

if firstTrain == True:
    startepoch = 0
    
else:
    checkpoint = torch.load(saveModelPath)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    startepoch = checkpoint['epoch']
    loss = checkpoint['loss']
    
# Default train mode!


Test of dim

In [14]:

x = torch.randn(2,3,32,32)
print(f"size of input{x.size()}")
# Encoder test
mean, var = encoder(x)
print(f"The mean shape {mean.size()}, \nthe variance shape {var.size()}")

# Decoder
epsilon = torch.randn_like(var).to(DEVICE)        # sampling epsilon        
z = mean + var*epsilon  
x_hat = decoder(z)

print(f"Latent vector size: {z.size()}, and x_hat {x_hat.size()}")
# Model
x_hat, mean, var = model(x)

loss_test = nn.functional.binary_cross_entropy(x_hat, x)
print(loss_test)


size of inputtorch.Size([2, 3, 32, 32])
The mean shape torch.Size([2, 256, 1, 1]), 
the variance shape torch.Size([2, 256, 1, 1])
Latent vector size: torch.Size([2, 256, 1, 1]), and x_hat torch.Size([2, 3, 32, 32])
tensor(0.6956, grad_fn=<BinaryCrossEntropyBackward>)


## Traning
In CIFAR100. First define loss function

In [15]:

def loss_function(x, x_hat, mean, log_var):
    reproduction_loss = nn.functional.binary_cross_entropy(x_hat, x)
    #KLD      = - 0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp())
    #KLD = torch.mean( -0.5 * torch.sum(1+ log_var - mean**2 - log_var.exp(),dim=1),dim = 0)
    #scale = 0.00025
    
    #print(f"Reproduction: {reproduction_loss}, \tKLD: {KLD.item()}, \tscaled KLD: {(KLD * scale).item()}, \tlog_var: {log_var.sum()}")
    print(reproduction_loss.item())
    return reproduction_loss #+ KLD*scale #* 0.5 # change? 

Training!

In [16]:
def train(startepoch, num_epochs, model, loader, plot : bool = False):
    loss_list = []
    model.train()

    size = len(loader.dataset)
    num_batches = len(loader)
    
    for epoch in range(startepoch,num_epochs):
        
        for batch_idx, (x, _) in enumerate(loader):
        
            x = x.to(DEVICE)
            
            # clear gradients for this training step 
            optimizer.zero_grad()
                
            x_hat, mean, log_var = model(x)
            
            #print(f"min: {min(torch.flatten(x_hat)).item()}, \tmax: {max(torch.flatten(x_hat)).item()}")
            loss = loss_function(x, x_hat, mean, log_var)
            
            # backpropagation, compute gradients 
            loss.backward()
            
            # apply gradients  
            optimizer.step()
            
            current_batch_size = len(x_hat)
                      
            
            if batch_idx % (500//current_batch_size) == 0:
                loss, current = loss.item(), batch_idx * current_batch_size
                print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
                
                if plot:
                    loss_list.append(loss.item())
                
                pass
        # Save model after each epoch
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, saveModelPath)
        print(f"Checkpoint at epoch {epoch}")
        pass
    
    if plot:
        xVals = list(range(1, len(loss_list) + 1))
        
        # subplots define number of rows and columns
        fig, ax1 = plt.subplots(1, 1)
        ax1.plot(xVals, loss_list, 'o-')
        fig.suptitle(f"Loss through training.")
        ax1.set_ylabel("Loss over training")
       
    print("Done!")            
    pass

train(startepoch, epochs, model, trainloader, True)
# Save final model 
torch.save(model.state_dict(), saveModelPath)

0.694412887096405
0.6948010325431824
0.6946728229522705
0.6949954628944397
0.6948586106300354
0.6946943402290344
0.6947937607765198
0.6950170993804932
0.6949772834777832
0.6949276924133301
0.6947660446166992
0.6946489810943604
0.6946744918823242
0.6945489048957825
0.6950672268867493
0.694856584072113
0.6949207186698914
0.6952629089355469
0.6951518058776855
0.6945661902427673
0.6939906477928162
0.6949846744537354
0.6951618194580078
0.6946007609367371
0.6953015327453613
0.6948404312133789
0.6950771808624268
0.6949770450592041
0.6950057148933411
0.6943953633308411
0.6946597695350647
0.6950089931488037
0.6946697235107422
0.6948482394218445
0.6944198608398438
0.6949440836906433
0.6946679949760437
0.6949397921562195
0.6943309307098389
0.6943848133087158
0.6949140429496765
0.6943840384483337
0.6949601769447327
0.6948098540306091
0.6947371363639832
0.6949717998504639
0.6946975588798523
0.6946940422058105
0.694706380367279
0.6947951912879944
0.6948379874229431
0.6947396397590637
0.6950621604919

KeyboardInterrupt: 

Convert to python file!

In [None]:
!jupyter nbconvert --to script VAE_CIFAR100_test.ipynb