# Generative Models

## GANs & VAE

### HW4

In this notebook, we are going to implement Variational AutoEncoder (VAE) and Generative Adversarial Network (GAN) on the MNIST dataset. VAEs learn a latent variable model from input data. They sample from this distribution and decode it to generate new data. GANs uses a generator to make images based on a prior distribution.


+ Complete the `TODO` parts in the code accordingly.

In [None]:
from torch.utils.data import DataLoader, random_split
from torch import optim
from tqdm.notebook import tqdm

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms, datasets
import matplotlib.pyplot as plt

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE

In [None]:

from google.colab import drive
drive.mount('/content/drive')
data_path = "/content/drive/MyDrive/DL/HW2/"

# Dataset

We will be using MNIST dataset which consists of 60000 data. We split them into train and validation sets, each having 50000 and 10000 data respectively.

In [None]:
transformation = transforms.Compose([
    transforms.ToTensor()
])

In [None]:
mnist_data = datasets.MNIST(root='data/', download=True, transform=transformation)

train_data, valid_data = random_split(mnist_data, [50000, 10000])

In [None]:
num_threads = 4
batch_size = 64

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_threads)
val_loader = DataLoader(valid_data, batch_size=batch_size, shuffle=False, num_workers=num_threads)

# VAE

Variational Autoencoders (VAEs) are a class of generative models that aim to learn the probability distribution of a given dataset in order to generate new data points that resemble the original data. They use and encoder that embeds each data to a smaller latent space and a decoder that tries to generate the original image, given this latent space.

![](https://upload.wikimedia.org/wikipedia/commons/thumb/4/4a/VAE_Basic.png/425px-VAE_Basic.png)
---


In [None]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()

        ##########################################################
        # TODO
        # Define your variational encoder layers.
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mean = nn.Linear(hidden_dim, latent_dim)
        self.fc_log_var = nn.Linear(hidden_dim, latent_dim)
        ##########################################################


    def forward(self, x):

        ##########################################################
        # TODO
        # Convert input `x` to mean and log variance of latent
        # space which is then used to sample data for the decoder.
        h = F.relu(self.fc1(x))
        z_mean = self.fc_mean(h)
        z_log_var = self.fc_log_var(h)
        ##########################################################

        return z_mean, z_log_var

class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()

        ##############################
        # TODO
        # Define your decoder layers.
        self.fc1 = nn.Linear(latent_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        ##############################

    def forward(self, z):
        ########################################
        # TODO
        # Decode the latent vector `z` to images.
        h = F.relu(self.fc1(z))
        reconstruction = torch.sigmoid(self.fc2(h)) # Use sigmoid for pixel values between 0 and 1
        ########################################

        return reconstruction

In [None]:
class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()

        self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
        self.decoder = Decoder(latent_dim, hidden_dim, input_dim)

    def forward(self, x):
        ##########################################################
        # TODO
        # Using the encoder and decoder you have defined, encode the
        # image to the latent space and then reconstruct it.
        # Use the reparameterization trick to ensure gradient flow.
        z_mean, z_log_var = self.encoder(x)

        # Reparameterization trick
        std = torch.exp(0.5 * z_log_var)
        eps = torch.randn_like(std)
        z = z_mean + eps * std

        image_recon = self.decoder(z)
        ##########################################################

        return image_recon, z_mean, z_log_var

In [None]:
def reconstruct_images(vae, images):
    images = images.view(-1, 784)
    with torch.no_grad():
        recon_images, _, _ = vae(images)
    return recon_images.view(-1, 1, 28, 28)

def plot_images(original, reconstructed, n=4):
    fig, axes = plt.subplots(2, n, figsize=(10, 4))
    for i in range(n):
        axes[0, i].imshow(original[i].squeeze().cpu().numpy(), cmap='gray')
        axes[0, i].set_title('Original')
        axes[0, i].axis('off')
        axes[1, i].imshow(reconstructed[i].squeeze().cpu().numpy(), cmap='gray')
        axes[1, i].set_title('Reconstructed')
        axes[1, i].axis('off')
    plt.show()

In [None]:
def loss_function(x, x_hat, mean, log_var):
    """Elbo loss function."""
    reproduction_loss = nn.functional.binary_cross_entropy(x_hat, x, reduction="sum")
    kld = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
    return reproduction_loss + kld

def train_vae(dataloader, val_loader, latent_dim, hidden_dim, input_dim, learning_rate=1e-4, num_epochs=100):
    ###############################################
    # TODO
    # Define the loss fc, optimizer and VAE model.
    ###############################################
    model = VAE(input_dim, hidden_dim, latent_dim).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    sample_images, _ = next(iter(val_loader))
    sample_images = sample_images.to(DEVICE)

    recon_images = []

    val_losses = []
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0

        print(f'Epoch {epoch+1}: ')
        with tqdm(enumerate(dataloader), total=len(dataloader)) as pbar:
            for i, data in pbar:

                ###############################################
                # TODO
                # Calculate the loss.
                # (optional) Use the KL divergence loss to normalize the
                # output distribution of decoder.
                images, _ = data
                images = images.view(-1, input_dim).to(DEVICE)

                # Forward pass
                recon_images_batch, z_mean, z_log_var = model(images)

                # Calculate loss
                loss = loss_function(images, recon_images_batch, z_mean, z_log_var)

                # Backward pass and optimization
                optimizer.zero_grad()
                loss.backward()
                train_loss += loss.item()
                optimizer.step()
                ###############################################

            print('Loss: {:.4f}'.format(train_loss / len(dataloader.dataset)))

        # TODO
        # Reconstruct `sample_images` and plot first 4 images after each epoch
        # This is used to visualize the progress
        with torch.no_grad():
          model.eval()
          recon_images = reconstruct_images(model, sample_images)
          plot_images(sample_images, recon_images)

          val_loss = 0
          with tqdm(enumerate(val_loader), total=len(val_loader)) as pbar:
            for i, data in pbar:
                x = data[0]
                x = x.view(-1, input_dim)
                x = x.to(DEVICE)
                x_hat, mean, log_var = model(x)
                loss = loss_function(x, x_hat, mean, log_var)
                val_loss += loss.item()

            print('Val Loss: {:.4f}'.format(val_loss / len(val_loader.dataset)))
            val_losses.append(val_loss / len(val_loader.dataset))


        # Sample From Latent Space
        if epoch % 10 == 0:
            with torch.no_grad():
                x = torch.randn(8, latent_dim).to(DEVICE)
                samples = model.decoder(x).cpu()
                samples = samples.view(8, 1, 28, 28)
                grid_img = torchvision.utils.make_grid(samples, nrow=4, padding=2, normalize=True)
                plt.figure(figsize=(15, 15))
                plt.imshow(grid_img.permute(1, 2, 0))
                plt.axis('off')
                plt.show()

    return model, val_losses

In [None]:
input_dim = 784
hidden_dim = 256
latent_dim = 10
model, vae_val_loss = train_vae(train_loader, val_loader, latent_dim, hidden_dim, input_dim, num_epochs=100)

In [None]:
torch.save(model.state_dict(), data_path+'vae_model.pth')

# GAN

GANs consist of two models: a generator and a discriminator. The generator creates new data points, and the discriminator evaluates them, trying to distinguish between real and generated (fake) data points. The training process involves updating the generator to produce more realistic data, as judged by the discriminator, and simultaneously updating the discriminator to get better at distinguishing real from fake. This adversarial process leads to improvements in both models, with the generator producing highly realistic data points as a result.

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super(Generator, self).__init__()

        #######################
        # TODO
        # Define your generator
        #######################
        self.img_shape = img_shape

        self.model = nn.Sequential(
            # TODO
            # Project and reshape
            nn.Linear(latent_dim, 128 * 7 * 7),
            nn.ReLU(True),
            nn.Unflatten(1, (128, 7, 7)),

            # Upsample to 14x14
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            # Upsample to 28x28
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Sigmoid() # Use Sigmoid to output pixel values in [0, 1]
        )

    def forward(self, z):
        ##############################################
        # TODO
        # generate an image using `z` vector, sampled
        # from a prior distribution.
        # Reshape the result to the shape of original images
        ##############################################
        img = self.model(z)
        # The view is handled implicitly by the model layers, but this is good for ensuring shape
        img = img.view(img.size(0), *self.img_shape)
        return img

class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super(Discriminator, self).__init__()

        ###########################
        # TODO
        # Define your discriminator
        ###########################
        self.model = nn.Sequential(
            # TODO
            # Input: 1 x 28 x 28
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            # State: 64 x 14 x 14
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            # State: 128 x 7 x 7
            nn.Flatten(),
            nn.Linear(128 * 7 * 7, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        ##############################################
        # TODO
        # Predict whether each image in `img` is real or fake.
        ##############################################
        validity = self.model(img)
        return validity

In [None]:
def train_gan(dataloader, val_loader, latent_dim, img_shape=(1, 28, 28), learning_rate=3e-4, num_epochs=100):
    ###############################################
    # TODO
    # Define the loss fc, optimizers, generator and discriminator.
    ###############################################
    generator = Generator(latent_dim, img_shape).to(DEVICE)
    discriminator = Discriminator(img_shape).to(DEVICE)
    adversarial_loss = nn.BCELoss().to(DEVICE)
    optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

    val_gen_losses, val_dis_losses = [], []
    for epoch in range(num_epochs):
        generator.train()
        discriminator.train()
        with tqdm(enumerate(dataloader), total=len(dataloader)) as pbar:
            for i, (imgs, _) in pbar:

                ##########################################################
                # TODO
                # Calculate the loss for the Generator and the Discriminator.
                ##########################################################

                # Move data to the configured device
                real_imgs = imgs.to(DEVICE)

                # Adversarial ground truths
                valid = torch.ones(imgs.size(0), 1, device=DEVICE)
                fake = torch.zeros(imgs.size(0), 1, device=DEVICE)

                # -----------------
                #  Train Generator
                # -----------------
                optimizer_G.zero_grad()

                # Sample noise as generator input
                z = torch.randn(imgs.size(0), latent_dim, device=DEVICE)

                # Generate a batch of images
                gen_imgs = generator(z)

                # Loss measures generator's ability to fool the discriminator
                g_loss = adversarial_loss(discriminator(gen_imgs), valid)

                g_loss.backward()
                optimizer_G.step()

                # ---------------------
                #  Train Discriminator
                # ---------------------
                optimizer_D.zero_grad()

                # Measure discriminator's ability to classify real images
                real_loss = adversarial_loss(discriminator(real_imgs), valid)

                # Measure discriminator's ability to classify fake images
                fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)

                d_loss = (real_loss + fake_loss) / 2

                d_loss.backward()
                optimizer_D.step()

                # TODO
                # Calculate and fill discriminator and generator losses
                pbar.set_description("[epoch: {}/{}] [D loss: {:.4f}] [G loss: {:.4f}]".format(epoch+1, num_epochs, d_loss.item(), g_loss.item()))

        # TODO
        # Plot generated images after each epoch
        generator.eval()
        discriminator.eval()
        val_loss = [0., 0.]
        with torch.no_grad():
          with tqdm(enumerate(val_loader), total=len(val_loader)) as vbar:
              for i, (imgs, _) in vbar:

                  ##########################################################
                  # TODO
                  # Calculate the loss for the Generator and the Discriminator.
                  ##########################################################
                  real_imgs = imgs.to(DEVICE)

                  # Adversarial ground truths
                  valid = torch.ones(imgs.size(0), 1, device=DEVICE)
                  fake = torch.zeros(imgs.size(0), 1, device=DEVICE)

                  # Sample noise
                  z = torch.randn(imgs.size(0), latent_dim, device=DEVICE)
                  gen_imgs = generator(z)

                  # Calculate losses
                  g_loss = adversarial_loss(discriminator(gen_imgs), valid)
                  real_loss_val = adversarial_loss(discriminator(real_imgs), valid)
                  fake_loss_val = adversarial_loss(discriminator(gen_imgs.detach()), fake)
                  d_loss = (real_loss_val + fake_loss_val) / 2

                  # TODO
                  # Calculate and fill discriminator and generator losses
                  vbar.set_description("[Validation] [D loss: {:.4f}] [G loss: {:.4f}]".format(d_loss.item(), g_loss.item()))
                  val_loss[0] = val_loss[0] + g_loss.item()
                  val_loss[1] = val_loss[1] + d_loss.item()

              val_gen_losses.append(val_loss[0] / len(val_loader))
              val_dis_losses.append(val_loss[1] / len(val_loader))
              print("[Validation avg loss] ----- [D loss: {:.4f}] [G loss: {:.4f}]".format(val_dis_losses[-1], val_gen_losses[-1]))

        z = torch.randn(8, latent_dim).type_as(imgs).to(DEVICE)
        generated_imgs = generator(z)
        grid = torchvision.utils.make_grid(generated_imgs, nrow=4, normalize=True)
        plt.figure(figsize=(10, 10))
        plt.imshow(grid.cpu().permute(1, 2, 0))
        plt.show()

    return generator, discriminator, val_gen_losses, val_dis_losses

In [None]:
latent_dim = 100

In [None]:
generator, discriminator, val_gen_losses, val_dis_losses = train_gan(train_loader, val_loader, latent_dim, num_epochs=20)

In [None]:
torch.save(generator.state_dict(), data_path+'gen_model.pth')
torch.save(discriminator.state_dict(), data_path+'dis_model.pth')

# Compare

Use validation dataset to plot and compare the results of your trained models.

In [None]:
# TODO
# Comparison of the generated images

In [None]:
def plot(losses, title):
  plt.plot(np.arange(len(losses)), losses)
  plt.xlabel('Epochs')
  plt.ylabel('Loss')
  plt.title(title)
  plt.show()

## **VAE**

In [None]:
# VAE validation loss plot
plot(losses=vae_val_loss, title='VAE Validation Loss')

In [None]:
vae = VAE(784, 256, 10)
vae.load_state_dict(torch.load(data_path+'vae_model.pth'))
vae.to(DEVICE)
vae.eval()

In [None]:
def VAElatentSamplegeneration(model, num=64):
  model.eval()
  with torch.no_grad():
    x = torch.randn(num, 10).to(DEVICE)
    samples = model.decoder(x).cpu()
    samples = samples.view(num, 1, 28, 28)
    grid_img = torchvision.utils.make_grid(samples, nrow=num//8, padding=2, normalize=True)
    plt.figure(figsize=(15, 15))
    plt.imshow(grid_img.permute(1, 2, 0))
    plt.axis('off')
    plt.show()

  return samples

In [None]:
VAElatentSamplegeneration(vae)

## **GAN**

In [None]:
# GAN Generator validation loss plot
plot(losses=val_gen_losses, title='GAN Generator Validation Loss')

In [None]:
# GAN Discriminator validation loss plot
plot(losses=val_dis_losses, title='GAN Discriminator Validation Loss')

In [None]:
generator = Generator(latent_dim=100, img_shape=(1, 28, 28))
generator.load_state_dict(torch.load(data_path+'gen_model.pth'))
generator.to(DEVICE)
generator.eval()

In [None]:
def GANlatentSamplegeneration(model, num=64):
  model.eval()
  with torch.no_grad():
    x = torch.randn(num, 100).to(DEVICE)
    samples = model(x).cpu()
    samples = samples.view(num, 1, 28, 28)
    grid_img = torchvision.utils.make_grid(samples, nrow=num//8, padding=2, normalize=True)
    plt.figure(figsize=(15, 15))
    plt.imshow(grid_img.permute(1, 2, 0))
    plt.axis('off')
    plt.show()

  return samples

In [None]:
GANlatentSamplegeneration(generator, num=64)

## Evaluate with GAN discriminator

In [None]:
# load discriminator
discriminator = Discriminator(img_dim=(1, 28, 28))
discriminator.load_state_dict(torch.load(data_path+'dis_model.pth'))

In [None]:
def discriminatorEvaluation(model, generated_samples):
  model.eval()
  pred_validation = model(generated_samples)

  pred_validation[pred_validation >= 0.5] = 1
  pred_validation[pred_validation < 0.5] = 0

  return np.count_nonzero(pred_validation == 1) / pred_validation.shape[0]

In [None]:
vae_generated_samples = VAElatentSamplegeneration(vae.to('cuda'), num=128)
gan_generated_samples = GANlatentSamplegeneration(generator.to('cuda'), num=128)

In [None]:
vae_generated_samples.shape, gan_generated_samples.shape

In [None]:
print(f'percentage of valid generated image to all for GAN model: {discriminatorEvaluation(discriminator, gan_generated_samples)}')
print(f'percentage of valid generated image to all for VAE model: {discriminatorEvaluation(discriminator, vae_generated_samples)}')