# Deep Learning course - LAB 10

## Generative Adversarial Networks

Generative Adversarial Networks (GANs) are a neural-based category of generative models.

Up to now, we have seen only **discriminative models**, i.e. models tasked with learning $P(Y\vert X)$, where $X$ are the observations and $Y$ is the dependent variable (category or scalars, depending uon the task).

**Generative models**, on the other hand, are tasked with learning $P(Y, X)$, i.e. in simple terms they learn a _**rule**_ through which we can sample (generate) any amount of data.

Specifically, GANs are expressed as a game in which two models are in competition:
* a **generator** $G$ generates synthetic data
* a **discriminator** $D$ distinguishes whether a piece of data is synthetic or not

![](img/gan.jpg)

$D$ is usually expressed as a *simple* network for binary classification and trained using the **binary cross-entropy loss** (BCELoss). Given $D(x)\in[0,1]$ output of the discriminator and $y\in\{0,1\}$ the ground truth, $\mathcal{L}_D(\hat{y}, y) = -y\log(D(\hat{y})) - (1-y) \log(1-D(\hat{y}))$.

On the other hand, $G$ is a more complex entity. The input is a *latent variable* $z\in\mathbb{R}^d$, while the output is a point in the data space $\mathcal{D}$ (i.e. $x\in\mathcal{D}$). The latent variable is sample from a fixed distribution, which usually is $\mathcal{N}(\mathbf{0}_d,\sigma^2I_d )$.

The generator, recall, is tasked with producing samples which $D$ misclassifies as real. We may then ask for the following: $G^\star = \text{argmax}_G \{ \mathcal{L}_D(G(z)) \}$, or, that the generated data $G(z)$ *increase* the loss of $D$.

Merging the two concepts, we express the GAN training as a *minmax* game:

$\min_G \max_D \log(D(x)) + \log(1-D(G(z)))$ (note: the signs are switched w.r.t. $\mathcal{L}_D$, so the $\min$ becomes $\max$ and vice-versa).

In practical terms, the GANs are trained this way:

1. Evaluate the discriminator on real data (usually marked with label `1`)
2. Evaluate the discriminator on synthetic/fake data (label `0`) -- at the first iteration this is done on the generated data being pure noise
    * the synthetic data is generated starting from a sample from the desidered latent distribution
3. Update the params of the discriminator via $\mathcal{L}_D$
    * since the label is `1` for real data and `0` for fake data, it's easy to verify that the loss can be decomposed into two parts
        * $-\log(D(\hat{y}))$ for the real data
        * $-\log(1-D(\hat{y}))$ for the fake data. Notice that, since we deal with data generated by $G$, we may replace $\hat{y}$ with $G(z)$: $-\log(1-D(G(z)))$
3. **After** having trained the discriminator, train the generator:
    * we wish to **maximize** $-\log(1-D(G(z)))$ (the "fake part" of the discriminator loss)
    * this formulation can lead to an imbalanced discriminator (too strong w.r.t. generator in early training), so, [3](https://arxiv.org/pdf/1406.2661.pdf) suggests to minimize $-\log(D(G(z)))$ instead.
        * Notice that this formulation is equivalent to the **first part** of the discriminator loss, so while training $G$ we can still use the BCELoss, but inverting the labels for real and synthetic data.

In [None]:
import torch
from torch import nn
import torchvision

from matplotlib import pyplot as plt
import cv2
import numpy as np

from scripts.mnist import get_data
from scripts.torch_utils import use_gpu_if_possible
from scripts.train_utils import AverageMeter
from scripts.architectures import MLPCustom

#### Data

We're not using the presets from `scripts` since we need to update (a) the batch size and (b) the normalizing constants.

**Q**: why are we not using `torchvision.transforms.Normalize(...)`?

In [None]:
trainloader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(
        "datasets",
        train=True,
        download=True,
        transform=torchvision.transforms.ToTensor()
    ),
    batch_size=64,
    shuffle=True,
)

testloader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(
        "datasets",
        train=False,
        download=True,
        transform=torchvision.transforms.ToTensor()
    ),
    batch_size=64,
    shuffle=False
)

this is a free reinterpretation of DCGAN [1](https://arxiv.org/abs/1511.06434) build s.t. we can match the original shape of MNIST. We need to tweak the values of the transposed convolution because the original implementation is thought for images with $n^2$ spatial dimensions, while MNIST does not abide to that prerequisite. This is one of many possible implementations leading to an image of size (28, 28) as output.

Also, we need to define a custom initialization function as in [1](https://arxiv.org/abs/1511.06434) the researchers found that very small weights are needed in order to make it work better.

As far as the architectural choice is concerned, we can either:
* put `nn.Sigmoid()` as the final activation and avoid data normalization
* put `nn.Tanh()` as the final activation and normalize images using `mean=[0.5], std=[0.5]`

In [None]:
# custom weights initialization called on discrim and generator
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
class Generator(nn.Module):
    def __init__(self, dim_latent, base_width=64, output_ch=1):
        super().__init__()
        self.layers = nn.Sequential(
            nn.ConvTranspose2d(dim_latent, base_width*8, kernel_size=5, bias=False),
            nn.BatchNorm2d(base_width*8),
            nn.ReLU(),
            nn.ConvTranspose2d(base_width*8, base_width*4, kernel_size=5, bias=False),
            nn.BatchNorm2d(base_width*4),
            nn.ReLU(),
            nn.ConvTranspose2d(base_width*4, base_width*2, kernel_size=5, stride=2, bias=False),
            nn.BatchNorm2d(base_width*2),
            nn.ReLU(),
            nn.ConvTranspose2d(base_width*2, base_width, kernel_size=5, bias=False),
            nn.BatchNorm2d(base_width),
            nn.ReLU(),
            nn.ConvTranspose2d(base_width, output_ch, kernel_size=4, bias=False),
            nn.Sigmoid() 
        )
    
    def forward(self, z):
        z = z.unsqueeze(-1).unsqueeze(-1) # append two spatial dimensions to make it image-like
        return self.layers(z)

Let's test if the generator produces an output of the desired shape:

In [None]:
dim_latent = 100
g = Generator(dim_latent=dim_latent)
g.apply(weights_init)
g(torch.rand((1,dim_latent))).shape

The discriminator is a regular CNN for classification with a single neuron as output (=> binary classification).

Note the use of Leaky ReLU, which is suggested for GANs.

In [None]:
class Discriminator(nn.Module):
    def __init__(self, base_width=64, input_ch=1):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(input_ch, base_width, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(),
            nn.Conv2d(base_width, base_width*2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(base_width*2),
            nn.LeakyReLU(),
            nn.Conv2d(base_width*2, base_width*4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(base_width*4),
            nn.LeakyReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(base_width*4, 1),
            nn.Sigmoid() # squeeze output in the 0-1 axis so we can apply the loss
        )
    
    def forward(self, X):
        return self.layers(X).flatten()

Check if shape is correct: given batch of size $B$, we expect to have a tensor of shape `[B]` as output (one probability output per image in the batch).

In [None]:
d = Discriminator()
d.apply(weights_init)
d(torch.rand(10,1,28,28)).shape


The loss function for the discriminator is the usual binary cross entropy.

We moreover use `Adam` as optimizer both for the discriminator and the generator. The hyperparameters presented in the cell below are the ones suggested in the literature. Notice the use of a slightly smaller learning rate in both (usually the Adam default is `1e-3`).

In [None]:
loss_fn = nn.BCELoss()

device = "cuda:0"

real_label = 1
fake_label = 0

lr = 2e-4
beta1 = 0.5
beta2 = 0.999

# 2 Adams with same hyperparameters
optim_d = torch.optim.Adam(d.parameters(), lr=lr, betas=(beta1, beta2))
optim_g = torch.optim.Adam(g.parameters(), lr=lr, betas=(beta1, beta2))


We train for 3 epochs (it'll be well enough for this kind of CNN on MNIST...) and print the stats each 25 iterations.

In [None]:
ite_print = 25
num_epochs = 3

d = d.to(device)
g = g.to(device)

Moreover, at each split, we're going to generate some images starting from the same random sample (`fixed_noise`) from a $\mathcal{N}(0_{100},I_{100})$ (the latent space). We store these images in the `fakeimgs` array.

In [None]:
n_fake = 24
fixed_noise = torch.randn([n_fake, dim_latent]).to(device)

fakeimgs = []

We build the training function in blocks.

First off, we start with the discriminator. It's split in two parts:

1. train the discriminator on the real data.
2. train the discriminator on the synthetic data.

The first step is very easy and basically attaches an array of 1's (`real_label`) as large as the batch size as ground truth to the images (since they're all real).

The second step instead is a tiny bit more complicated: we start from a generic `synthetic_data` which we will generate outside this function, then attach to it an array of 0's as ground truth, then feed it to the discriminator and calculate the loss.

The third function (`step_discriminator`) combines the previous two such that we first run the part on the real data, then the one on the synthetic data.

Notice the `backward`s and `step`s calls:
* we call `backward` twice: first in the real data part, then in the synthetic data part. This actually _accumulates_ (i.e. sums) gradients, which is what we want: the differentiation is a linear operator -- $\nabla(f+g)=\nabla f + \nabla g$. Backward should be called on the sum of the two losses, but we can do it the same by calling backward two times on the two partial losses.
* once we have accumulated twice, we can call `optimizer.step()` in the `step_discriminator` function.

Also, we return the errors (losses) just for statistical purposes.

In [None]:
def discriminator_real_data(discriminator, real_data, loss_fn, device, real_label=1):
    batch_size = real_data.shape[0]

    ground_truth_real = torch.full([batch_size], real_label, dtype=torch.float).to(device)

    discriminator_output_real = discriminator(real_data).view(-1)

    error_real = loss_fn(discriminator_output_real, ground_truth_real)
    error_real.backward()

    return error_real

def discriminator_synthetic_data(discriminator, generator, synthetic_data, loss_fn, device, fake_label=0):
    batch_size = synthetic_data.shape[0]

    ground_truth_fake = torch.full([batch_size], fake_label, dtype=torch.float).to(device)

    discriminator_output_fake = discriminator(synthetic_data.detach()).view(-1)
    error_fake = loss_fn(discriminator_output_fake, ground_truth_fake)

    error_fake.backward()

    return error_fake

def step_discriminator(discriminator, generator, real_data, synthetic_data, optimizer, loss_fn, device, real_label=1, fake_label=0):
    optimizer.zero_grad()
    error_real = discriminator_real_data(discriminator, real_data, loss_fn, device, real_label=real_label)
    error_fake = discriminator_synthetic_data(discriminator, generator, synthetic_data, loss_fn, device, fake_label=fake_label)
    optimizer.step()

    error_overall = error_real.item() + error_fake.item()

    return error_overall

Once we have advanced the discriminator, it's time for the generator. We utilize the same synthetic images used for the discriminator. This time, we attach the ground truth of 1's, for the reason talked about before.

This function is almost completely identical to `discriminator_synthetic_data`, but for the fact that we flip the label of the ground truth and we call step in the end.

In [None]:
def step_generator(discriminator, generator, synthetic_data, optimizer, loss_fn, device, dim_latent, real_label=1):
    optimizer.zero_grad()
    batch_size = synthetic_data.shape[0]
    ground_truth_synth = torch.full([batch_size], real_label, dtype=torch.float).to(device)
    discriminator_output_synth = discriminator(synthetic_data)

    error_generator = loss_fn(discriminator_output_synth, ground_truth_synth)

    error_generator.backward()

    optimizer.step()

    return error_generator.item()

We can wrap it all up together here.

We generate the synthetic data once per iteration using the generator at the current state.

We log and save the synthetic images generated from `fixed_noise` each `ite_print` iterations and at end of epoch.

In [None]:
g.train()
d.train()
for epoch in range(num_epochs):
    for i, (real_data, _) in enumerate(trainloader):      

        real_data = real_data.to(device)
        noise = torch.randn([real_data.shape[0], dim_latent]).to(device)
        synthetic_data = g(noise)
        error_discrim = step_discriminator(d, g, real_data, synthetic_data, optim_d, loss_fn, device, real_label=real_label, fake_label=fake_label)
        error_generator = step_generator(d, g, synthetic_data, optim_g, loss_fn, device, dim_latent, real_label=real_label)

        if (i + 1) % ite_print == 0 or (i + 1) == len(trainloader):
            print(f"Ep. {epoch + 1}/{num_epochs} It. {i+1}/{len(trainloader)} >>> D loss {error_discrim:.3f} | G loss {error_generator:.3f}")
            g.eval()
            sample = g(fixed_noise).detach().cpu()
            fakeimgs.append(sample.reshape(sample.shape[0], 28, 28))
            g.train()
    


Let us produce a collage of splits to visualize the results.
This function allows us to stitch together an array of images onto a single numpy array.

In [None]:
def produce_collage(nrow, ncol, img_range, dim=(28,28)):
    img_collage = np.empty((nrow*dim[0], ncol*dim[1]))
    for i, img in enumerate(img_range):
        index_row = (i // ncol) * dim[0]
        index_col = (i % ncol) * dim[1]
        img_collage[index_row:index_row+dim[0], index_col:index_col+dim[1]] = img.reshape(dim)
    
    return img_collage

Now, we shall visualize the first image generated at each split (each 25 iterations). We enlarge a bit the image using cv2 as 28*28 is too small.

In [None]:
collage = produce_collage(3, 38, [imgs_epoch[0] for imgs_epoch in fakeimgs])
collage = cv2.resize(collage, (collage.shape[1]*4, collage.shape[0]*4), interpolation=cv2.INTER_NEAREST)

Remember, the image is still a float array in the range defined by the transforms we applied before.
If we want to save it, we first need to

1. rescale it in the 0-255 interval
2. convert it to `uint8`

In [None]:
def array_to_image(array):
    img = array - array.min()
    img = img * 255 / img.max()
    return img.astype("uint8")

Now we can convert and save:

In [None]:
collage = array_to_image(collage)
cv2.imwrite("img/GAN_collage.jpg", collage)

Takeaway: training vanilla GANs is **hard** and often we don't reach a real **convergence**: the training is a process of finding an equilibrium between generator and discriminator and you must fine-tune both architectures (and their optimizers) to find a good result. The literature on GANs is very extensive and a lot of different variants have been proposed to try solving these issues. One of the variants which we will present is called **Wasserstein GAN**.

## Wasserstein GAN

The Wasserstein GAN (WGAN) ([2](https://arxiv.org/pdf/1701.07875.pdf)) replaces the discriminator with a **critic**, which is still an ANN, but with a different loss function: while the discriminator minimizes the binary cross-entropy loss, the critic minimizes an approximation of the **Wasserstein Distance** (also called Earth Mover's Distance, EMD) between the real distribution of data ($P_r$) and the *guide* distribution ($P_\theta$), which is the distribution being learned by our generator.

![](img/emd.jpg)

$\text{EMD}(P_r, P_\theta) = \inf_{\gamma\in\Pi(P_r,P_\theta)} \text{E}_{(x,y)\sim\gamma}(\vert\vert x-y \vert\vert)$

Here, $\Pi(P_r,P_\theta)$ represent the family of all possible joint distribution of $X, Y$ whose marginal $P(X)=P_r$ and $P(Y)=P_\theta$, and $\gamma$ is hence a sample from this family.

More specifically, we wish to find the *minimal density mass* to be transported from $P_\theta$ to $P_r$ such that $P_\theta$ *becomes* $P_r$.

It turns out that the critic can be an ANN $f_\omega$ trained by back-propagating the gradient according to the following loss function:

$\mathcal{L}_f = - \frac{1}{b}\sum_{i=1}^{b} f_\omega(x^{(i)}) + \frac{1}{b}\sum_{i=1}^{b}f_\omega(g_\theta(z^{(i)}))$

The notation is easily inferrable from before:
* $b$ is the batch size
* $g$ is the generator parametrized by $\theta$
* $z$ is a sample from the *guide distribution* (usually, a Gaussian)
* $f$ is the critic parametrized by $\omega$

$\omega$ (parameters of $f$) is then updated using RMSProp (in the original paper) and clipping the resulting new parameters in the interval $[-0.01, 0.01]$.

The loss of $g$ is instead

$\mathcal{L}_g = - \frac{1}{b}\sum_{i=1}^{b}f_\omega(g_\theta(z^{(i)}))$

Note that:

1. The above losses don't require labels. We can then get rid of the "real" and "fake" labels we defined in the vanilla GAN
2. The correct approximation of EMD with the above loss functions is valid only when the critic abides to a local form of lipschitzianity. This is enforced by clipping the weights $\omega$ in a pre-defined interval ($[-0.01, 0.01]$ in the paper)
3. The training of the critic and the generator doesn't always happen in an alternate fashion: $\omega$ is updated $k$ times each update of $\theta$. In the paper, $k=5$.

In this case, we're going to use MLPs for both critic (discriminator) and generator.

In [None]:
device = use_gpu_if_possible()

d = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28, 64),
    nn.ReLU(),
    nn.Linear(64, 128),
    nn.ReLU(),
    nn.Linear(128, 1)
).to(device)

g = nn.Sequential(
    nn.Linear(100, 64),
    nn.LeakyReLU(),
    nn.BatchNorm1d(64),
    nn.Linear(64, 256),
    nn.LeakyReLU(),
    nn.BatchNorm1d(256),
    nn.Linear(256, 512),
    nn.LeakyReLU(),
    nn.BatchNorm1d(512),
    nn.Linear(512, 28*28),
    nn.Tanh()
).to(device)

The paper specifically instruct to use RMSProp as an optimizer with a specific very small learning rate.

In [None]:
optim_d = torch.optim.RMSprop(d.parameters(), lr=5e-5)
optim_g = torch.optim.RMSprop(g.parameters(), lr=5e-5)

The idea is that we train the critic $k$ times before updating the generator. We call $n$ `n_critic_training`.

Note also that we need a lot of training epochs before obtaining a quality result.

In [None]:
# defines the ratio of critic update vs generator update
n_critic_training = 5
num_epochs = 200
fakeimgs = []

Onto the process of training, we can simply redefine the training loop used before by replacing the correct loss functions. Note that we don't need the labels anymore since the loss functions operate merely on the output of the critic and the discriminator.

Note that we can fuse the part concerning real and synthetic data as right now they're much shorter.

In [None]:
def step_critic(critic, generator, real_data, synthetic_data, optimizer, critic_clamp_value=0.01):
    optimizer.zero_grad()
    error_critic = -critic(real_data).mean() + critic(synthetic_data.detach()).mean()
    error_critic.backward()
    optimizer.step()

    for param in critic.parameters():
        param.data.clamp_(-critic_clamp_value, critic_clamp_value)

    return error_critic.item()

def step_generator(critic, generator, synthetic_data, optimizer):
    optimizer.zero_grad()

    error_generator = -critic(synthetic_data).mean()
    error_generator.backward()

    optimizer.step()

    return error_generator.item()

In [None]:
d.load_state_dict(torch.load(models_push/wgan/discriminator.pt))
g.load_state_dict(torch.load(models_push/wgan/generator.pt))

In [None]:
g.train()
d.train()
for epoch in range(num_epochs):
    for i, (real_data, _) in enumerate(trainloader):
        real_data = real_data.to(device)
        noise = torch.randn([real_data.shape[0], dim_latent])
        #noise = torch.autograd.Variable(Tensor(np.random.normal(0, 1, (real_data.shape[0], dim_latent))))
        synthetic_data = g(noise.to(device))

        error_critic = step_critic(d, g, real_data, synthetic_data, optim_d)

        if i % n_critic_training == 0:
            # note: the WGAN paper specifies to re-sample data in G training
            noise = torch.randn([real_data.shape[0], dim_latent])
            synthetic_data = g(noise.to(device))
            error_generator = step_generator(d, g, synthetic_data, optim_g)
        
    print(f"Ep. {epoch + 1}/{num_epochs} >>> C loss {error_critic:.3f} | G loss {error_generator:.3f}")
    g.eval()
    sample = g(fixed_noise).detach().cpu()
    fakeimgs.append(sample.reshape(fixed_noise.shape[0], 28, 28))
    g.train()




Since it's a long training, let's save the params.

In [None]:
torch.save(d.state_dict(), "models_push/wgan/discriminator.pt")
torch.save(g.state_dict(), "models_push/wgan/generator.pt")

Let's visualize a couple of results

In [None]:
plt.imshow(fakeimgs[197][0],cmap="gray")

In [None]:
collage = produce_collage(5, 5, [fakeimgs[int(i)][15] for i in np.linspace(0, 199, 25)])
collage = cv2.resize(collage, (collage.shape[1]*4, collage.shape[0]*4), interpolation=cv2.INTER_NEAREST)
collage = array_to_image(collage)
cv2.imwrite("img/WGAN_collage4.jpg", collage)

## Variational AutoEncoder

A Variational AutoEncoder (VAE) can be seen as a special case of AutoEncoder (AE) in which the bottleneck layer is seen as a collection of random variables instead of a fixed array of scalars.

![](https://miro.medium.com/max/3110/0*uq2_ZipB9TqI9G_k)

While the aim of the AE is to operate dimensionality reduction, finding a latent space which is able to maximally "capture" the variability of the data space, VAEs instead try to find a _**probabilistic space**_ (whose dimension is hopefully smaller than the dimension of the data) which can approximate the data distribution starting from a family of known distributions (usually Gaussians).

This space is represented by the **bottleneck** (_latent_) **layer** in the image above.

AEs are a stack of two structures: an **econder**, which is a cascade of layers of decreasingly smaller size leading to the bottleneck, while the **decoder** starts from the bottleneck, increases the size of each hidden space, leading to the output layer, which has the **same size of the input layer**. The aim is to train the network such that the input and the output activations are **as close as possible**.

To achieve this, AEs use the **reconstruction loss** $\sum_{i=1}^{N} (\mathbf{\hat{x}}_i - \mathbf{x}_i)$ to find a deterministic mapping $e:\mathbb{R}^p \rightarrow \mathbb{R}^d$ and another $d:\mathbb{R}^d \rightarrow \mathbb{R}^p$, with $d \ll p$. $e$ is the **encoder** and $d$ the **decoder**.

VAEs still retains the structure of encoder/decoder, but this time they are viewed as two probabilistic distributions (mappings):
* encoder: $q_\theta(z\vert x)$. $z$ is the latent (random) variable represented by the bottleneck. We're given an input (i.e., an image), and we _encode_ it in a latent space which is a vector of random variables of a specific distribution family (see later)
* decoder: $p_\varphi(x\vert z)$. We're giving a sample from the latent space and we want to re-construct the input from it.

VAEs train encoders such that the latent representation is _as close as possible_ to a collection of known distributions, and decoders such that the reconstruction of the input is _as close as possible_ to the original input.

VAEs employ a variation of the reconstruction loss, which is the reconstruction log-likelihood: $\text{RLL} = \log(p_\varphi(x\vert z))$. Given a sample from the **latent space**, we re-build an input $x$ through a learnt probability distribution $p(z)$. Via $p_\varphi(x\vert z)$ we can reconstruct the input, then evaluate how _likely_ this reconstruction is.

Alongside the RLL, we attach a regularization term, which forces the encoder to stick to the encoding onto the desired family of distributions. This regularization term is represented by the Kullback-Leibler divergence, which is a measure of "distance" between probability distributions: $\text{KL}[q_\theta(z\vert x)\vert\vert p(z)]$. $p(z)$ is the desided distribution for the latent space, usually $\mathcal{N}(\mathbf{0}_d,I_d)$.

The loss for the training of VAEs is hence the negative RLL (negative because we want to maximize reconstruction likelihood) plus the KL regularization term. The penalty is necessary to _communicate_ the VAE that it must produce the encoding "close" to the $\mathcal{N}(0,1)$ density. Wrapping things up:

$l_i(\theta,\varphi) = - \mathbb{E}_{z\sim q_\theta(z|x_i)} \{\log[p_\varphi (x_i\vert z)] \} + \text{KL}[q_\theta(z\vert x_i) \vert\vert p(z)]$

where $i$ is a single datum, and $z\sim q_\theta(z|x_i)$ represent a sample from the encoder given the input datum $x_i$.

![](https://www.jeremyjordan.me/content/images/2018/06/Screen-Shot-2018-06-20-at-2.51.06-PM.png)



### Practical implementations

#### Architecture

To render it possible to the encoder to construct a probability distribution, we build it this way:
* let $u$ be the latent dimension of $z$
    * we'd expect the bottleneck layer to have dimension $u$
* we replicate the bottleneck layer for $k$ times, where $k$ is the dimensionality of the sufficient statistic defining the distribution of $Z$
    * e.g. for the Gaussian distribution, $(\mu, \sigma)$ is a sufficient statistic, so we replicate the bottleneck layer two times
* the penultimate layer in the encoder will then feed its activations to $k$ different layers running in parallel

If we have a bottleneck layer with 16 neurons, each neuron $j$ will represent a Gaussian distribution with mean $\mu_j$ and std $\sigma_j$. Practically, we'll have 16 neurons from a layer producing $\mu$'s and 16 neurons from a layer producing $\sigma$'s.

![](https://miro.medium.com/max/2540/1*R0irE3x0tXIYndqRLprFmw.png)

#### Sampling of latent variable

To produce $z$ and feed it to the decoder, we hence sample from each of these 16 Gaussians by using the values of $\mu$s and $\sigma$'s.

Also, note that the encoder is better fit to produce values in the range $(-\infty, +\infty)$ (unless you use some specific activation functions to limit its range), so, for a normal distribution in which $\sigma\in(0,+\infty)$, usually you produce $\log\sigma$ through the encoder and then exponentiate it while sampling $z$ to produce a strictly positive value.

#### Reparametrization trick

Since the sampling of $z$ as explained before would cause non-differentiability within the network, we reproduce the sampling via a reparametrization: $z = \mu(x) + \Sigma(x)^{1/2} \cdot\varepsilon$, where $\varepsilon \sim \mathcal{N}(0,1)$.

![](https://www.jeremyjordan.me/content/images/2018/03/Screen-Shot-2018-03-18-at-4.36.34-PM.png)

#### $\text{RLL}$

If we think about our pixels $x$ as Bernoulli random variables with an uknown parameter $\pi$, then the reconstruction likelihood $p_\varphi(x_i\vert z)$ becomes $(\pi)^{x_i} (1-\pi)^{1-x_i}$.

Applying $\log$, we have $x_i\log(\pi) + (1-x_i)\log(1-\pi)$.

Eventually, changing sign, we obtain the Binary Cross Entropy (BCE) Loss.

Note that:
* $x_i$ is treated as a "reconstructed pixel" for image $i$
* $\pi$ is the groud truth, i.e. the "original pixel"
    * we're then penalizing $x_i$ when it _goes too far off_ of $\pi$
* since we're ensembling the log-likelihood in a sample of images (the batch), we're summing each BCE term, not averaging it. We'll hence need to communicate that to PT by specifying the argument `reduction="sum"` to the loss. This tells PT that the loss must be calculated by summing individual tokens, not averaging them (which is the default behaviour).


In [None]:
class VAEEncoder(nn.Module):
    def __init__(self, dim_input:int, dim_latent:int):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(dim_input, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128)
        )
        self.encode_mu = nn.Linear(128, dim_latent)
        self.encode_sigma = nn.Linear(128, dim_latent)
    
    def forward(self, X:torch.Tensor) -> torch.Tensor:
        h = self.encoder(X)
        latent_mu = self.encode_mu(h)
        latent_log_sigma = self.encode_sigma(h)
        return latent_mu, latent_log_sigma

class VAEDecoder(nn.Module):
    def __init__(self, dim_input:int, dim_latent:int):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.Linear(dim_latent, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Linear(512, dim_input),
            nn.Sigmoid()
        )
    
    def forward(self, z:torch.Tensor) -> tuple:
        reconstruction = self.decoder(z)
        return reconstruction

class VAE(nn.Module):
    def __init__(self, dim_input:int, dim_latent:int):
        super().__init__()
        self.encoder = VAEEncoder(dim_input, dim_latent)
        self.decoder = VAEDecoder(dim_input, dim_latent)
    
    def sample_latent(self, mu, log_sigma):
        # with reparametrization trick
        device = next(self.parameters()).device
        # white_noise is the epsilon (~N(0,1))
        white_noise = torch.randn(mu.shape).to(device)
        return mu + (log_sigma / 2).exp() * white_noise

    def forward(self, X):
        mu, log_sigma = self.encoder(X)
        z = self.sample_latent(mu, log_sigma)
        reconstruction = self.decoder(z)
        return reconstruction, mu, log_sigma


### Setting the stages for the training

#### Loss

As told before, the loss is a composition of the `BCELoss` (which we can use "for free" since it's already present in PT) and the KL part, which we can calculate in close form since it's applied to two normal distributions.

In [None]:
def reconstruction_loss(output, ground_truth):
    reconstruction_loss_fn = nn.BCELoss(reduction="sum")
    return reconstruction_loss_fn(output, ground_truth)

def kl_vae(mu, log_sigma):
    # calculated in close form
    kl = .5 * (log_sigma.exp() ** 2 + mu ** 2 - 1) - log_sigma
    return kl.sum()

def vae_loss(output, mu, log_sigma, ground_truth):
    rec_loss_val = reconstruction_loss(output, ground_truth)
    
    kl_loss_val = kl_vae(mu, log_sigma)
    return rec_loss_val + kl_loss_val

In [None]:
latent_dim = 32
vae = VAE(28*28, latent_dim)

In [None]:
num_epochs = 10

optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
device = use_gpu_if_possible()
ite_print = len(trainloader)

### Training

In [None]:
vae.load_state_dict(torch.load("models_push/vae/vae.pt"))

In [None]:
vae.train()
vae = vae.to(device)
flatten = nn.Flatten()
for epoch in range(num_epochs):
    loss_meter = AverageMeter()
    for i, (X, _) in enumerate(trainloader):
        X = X.to(device)

        optimizer.zero_grad()

        X_hat, mu, log_sigma = vae(X)
        loss = vae_loss(X_hat, mu, log_sigma, flatten(X))
        loss.backward()
        optimizer.step()

        loss_meter.update(loss.item(), X.shape[0])

        if (i + 1) % ite_print == 0 or (i + 1) == len(trainloader):
            print(f"Epoch {epoch+1} | Loss {loss_meter.avg}")

In [None]:
torch.save(vae.state_dict(), "models_push/vae/vae.pt")

### Generating images from our VAE

#### Image generation starting from a test Image

Let us consider the first test image

In [None]:
img = next(iter(testloader))[0][0]
label = next(iter(testloader))[1][0]
plt.imshow(img.reshape(28,28,1), cmap="gray")
print(label)

Let's first check how it reconstructs the image:

In [None]:
vae.eval()
vae.cpu()
img_hat, mu, log_sigma = vae(img)
plt.imshow(img_hat.reshape(28,28).detach().numpy(), cmap="gray")

In [None]:
z = vae.sample_latent(mu, log_sigma)

let us try to _tweak_ the latent distribution by applying some white noise to it...

In [None]:
z_tweaked = vae.sample_latent(mu.add(torch.rand_like(mu)), log_sigma)
img_hat_tweaked = vae.decoder(z_tweaked)
plt.imshow(img_hat_tweaked.reshape(28,28).detach().numpy(), cmap="gray")

Let us see the values of $\mu$ and $\Sigma$...

In [None]:
mu, log_sigma.exp()

Let's build a new sequence of images by varying the nineteenth value of this latent distribution on a scale going from $\mu[18] - 3\sigma[18]$ to $\mu[18] + 3\sigma[18]$

In [None]:
img_range = []
for j in torch.linspace((mu[0,18] - 3*log_sigma.exp()[0,18]).item(), (mu[0,18] + 3*log_sigma.exp()[0,18]).item(), 20):
    z2 = z.clone()
    z2[0,18] = j
    img_range.append(vae.decoder(z_tweaked).detach())

In [None]:
c = produce_collage(5, 4, img_range)
c = array_to_image(c)
c = cv2.resize(c, (c.shape[1]*4, c.shape[0]*4), interpolation=cv2.INTER_NEAREST)

In [None]:
cv2.imwrite("img/VAE_collage0.jpg", c)

Let's try to vary the fifth component on a larger scale

In [None]:
img_range = []
for j in torch.linspace(-100, 100, 25):
    z2 = z.clone()
    z2[0,0] = j
    img_range.append(vae.decoder(z2).detach())
c = produce_collage(5, 5, img_range)
c = array_to_image(c)
c = cv2.resize(c, (c.shape[1]*4, c.shape[0]*4), interpolation=cv2.INTER_NEAREST)
cv2.imwrite("img/VAE_collage1.jpg", c)

do it with another one...

In [None]:
img_range = []
for j in torch.linspace(-100, 100, 25):
    z2 = z.clone()
    z2[0,5] = j
    img_range.append(vae.decoder(z2).detach())
c = produce_collage(5, 5, img_range)
c = array_to_image(c)
c = cv2.resize(c, (c.shape[1]*4, c.shape[0]*4), interpolation=cv2.INTER_NEAREST)
cv2.imwrite("img/VAE_collage2.jpg", c)

In [None]:
layers = [
    {"n_in": 784, "n_out": 16, "batchnorm": False},
    {"n_out": 32, "batchnorm": True},
    {"n_out": 64, "batchnorm": True},
    {"n_out": 10, "batchnorm": True}
]
net = MLPCustom(layers)
net.load_state_dict(torch.load("models_push/mlp_custom_mnist/mlp_custom_mnist.pt"))

In [None]:
normalize = lambda x: (x - 0.1307) / 0.3081
img_range_norm = [normalize(img) for img in img_range]

In [None]:
net.eval()
net(torch.cat(img_range_norm)).argmax(dim=1)

The part on VAE has been mainly inspired by this [tutorial by Jeremy Jordan](https://www.jeremyjordan.me/variational-autoencoders/).



#### References

[1](https://arxiv.org/abs/1511.06434) Radford, Alec, Luke Metz, and Soumith Chintala. "Unsupervised representation learning with deep convolutional generative adversarial networks."

[2](https://arxiv.org/pdf/1701.07875.pdf) Arjovsky, Martin, et al. "Wasserstein GAN."

[3](https://arxiv.org/pdf/1406.2661.pdf) Goodfellow, Ian, et al. "Generative Adversarial Nets."