HW2. Conditional Variational Autoencoders (CVAE)
======
Implement a deep learning model that generates images of a given number, based on the CVAE. (Take a look at the Lab1 PowerPoint)

There are **three code blocks** to be implemented.
[Encoder, Decoder, Conditional Variational Autoencoder]



Submit a PDF file (a 1~2 pages report) and your code.
Your report should contain the following.

- How the model architecture was modified to incorporate class condition. (Just take a screenshot of your code.)
- Discuss the comparison of reconstruction errors between VAE and CVAE.
- You may observe that images not corresponding to the conditions are also generated. Discuss the reasons and propose methods to resolve it.

In [None]:
# This is for colab users to mount google drive
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
# install pytorch (http://pytorch.org/) if run from Google Colaboratory
import sys
%matplotlib inline
import os

import torch
import torch.nn as nn
import torch.nn.functional as F

Hyperparameters
-----

In [None]:
# 2-d latent space, parameter count in the same order of magnitude
# as in the original VAE paper (VAE paper has about 3x as many)
latent_dims = 2
num_epochs = 100
batch_size = 128
capacity = 64
learning_rate = 1e-3
variational_beta = 1
use_gpu = True
savepath='cvae_2d_100.pth'

# # 10-d latent space, for comparison with non-variational auto-encoder
# latent_dims = 10
# num_epochs = 100
# batch_size = 128
# capacity = 64
# learning_rate = 1e-3
# variational_beta = 1
# use_gpu = True

MNIST Data Loading
-------------------

MNIST images show digits from 0-9 in 28x28 grayscale images. We do not center them at 0, because we will be using a binary cross-entropy loss that treats pixel values as probabilities in [0,1]. We create both a training set and a test set.

In [None]:
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

img_transform = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = MNIST(root='./data/MNIST', download=True, train=True, transform=img_transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = MNIST(root='./data/MNIST', download=True, train=False, transform=img_transform)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [None]:
from torchvision.transforms.functional import to_pil_image
idx = 42
# train_dataset[idx] := (image data, class)
print(train_dataset[idx][1]) # class
to_pil_image(train_dataset[idx][0]) # image data

CVAE Definition*
-----------------------

### Encoder*

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        c = capacity

        # Result of concatenation: additional channel indicating class information
        # i.e. input channel consists of 1 channel for input image and 1 channel for class information
        # 2 channels in total for the first conv layer
        self.conv1 = nn.Conv2d(in_channels=2, out_channels=c, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=c, out_channels=c*2, kernel_size=4, stride=2, padding=1)
        self.fc_mu = nn.Linear(in_features=c*2*7*7, out_features=latent_dims)
        # Log to variance makes the model more stable
        self.fc_logvar = nn.Linear(in_features=c*2*7*7, out_features=latent_dims)

    def forward(self, x, condition): # modified
        # Simple way to add class information: concatenate to first conv layer
        condition = (condition.unsqueeze(1).unsqueeze(2).unsqueeze(3)*torch.ones((x.size(0),1,x.size(2),x.size(3)), device=x.device)) ## condition.shape = [b,1,w,h]
        x = torch.concat((x, condition), 1) 

        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x_mu = self.fc_mu(x)
        x_logvar = self.fc_logvar(x)
        return x_mu, x_logvar


### Decoder*

In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        c = capacity
        # Basically the reverse order of encoder
        # Plus additional dimension for class information
        self.fc = nn.Linear(in_features=latent_dims+1, out_features=c*2*7*7)

        # Transpose conv layer for upscaling (deconvolution)
        self.conv2 = nn.ConvTranspose2d(in_channels=c*2, out_channels=c, kernel_size=4, stride=2, padding=1)
        self.conv1 = nn.ConvTranspose2d(in_channels=c, out_channels=1, kernel_size=4, stride=2, padding=1)

    def forward(self, x, condition): #modified
        # Simple way to add class information: concatenate to first deconv layer
        condition = condition.unsqueeze(1) ## condition.shape = [b,1]
        x = torch.concat((x, condition), 1)

        x = self.fc(x)
        x = x.view(x.size(0), capacity*2, 7, 7)
        x = F.relu(self.conv2(x))
        x = torch.sigmoid(self.conv1(x))
        return x

### Conditional Variational Autoencoder*

> 들여쓴 블록



In [None]:
class ConditionalVariationalAutoencoder(nn.Module):
    def __init__(self):
        super(ConditionalVariationalAutoencoder, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x, condition):

        # Add condition tensor for running forward
        latent_mu, latent_logvar = self.encoder(x, condition)
        latent = self.latent_sample(latent_mu, latent_logvar)
        x_recon = self.decoder(latent, condition)

        return x_recon, latent_mu, latent_logvar

    def latent_sample(self, mu, logvar):
        if self.training:
            '''
            logvar = log(s^2)
            0.5 * logvar = 0.5 * log(s^2) = log(s^(2 * 0.5)) = log(s)
            .exp_() means 'to exponentiate'
            exp(log(s)) = e^(log(s)) = s
            '''
            std = logvar.mul(0.5).exp_()
            eps = torch.empty_like(std).normal_() # normal distribution
            return eps.mul(std).add_(mu) # e * s + m
        else:
            return mu

In [None]:
cvae = ConditionalVariationalAutoencoder()

device = torch.device("cuda:0" if use_gpu and torch.cuda.is_available() else "cpu")
cvae = cvae.to(device)

num_params = sum(p.numel() for p in cvae.parameters() if p.requires_grad)
print('Number of parameters: %d' % num_params)

Train CVAE
--------

### Loss function

In [None]:
def cvae_loss(recon_x, x, mu, logvar):
    recon_loss = F.binary_cross_entropy(recon_x.view(-1, 784), x.view(-1, 784), reduction='sum')
    kldivergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return recon_loss + variational_beta * kldivergence

### Training

In [None]:
optimizer = torch.optim.Adam(params=cvae.parameters(), lr=learning_rate, weight_decay=1e-5)

cvae.train()

train_loss_avg = []

print('Training ...')
for epoch in range(num_epochs):
    train_loss_avg.append(0)
    num_batches = 0

    for image_batch, condition_batch in train_dataloader: # modified - class label is needed in C-VAE

        image_batch = image_batch.to(device)
        condition_batch = condition_batch.to(device) #modified

        image_batch_recon, latent_mu, latent_logvar = cvae(image_batch,condition_batch) #modified

        loss = cvae_loss(image_batch_recon, image_batch, latent_mu, latent_logvar)

        optimizer.zero_grad()
        loss.backward()

        optimizer.step()

        train_loss_avg[-1] += loss.item()
        num_batches += 1

    train_loss_avg[-1] /= num_batches
    print('Epoch [%d / %d] average reconstruction error: %f' % (epoch+1, num_epochs, train_loss_avg[-1]))

In [None]:
torch.save(cvae.state_dict(), 'gdrive/My Drive/Colab Notebooks/'+savepath)

### Plot Training Curve

In [None]:
import matplotlib.pyplot as plt
plt.ion()

fig = plt.figure()
plt.plot(train_loss_avg)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

Experiments
-------------------------

### Reconstruction

In [None]:
cvae.eval()

test_loss_avg, num_batches = 0, 0
for image_batch, condition_batch in test_dataloader: #modified

    with torch.no_grad():

        image_batch = image_batch.to(device)
        condition_batch = condition_batch.to(device) #modified

        # vae reconstruction
        image_batch_recon, latent_mu, latent_logvar = cvae(image_batch,condition_batch) #modified

        # reconstruction error
        loss = cvae_loss(image_batch_recon, image_batch, latent_mu, latent_logvar)

        test_loss_avg += loss.item()
        num_batches += 1

test_loss_avg /= num_batches
print('average reconstruction error: %f' % (test_loss_avg))

Visualisation

In [None]:
import numpy as np
import matplotlib.pyplot as plt
plt.ion()

import torchvision.utils

cvae.eval()

def to_img(x):
    x = x.clamp(0, 1)
    return x

def show_image(img):
    img = to_img(img)
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

def visualise_output(images, conditions, model):

    with torch.no_grad():

        images = images.to(device)
        conditions = conditions.to(device)
        images, _, _ = model(images, conditions)
        images = images.cpu()
        images = to_img(images)
        np_imagegrid = torchvision.utils.make_grid(images[1:50], 10, 5).numpy()
        plt.imshow(np.transpose(np_imagegrid, (1, 2, 0)))
        plt.show()

images, labels = next(iter(test_dataloader)) # Bug Fix

print('Original images')
show_image(torchvision.utils.make_grid(images[1:50],10,5))
plt.show()

print('VAE reconstruction:')
visualise_output(images, labels, cvae)

### Conditional Generation

In [None]:
cvae.eval()

with torch.no_grad():

    # Sampling from normal distribution - prior probability
    latent = torch.randn(100, latent_dims, device=device)
    condition = torch.tensor([i//10 for i in range(100)], device=device) #modified

    img_recon = cvae.decoder(latent, condition) #modified
    img_recon = img_recon.cpu()

    fig, ax = plt.subplots(figsize=(5, 5))
    show_image(torchvision.utils.make_grid(img_recon.data[:100],10,5))
    plt.show()