---

## Deep Learning Coding Project 3-3: Variational Auto-Encoder

Before we start, please put your **Chinese** name and student ID in following format:

Name, 0000000000 // e.g.) 小明, 2021123123

YOUR ANSWER HERE

## Introduction

We will use Python 3, [NumPy](https://numpy.org/), and [PyTorch](https://pytorch.org/) packages for implementation. To avoid unexpected issues with PyTorch 2.0, we recommend using PyTorch version 1.x.

In this coding project, you will implement 4 generative models, i.e., energy-based model, flow-based model, variational auto-encoder, and generative adverserial network, to generate MNIST images.

**We will implement a conditional variational auto-encoder (CVAE) in this notebook.**

In some cells and files you will see code blocks that look like this:

```Python
##############################################################################
#                  TODO: You need to complete the code here                  #
##############################################################################
raise NotImplementedError()
##############################################################################
#                              END OF YOUR CODE                              #
##############################################################################
```

You should replace `raise NotImplementedError()` with your own implementation based on the context, such as:

```Python
##############################################################################
#                  TODO: You need to complete the code here                  #
##############################################################################
y = w * x + b
##############################################################################
#                              END OF YOUR CODE                              #
##############################################################################

```

When completing the notebook, please adhere to the following rules:

+ Unless otherwise stated, do not write or modify any code outside of code blocks
+ Do not add or delete any cells from the notebook.
+ Run all cells before submission. We will not re-run the entire codebook during grading.

**Finally, avoid plagiarism! Any student who violates academic integrity will be seriously dealt with and receive an F for the course.**

### Task

In this problem, you need to implement a class conditioned variational autoencoder to generate MNIST images. We suppose the prior $p(z)$ is a
standard Gaussian distribution $\mathcal{N} (0, I)$. Also, we assume $q(z|x, y)$ and $p(x|z, y)$
are Gaussian distributions.

1. **You need to complete the encoder $q(z|x, y; \phi)$ and the decoder $p(x|z, y; \theta)$, which are both MLPs**.

2. **Implement the VAE loss function.**

An example of generated images using CVAE is shown below.

If you use Colab in this coding project, please uncomment the cell below, change the `GOOGLE_DRIVE_PATH` to your project folder and run the following cell to mount your Google drive. Then, the notebook can find the required files (i.e., utils.py). If you run the notebook locally, you can skip this cell.

In [None]:
# ### uncomment this cell if you're using Google colab
# from google.colab import drive
# drive.mount('/content/drive')

# ### change GOOGLE_DRIVE_PATH to the path of your CP3 folder
# GOOGLE_DRIVE_PATH = '/content/drive/MyDrive/Colab Notebooks/DL23SP/CP3'
# %cd $GOOGLE_DRIVE_PATH

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib import rcParams

%matplotlib inline

# figure size in inches optional
rcParams['figure.figsize'] = 6, 4
plt.imshow(mpimg.imread('./vae/sample.png'))

### Submission

You need to submit your code (this notebook), your trained VAE model (named `./vae/vae_best.pth`), your generated images, and your report:

+ **Code**

Remember to run all the cells before submission. Remain your tuned hyperparameters unchanged.

+ **Generator**

In this notebook, we select the best model based on validation loss. You can also manually select the best one, and save it as `./vae/vae_best.pth`. **Please do not submit any other checkpoints except for `./vae/vae_best.pth`!**

+ **Generated Images**

Please generate 100 images for each class (1000 in total), save it in `./vae/generated/`, and organize them as the following structure:

```
CodingProject3
├── ...
├── vae
│   ├── vae_best.pth
│   ├── generated
│       └── 0
│           ├── 0_00.png
│           ├── ...
│           ├── 0_99.png
│       └── 1
│           ├── 1_00.png
│           ├── ...
│           ├── 1_99.png
│       └── ...
│       └── 9
│           ├── 9_00.png
│           ├── ...
│           ├── 9_99.png
```

Specifially, you should name the $j$-th generated image of class $i$ as `{i}_{j}.png`, and save it into folder `./vae/generated/{i}/`. Check the `make_dataset` method of VAE model for details.

+ **Report**

Please include the conditioned generation results (i.e., generated images in a $10\times 10$ array as above), the FID score, the standard deviation for each class, and other relevant statistics in your
report. Note that you only need to write a single report for this coding project.

### Grading

We will evaluate your model by **computing the [FID score](https://en.wikipedia.org/wiki/Fr%C3%A9chet_inception_distance)**. We have provided a fine-tuned Inception-V3 model, which will be used by our evaluation script to compute FID score.

## Set Up Code

If you use Colab in this coding project, please make sure to mount your drive before running the cells below.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from utils import hello
hello()

Please run the following cell to import some base classes for implementation (no matter whether you use colab).

In [None]:
from collections import deque
from torch.utils.data import DataLoader
from tqdm.autonotebook import tqdm

import numpy as np
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

from utils import save_model, load_model, train_set, val_set

seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

device = torch.device(
    "cuda") if torch.cuda.is_available() else torch.device("cpu")

## VAE Model

Complete the conditional VAE model with structure shown in doc strings.

**Hint**: we usually output logarithm standard deviation.

In [None]:
class CVAE(nn.Module):
    def __init__(self, img_size, label_size, latent_size, hidden_size):
        super(CVAE, self).__init__()
        self.img_size = img_size  # (C, H, W)
        self.label_size = label_size
        self.latent_size = latent_size
        self.hidden_size = hidden_size
        
        # Encoder.
        '''
        img   -> fc  ->                   -> fc -> mean    
                        concat -> encoder                  -> z
        label -> fc  ->                   -> fc -> logstd 
        '''
        ##############################################################################
        #                  TODO: You need to complete the code here                  #
        ##############################################################################
        # YOUR CODE HERE
        raise NotImplementedError()
        ##############################################################################
        #                              END OF YOUR CODE                              #
        ##############################################################################
        
        # Decoder.
        '''
        latent -> fc ->
                         concat -> decoder -> reconstruction
        label  -> fc ->
        '''
        ##############################################################################
        #                  TODO: You need to complete the code here                  #
        ##############################################################################
        # YOUR CODE HERE
        raise NotImplementedError()
        ##############################################################################
        #                              END OF YOUR CODE                              #
        ##############################################################################

    def encode_param(self, x, y):
        # compute mu and logstd of p(z|x, y)
        ##############################################################################
        #                  TODO: You need to complete the code here                  #
        ##############################################################################
        # YOUR CODE HERE
        raise NotImplementedError()
        ##############################################################################
        #                              END OF YOUR CODE                              #
        ##############################################################################

    def reparamaterize(self, mu: torch.Tensor, logstd: torch.Tensor):
        # compute latent z with reparameterization trick
        ##############################################################################
        #                  TODO: You need to complete the code here                  #
        ##############################################################################
        # YOUR CODE HERE
        raise NotImplementedError()
        ##############################################################################
        #                              END OF YOUR CODE                              #
        ##############################################################################

    def encode(self, x, y):
        # sample latent z from p(z|x, y)
        ##############################################################################
        #                  TODO: You need to complete the code here                  #
        ##############################################################################
        # YOUR CODE HERE
        raise NotImplementedError()
        ##############################################################################
        #                              END OF YOUR CODE                              #
        ##############################################################################

    def decode(self, z, y):
        ##############################################################################
        #                  TODO: You need to complete the code here                  #
        ##############################################################################
        # YOUR CODE HERE
        raise NotImplementedError()
        ##############################################################################
        #                              END OF YOUR CODE                              #
        ##############################################################################

    @torch.no_grad()
    def sample_images(self, label, save=True, save_dir='./vae'):
        self.eval()
        n_samples = label.shape[0]
        samples  = self.decode(torch.randn(n_samples, self.latent_size).to(label.device), label)
        imgs = samples.view(n_samples, 1, 28, 28).clamp(0., 1.)
        if save:
            os.makedirs(save_dir, exist_ok=True)
            torchvision.utils.save_image(imgs, os.path.join(save_dir, 'sample.png'), nrow=int(np.sqrt(n_samples)))
        return imgs
    
    @torch.no_grad()
    def make_dataset(self, n_samples_per_class=10, save=True, save_dir='./vae/generated/'):
        self.eval()
        device = next(self.parameters()).device
        for i in range(self.label_size):
            label = torch.zeros(n_samples_per_class, self.label_size, device=device)
            label[:, i] = 1
            samples = self.decode(torch.randn(
                n_samples_per_class, self.latent_size).to(device), label)
            imgs = samples.view(n_samples_per_class, 1, 28, 28).clamp(0., 1.)
            if save:
                os.makedirs(os.path.join(save_dir, str(i)), exist_ok=True)
                for j in range(n_samples_per_class):
                    torchvision.utils.save_image(imgs[j], os.path.join(save_dir, str(i), "{}_{:>03d}.png".format(i, j)))


## VAE Loss

Given image $x$ and corresponding label $y$, compute the VAE loss in the following function.

**Hint**: $p(x|z, y)$ is a real-valued Gaussian distribution, while images are in range $[0, 1]$. Therefore, you may want to transform $x$ when computing $p(x|z, y)$.

In [None]:
def compute_vae_loss(vae_model, x, y, beta=1):
    # compute vae loss for input x and label y
    ##############################################################################
    #                  TODO: You need to complete the code here                  #
    ##############################################################################
    # YOUR CODE HERE
    raise NotImplementedError()
    ##############################################################################
    #                              END OF YOUR CODE                              #
    ##############################################################################


## Training & Evaluation

We have implemented the training and evaluation functions. Feel free to modify `train` if you want to monitoring more information. Make sure your best model is stored in `'./vae/vae_best.pth'`.

In [None]:
@torch.no_grad()
def evaluate(vae_model, loader, device, beta):
    vae_model.eval()
    val_loss = 0
    n_batches = 0

    pbar = tqdm(total=len(loader.dataset))
    pbar.set_description('Eval')
    for batch_idx, (x, y) in enumerate(loader):
        n_batches += x.shape[0]
        x = x.view(x.shape[0], -1).to(device)
        y = y.to(device)

        val_loss += compute_vae_loss(vae_model, x, y, beta).sum().item()
        pbar.update(x.size(0))
        pbar.set_description('Val Loss: {:.6f}'.format(val_loss / n_batches))

    pbar.close()
    return val_loss / n_batches

In [None]:
def train(n_epochs, vae_model, train_loader, val_loader, optimizer, beta=1, device=torch.device('cuda'), save_interval=10):
    vae_model.to(device)
    best_val_loss = np.inf

    for epoch in range(n_epochs):
        train_loss = 0
        n_batches = 0
        pbar = tqdm(total=len(train_loader.dataset))
        for i, (x, y) in enumerate(train_loader):
            # compute loss
            vae_model.train()
            n_batches += x.shape[0]
            x = x.view(x.shape[0], -1).to(device)
            y = y.to(device)
            loss = compute_vae_loss(vae_model, x, y, beta)

            optimizer.zero_grad()
            loss.mean().backward()
            optimizer.step()

            train_loss += loss.sum().item()

            pbar.update(x.size(0))
            pbar.set_description('Train Epoch {}, Train Loss: {:.6f}'.format(epoch + 1, train_loss / n_batches))
        pbar.close()

        if (epoch + 1) % save_interval == 0:
            os.makedirs(f'./vae/{epoch + 1}', exist_ok=True)
            vae_model.eval()
            save_model(f'./vae/{epoch + 1}/vae.pth', vae_model, optimizer)

            val_loss = evaluate(vae_model, val_loader, device, beta=beta)

            # sample and save images
            label = torch.eye(10).repeat(10, 1).to(device)
            vae_model.sample_images(
                label, save=True, save_dir=f"./vae/{epoch + 1}/")
            
            if val_loss < best_val_loss:
                print(
                    f'Current validation loss: {best_val_loss} -> {val_loss}')
                best_val_loss = val_loss
                save_model('./vae/vae_best.pth', vae_model)

## Enjoy

Tune your hyperparameters and make your conditional VAE work. Good luck!

In [None]:
label_dim = 10
img_dim = (1, 28, 28)
latent_dim = 100
cvae = CVAE(img_dim, 10, 100, 256)

train_loader = DataLoader(train_set, batch_size=128, pin_memory=True,
                          drop_last=False, shuffle=True, num_workers=8)
val_loader = DataLoader(val_set, batch_size=512, pin_memory=True,
                        drop_last=True, shuffle=True, num_workers=8)
optimizer = torch.optim.Adam(cvae.parameters(), lr=2e-4)

Let's start training! Please keep in mind that this cell may **NOT** be run when we evaluate your assignment!

In [None]:
# feel free to change training hyper-parameters!
train(50, cvae, train_loader, val_loader, optimizer, device=device, save_interval=10)

## Evaluation

Make sure your code runs fine with the following cells!

In [None]:
# collect model-generated samples
cvae.load_state_dict(load_model('./vae/vae_best.pth')[0])
cvae.make_dataset(n_samples_per_class=100)

In [None]:
# run evaluation
!python evaluate_cgen.py --vae