# BiGAN / ALI
![Bigan_architecture](https://i.imgur.com/FglUXHR.png)

In [None]:
!pip install torch torchvision matplotlib

In [None]:
from datetime import datetime
from itertools import chain

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In thic cell we initialize the data loaders for the MNIST dataset that will be used for provide data to training loop later.

Here, we also specify size of the minibatch.

In [None]:
batch_size = 100

train_data = datasets.MNIST('data/mnist', train=True, download=True,
                           transform=transforms.ToTensor())

train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, 
                              num_workers=4, pin_memory=True)

In this cell we specify dimensions of vectors for:
* `X_len` - linearized images (MNIST containes images of size $28 \times 28 = 784$)
* `z_len` - encoding vector. Here, 64 is used, but you can try smaller or larger vectors


In [None]:
X_len = 28 * 28
z_len = 64

Model of an Encoder $E: X \rightarrow z$ - takes linearized images $X \in [0, 1]^{784}$ as an input and returns encoding $z \in \mathbb{R}^{h}$, where $h$ - length of a vector specified by `z_len` in previous cell.







In [None]:
class Encoder(nn.Module):
  def __init__(self, X_dim, z_dim):
    super().__init__()
    
    self.model = nn.Sequential(
      torch.nn.Linear(X_dim, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, z_dim)
    )
  
  def forward(self, X):
    z = self.model(X)
    return z

Model of a Generator $G: z \rightarrow X$ - takes feature vecor $z$ sampled from unit gaussian distribution $z \sim \mathcal{N}(0, I)  \in \mathbb{R}^{h}$ and produces a linearized image $X \in [0, 1]^{784}$







In [None]:
class Generator(nn.Module):
  def __init__(self, X_dim, z_dim):
    super().__init__()
    
    self.model = nn.Sequential(
      torch.nn.Linear(z_dim, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, X_dim),
      torch.nn.Sigmoid()
    )
  
  def forward(self, z):
    X = self.model(z)
    return X

Model of a Discriminator $D: (X, z) \rightarrow [0, 1]$ - takes linearized image $X$ (taken from the dataset or generated by a Generator) and feature vecor $z$ (sampled or infered by an Encoder) and returns the probability, that input is a pair of type $(X, E(X))$







In [None]:
class Discriminator(nn.Module):
  def __init__(self, x_dim, z_dim):
    super().__init__()
    
    self.model = nn.Sequential(
      torch.nn.Linear(x_dim + z_dim, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 1),
      torch.nn.Sigmoid()
    )
  
  def forward(self, X, z):
    Xz = torch.cat([X, z], dim=1)
    p = self.model(Xz)
    return p

In [None]:
E = Encoder(X_dim=X_len, z_dim=z_len).to(device)
G = Generator(X_dim=X_len, z_dim=z_len).to(device)
D = Discriminator(x_dim=X_len, z_dim=z_len).to(device)

Weight initialization - we use the weight initialization from "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification" by He et al.

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname in ('Conv1d', 'Linear'):
        torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)

E = E.apply(weights_init)
G = G.apply(weights_init)
D = D.apply(weights_init)

Define optimizers that will calculate optimization steps for our weights. Note, that Encoder and Generator share the same optimizer. Here we use Adam from "Adam: A Method for Stochastic Optimization" by Kingma et al.

In [None]:
learning_rate = 3e-4
betas = (0.9, 0.999)

EG_optimizer = optim.Adam(chain(E.parameters(), G.parameters()), 
                          lr=learning_rate, betas=betas)
D_optimizer = optim.Adam(D.parameters(), lr=learning_rate, betas=betas)

Define the losses for BiGAN here. Remember to add epsilon ($10^{-6}$ is enough), where might be numerical instability (e.g. low values passed to logarithm function).

Losses for BiGAN are:

$L_D = \log(D(X, E(X))) + \log(1 - D(G(z), z))$

$L_{EG} = \log(D(G(z), z)) + \log(1 - D(X, E(X)))$

where $\log$ means natural logarithm

In [None]:
def loss_fn_eg(p_enc, p_gen):
  eps = 1e-8
  return None

def loss_fn_d(p_enc, p_gen):
  eps = 1e-8
  return None

Training procedure for BiGAN - fill the training steps, that are currently `None`

In [None]:
max_epochs = 300
for epoch_n in range(1, max_epochs+1):
  
  D.train()
  E.train()
  G.train()

  d_losses = 0.0
  eg_losses = 0.0
  
  start = datetime.now()
  for i, (X, y) in enumerate(train_dataloader, 1):
    X = X.to(device)
    y = y.to(device)
    
    X = X.view(X.size(0), -1)
   
    z_ = None
    z = None
    X_ = None
  
    p_enc = None
    p_gen = None
    
    D_optimizer.zero_grad()
    D.zero_grad()
    loss_d  = -torch.mean(loss_fn_d(p_enc, p_gen))
    loss_d.backward(retain_graph=True)
    d_losses += loss_d.item()
    D_optimizer.step()
    
    
    EG_optimizer.zero_grad()
    E.zero_grad()
    G.zero_grad()
    loss_eg = -torch.mean(loss_fn_eg(p_enc, p_gen))
    loss_eg.backward()
    eg_losses += loss_eg.item()
    EG_optimizer.step()
   
  print(f'Epoch {epoch_n:03d}: Z mean/std: {z_.mean():.4f}/{z_.std():.4f} '
        f'Loss_EG: {eg_losses / i:.4f} Loss_D: {d_losses / i:.4f} '
        f'Time: {datetime.now() - start}')
  
  # Visualize learning
  D.eval()
  E.eval()
  G.eval()

  with torch.no_grad():
    X = X[:10] # take 10 elements from the last minibatch
    reconstructions = G(E(X)).view(10, 28, 28).cpu().numpy()
    reals = X.view(10, 28, 28).cpu().numpy()

    z = torch.randn(10, z_len).to(device)
    samples = G(z).view(10, 28, 28).cpu().numpy()

    fig, ax = plt.subplots(1, 10, figsize=(5, 1))
    fig.suptitle(f'Real: {epoch_n}')
    for i, real in enumerate(reals):
      ax[i].imshow(real)
      ax[i].axis('off')
    plt.show()

    fig, ax = plt.subplots(1, 10, figsize=(5, 1))
    fig.suptitle(f'Reconstructions: {epoch_n}')
    for i, reconstruction in enumerate(reconstructions):
      ax[i].imshow(reconstruction)
      ax[i].axis('off')
    plt.show()

    fig, ax = plt.subplots(1, 10, figsize=(5, 1))
    fig.suptitle(f'Synthetic: {epoch_n}')
    for i, sample in enumerate(samples):
      ax[i].imshow(sample)
      ax[i].axis('off')
    plt.show()