# Adversarial AutoEncoder
![Adversarial AutoEncoder architecture](https://i.imgur.com/sgsfLwQ.png)

In [None]:
# Install necessary packages
!pip install matplotlib numpy torch torchvision 

In [None]:
# Import necessary modules
from datetime import datetime
from itertools import chain

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In thic cell we initialize the data loaders for the MNIST dataset that will be used for provide data to training loop later.

Here, we also specify size of the minibatch.

In [None]:
batch_size = 100

train_data = datasets.MNIST('data/mnist', train=True, download=True,
                           transform=transforms.ToTensor())

train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, 
                              num_workers=4, pin_memory=True)

In this cell we specify dimensions of vectors for:
* `X_len` - linearized images (MNIST containes images of size $28 \times 28 = 784$)
* `z_len` - encoding vector. Here, 64 is used, but you can try smaller or larger vectors


In [None]:
X_len = 28 * 28
z_len = 64

Model of an Encoder $E: X \rightarrow z$, that takes linearized images $X \in [0, 1]^{784}$ and produces encoding $z \in \mathbb{R}^{h}$ of length $h$ (specified by `z_len` from previous cell).



In [None]:
class Encoder(nn.Module):
  def __init__(self, X_dim, z_dim):
    super().__init__()
    
    self.model = nn.Sequential(
      torch.nn.Linear(X_dim, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, z_dim)
    )
  
  def forward(self, X):
    z = self.model(X)
    return z

Model of a Generator $G: z \rightarrow X$, that takes encoding $z \in \mathbb{R}^{h}$ of length $h$ (specified by `z_len` from 2 cells ago) and produces linearized images $X \in [0, 1]^{784}$.


In [None]:
class Generator(nn.Module):
  def __init__(self, X_dim, z_dim):
    super().__init__()
    
    self.model = nn.Sequential(
      torch.nn.Linear(z_dim, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, X_dim),
      torch.nn.Sigmoid()
    )
  
  def forward(self, z):
    X = self.model(z)
    return X

Model of a Discriminator $D: z \rightarrow [0, 1]$ - module that should give high probability $p$ for samples that come from the desired unit gaussian distribution $z \sim \mathcal{N}(0, I)$ and low probability $p$ for samples $\hat{z} \in \mathbb{R}^{h}$ that come from encoder $E$.


In [None]:
class Discriminator(nn.Module):
  def __init__(self, z_dim):
    super().__init__()
    
    self.model = nn.Sequential(
      torch.nn.Linear(z_dim, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 1),
      torch.nn.Sigmoid()
    )
  
  def forward(self, z):
    p = self.model(z)
    return p

In [None]:
E = Encoder(X_dim=X_len, z_dim=z_len).to(device)
G = Generator(X_dim=X_len, z_dim=z_len).to(device)
D = Discriminator(z_dim=z_len).to(device)

Weight initialization - we use the weight initialization from "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification" by He et al.

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname in ('Conv1d', 'Linear'):
        torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)

E = E.apply(weights_init)
G = G.apply(weights_init)
D = D.apply(weights_init)

Define optimizers that will calculate optimization steps for our weights. Note, that Encoder and Generator share the same optimizer. Here we use Adam from "Adam: A Method for Stochastic Optimization" by Kingma et al.

In [None]:
learning_rate = 3e-4
betas = (0.9, 0.999)

EG_optimizer = optim.Adam(chain(E.parameters(), G.parameters()), 
                          lr=learning_rate, betas=betas)
D_optimizer = optim.Adam(D.parameters(), lr=learning_rate, betas=betas)

Define the losses for AAE here. Remember to add epsilon ($10^{-8}$ is enough), where might be numerical instability (e.g. low values passed to logarithm function).

Losses for AAE are:

$L_D = \log(D(z)) + \log(1 - D(E(X)))$

$L_{EG} = \log(D(E(X)))$

$L_{reconstruction} = \text{bce}(X, G(E(X)))$

where $\log$ means natural logarithm, $z \sim \mathcal{N}(0, 1)$ and bce - binary crossentropy loss implemented in `torch.nn` module

In [None]:
def loss_fn_eg(p_gen):
  eps = 1e-8
  return None
  
def loss_fn_d(p_real, p_gen):
  eps = 1e-8
  return None

def loss_fn_reconstruction(X, X_):
  return None

Training procedure for AAE - fill the training steps, that are currently `None`

In [None]:
max_epochs = 300
for epoch_n in range(1, max_epochs+1):
  
  D.train()
  E.train()
  G.train()

  reconstruction_losses = 0.0
  d_regularization_losses = 0.0
  eg_regularization_losses = 0.0
  
  start = datetime.now()
  for i, (X, y) in enumerate(train_dataloader, 1):
    X = X.to(device)
    y = y.to(device)
    
    X = X.view(X.size(0), -1)
    
    z_ = None
    X_ = None
    
    #
    # Reconstruction stage
    #
    loss_rec = loss_fn_reconstruction(X_, X)
    
    loss_rec.backward(retain_graph=True)
    reconstruction_losses += loss_rec.item()
    EG_optimizer.step()

    EG_optimizer.zero_grad()
    E.zero_grad()
    G.zero_grad()
    
    #
    # Regularization stage
    #
    z = None
    p_real = None
    p_gen = None
    
    loss_d  = -torch.mean(loss_fn_d(p_real, p_gen))
    loss_d.backward(retain_graph=True)
    d_regularization_losses += loss_d.item()
    D_optimizer.step()
    
    loss_eg = -torch.mean(loss_fn_eg(p_gen))
    loss_eg.backward()
    eg_regularization_losses += loss_eg.item()
    EG_optimizer.step()
    
    D_optimizer.zero_grad()
    D.zero_grad()
    EG_optimizer.zero_grad()
    E.zero_grad()
    G.zero_grad()
  
   
  print(f'Epoch {epoch_n:03d}: Z mean/std: {z_.mean():.4f}/{z_.std():.4f} '
        f'Loss_EG_REC: {reconstruction_losses / i:.4f} Loss_EG_REG: {eg_regularization_losses / i:.4f} '
        f'Loss_D_REG: {d_regularization_losses / i:.4f}  Time: {datetime.now() - start}')
  
  
  # Visualize learning
  with torch.no_grad():
    X = X[:10]
    reconstructions = G(E(X)).view(10, 28, 28).cpu().numpy()

    reals = X.view(10, 28, 28).cpu().numpy()

    z = torch.randn(10, z_len).to(device)
    samples = G(z).view(10, 28, 28).cpu().numpy()

    fig, ax = plt.subplots(1, 10, figsize=(5, 1))
    fig.suptitle(f'Real: {epoch_n}')
    for i, real in enumerate(reals):
      ax[i].imshow(real)
      ax[i].axis('off')
    plt.show()

    fig, ax = plt.subplots(1, 10, figsize=(5, 1))
    fig.suptitle(f'Reconstructions: {epoch_n}')
    for i, reconstruction in enumerate(reconstructions):
      ax[i].imshow(reconstruction)
      ax[i].axis('off')
    plt.show()

    fig, ax = plt.subplots(1, 10, figsize=(5, 1))
    fig.suptitle(f'Synthetic: {epoch_n}')
    for i, sample in enumerate(samples):
      ax[i].imshow(sample)
      ax[i].axis('off')
    plt.show()