<img src="img/dsci572_header.png" width="600">

# Lecture 8: Advanced Convolutional Models

<br><br><br>

## Lecture Learning Objectives


- Describe what few-shot learning is and how can it be useful

- Describe what a generative adversarial network is and what they can be useful for

<br><br><br>

## Imports


In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, TensorDataset, Subset
from torchvision import transforms, datasets, utils, models
from torchsummary import summary
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from PIL import Image
from statistics import mean

plt.rcParams.update({'axes.grid': False})

## Few-Shot Learning

- We have so far discussed how CNNs can be used to classify images

- The assumption here is that we have access to relatively large datasets that contain 100s or 1000s of images per class
- However, this is not always the case. More data is usually hard, or even impossible to obtain
- Take the example of a **face recognition** system for a company's employees:

  - Does it make sense to ask each employee to provide, say, 1000 images of their face so as to register them in the system?

  - Even if we have 1000 images of each employee, we'd have to retrain the model every time a new employee is added to the system

<br><br><br>

- For cases like these, we can use a technique called **few-shot** learning. Instead of training the model on a new class of images, we use the similarities and differences of an image with other images to decide the class of that image.

- Let me explain this with some examples (credit: most images in this section are adopted from this [website](https://zzaebok.github.io/machine_learning/FSL/) and [Youtube video](https://www.youtube.com/watch?v=hE7eGew4eeg)):

For training a conventional CNN, we use a dataset like this:

<img src="https://imgur.com/6MAisQL.png" width="600">

<br><br><br>

When we want to predict the class of a new image, that class should already exist in the training set of the CNN:

<img src="https://imgur.com/H1rP0Aw.png" width="600">

-But what if the class of the test image does not exist among the classes on which the CNN was trained?

- This is the kind of problem we're interested in solving with few-shot learning

- **Few-shot learning is about learning to classify a new test image, with only few examples of that new class**

<br><br><br>

- Here for example, rabbit class does not exist in the training set. It would be desirable to have a model that can learn to classify this **query** image, based on a few images of rabbits

<img src="https://imgur.com/StTYXay.png" width="600">

<br><br><br>

- The dataset of images on which the model wasn't trained on is called the **support set**

- This problem is called **k-way n-shot** learning, when we have **k classes** and **n samples per class** in the **support set**

<img src="https://imgur.com/UsmR4Ow.png" width="600">


<br><br><br>

- **Now comes the key idea of few-show learning:** instead of directly learning how to classify, learn how to find similarities between samples belonging to the same class, and differences between samples belonging to different classes

- In other words, instead of **learning the classification itself**, we'd like to **learn how to learn the classification**!

- This is why few-shot learning is said to be an example of **meta learning**


<br><br><br>

### Siamese Networks for Few-Shot Learning

- As mentioned above, the goal is to learn **similarities** and **differences** between images in the same and different classes, rather than the corresponding classes

- Therefore, it seems natural to think that a CNN could be used for feature extraction

- An interesting model for few-shot learning is a Siamese (or twin) network ([image source](https://people.kth.se/~rosun/deep-learning/figures/siamese-arch.svg)):

<img src="https://people.kth.se/~rosun/deep-learning/figures/siamese-arch.svg" width="700">

<br><br><br>

- The Siamese network is supposed to learn similarities and differences

- This is why the first step is to **construct a new dataset from an existing one** as follows:

<img src="img/pos_neg.png" width="700">

- Positive samples are pairs of images that belong to the same class

- Negative samples are pairs of images that don't belong to the same class

- The Siamese network is a parallel neural network architecture has two streams of images: one for each image in a pair of images, that form either a positive or negative sample

- The two streams share exactly the same architecture and model parameters. This is why this particular architecture is called a **Siamese** or **twin** network

<br><br><br>

**How does a Siamese network learn?**

A Siamese network uses a particular type of loss function called **contrastive loss**, with the following form:

$$
(1-Y) \left(D\right)^2+(Y) \left\{\max \left(0, m-D\right)\right\}^2
$$

where $Y$ is the label of the generated samples; 0 for similar images (positive sample), and 1 for dissimilar images (negative sample).

- The contrastive loss tries to decrease distance between embeddings (feature vectors) of similar images, and to increase distance between embeddings of dissimilar images.

- In other words, through the contrastive loss, a Siamese network tries to pull similar images together, and push dissimilar images away, in the embedding space

<br><br><br>

- I'll explain the implementation in the lecture (the code below is adopted from [here](https://github.com/harveyslash/Facial-Similarity-with-Siamese-Networks-in-Pytorch/blob/master/Siamese-networks-medium.ipynb) with some minor changes).

In [None]:
class SiameseNetworkDataset(Dataset):
    def __init__(self, imageFolderDataset, transform=None):
        self.imageFolderDataset = imageFolderDataset
        self.transform = transform

    def __getitem__(self, index):
        idx = np.random.choice(len(self.imageFolderDataset.imgs))
        img0_tuple = self.imageFolderDataset.imgs[idx]

        # 50% chance of images being in the same class
        should_get_same_class = np.random.randint(0, 2)
        if should_get_same_class:
            while True:
                # loop untill the same class is found
                idx = np.random.choice(len(self.imageFolderDataset.imgs))
                img1_tuple = self.imageFolderDataset.imgs[idx]
                if img0_tuple[1] == img1_tuple[1]:
                    break
        else:
            while True:
                # loop untill a different class is found
                idx = np.random.choice(len(self.imageFolderDataset.imgs))
                img1_tuple = self.imageFolderDataset.imgs[idx]
                if img0_tuple[1] != img1_tuple[1]:
                    break

        img0 = Image.open(img0_tuple[0])
        img1 = Image.open(img1_tuple[0])

        # convert to gray-scale
        img0 = img0.convert("L")
        img1 = img1.convert("L")

        if self.transform is not None:
            img0 = self.transform(img0)
            img1 = self.transform(img1)

        return (
            img0,
            img1,
            torch.from_numpy(
                np.array([int(img1_tuple[1] != img0_tuple[1])], dtype=np.float32)
            ),
        )

    def __len__(self):
        return 500

In [None]:
class SiameseNetwork(nn.Module):

    def __init__(self):
        super().__init__()

        self.cnn = nn.Sequential(

            nn.Conv2d(1, 96, kernel_size=5,stride=2),
            nn.BatchNorm2d(96),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, stride=2),
            
            nn.Conv2d(96, 64, kernel_size=3, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, stride=2),

            nn.Flatten()
        )

        self.fc = nn.Sequential(

            nn.Linear(3136, 256),
            nn.ReLU(inplace=True),

            nn.Linear(256, 128),
            nn.ReLU(inplace=True),

            nn.Linear(128, 32),
            nn.ReLU(inplace=True),
            
            nn.Linear(32, 2)
        )
        
    def forward_each(self, x):
        output = self.cnn(x)
        output = self.fc(output)
        
        return output

    def forward(self, input1, input2):
        output1 = self.forward_each(input1)
        output2 = self.forward_each(input2)

        return output1, output2


model = SiameseNetwork()
summary(model, [(1, 128, 128,), (1, 128, 128,)]);

In [None]:
class ContrastiveLoss(torch.nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        # Calculate the euclidian distance and calculate the contrastive loss
        euclidean_distance = F.pairwise_distance(output1, output2, keepdim=True)

        loss_contrastive = torch.mean(
        (1 - label) * torch.pow(euclidean_distance, 2)
        + (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)
        )
        
        return loss_contrastive

In [None]:
def gridshow(img, text=None):
    npimg = img.numpy()
    plt.axis("off")
    if text:
        plt.text(75, 8, text, style='italic',fontweight='bold',
            bbox={'facecolor':'white', 'alpha':0.8, 'pad':10})
        
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show() 

In [None]:
folder_dataset = datasets.ImageFolder(root="data/faces/training")

transformation = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

siamese_dataset = SiameseNetworkDataset(
    imageFolderDataset=folder_dataset, transform=transformation
)

train_loader = DataLoader(siamese_dataset, shuffle=True, batch_size=8)

example_batch = next(iter(train_loader))
concatenated = torch.cat((example_batch[0], example_batch[1]), axis=0)

gridshow(utils.make_grid(concatenated, nrow=8))
print(example_batch[2].numpy().reshape(-1))

In [None]:
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

model = SiameseNetwork().to(device)
criterion = ContrastiveLoss(margin=3.0)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

siamese_dataset = SiameseNetworkDataset(imageFolderDataset=folder_dataset, transform=transformation)
# siamese_dataset = Subset(siamese_dataset, range(500))

train_loader = DataLoader(siamese_dataset, shuffle=True, batch_size=64)

loss_history = [] 

for epoch in range(200):

    for i, (img0, img1, label) in enumerate(train_loader, 0):

        if device.type in ['cuda', 'mps']:
            img0, img1, label = img0.to(device), img1.to(device), label.to(device)

        optimizer.zero_grad()
        output1, output2 = model(img0, img1)
        loss_contrastive = criterion(output1, output2, label)
        loss_contrastive.backward()
        optimizer.step()

        if i % 10 == 0 :
            print(f"Epoch {epoch}: Training batch loss = {loss_contrastive.item():g}")
            loss_history.append(loss_contrastive.item())

In [None]:
plt.loglog(loss_history)

In [None]:
model.eval()

folder_dataset_test = datasets.ImageFolder(root="data/faces/testing/")
siamese_dataset = SiameseNetworkDataset(imageFolderDataset=folder_dataset_test, transform=transformation)
                                        
# siamese_dataset = Subset(siamese_dataset, range(2000))

test_dataloader = DataLoader(siamese_dataset, batch_size=1, shuffle=True)

# Take one image to test on
dataiter = iter(test_dataloader)
x0, _, _ = next(dataiter)

if device.type in ['cuda', 'mps']: x0 = x0.to(device)

for i in range(10):
    _, x1, label2 = next(dataiter)
    if device.type in ['cuda', 'mps']:
        x1, labe12 = x1.to(device), label2.to(device)

    concatenated = torch.cat((x0, x1), 0)
    
    output1, output2 = model(x0, x1)
    euclidean_distance = F.pairwise_distance(output1, output2)
    gridshow(utils.make_grid(concatenated.cpu()), f'Distance: {euclidean_distance.item():.4g}')

<br><br><br>

## Generative Adversarial Networks (GANs)


### What are GANs?

GANs are a type of neural network models that are used to generate new data, that is indistinguishable from the data that exists in a dataset.

For example, suppose that we have a dataset of 10,000 images. The question is: can we somehow generate images that are so real-looking that we can't tell if they are **real** (could potentially come from the dataset) or **fake** (generated by some algorithm)?

Here, there aren't any real labels. We just want to be able to produce images that are as real-looking as possible; we don't classify them. This is why **GAN modeling** is regarded as an **unsupervised learning** task. In other words, we just need a bunch of images (or input data), no labels (or target data or outputs) would be required.

<br><br><br>

GANs were invented in 2014 by Ian Goodfellow and colleagues (see the original paper [here](https://arxiv.org/abs/1406.2661)); and have been called "_the most interesting idea in the last 10 years in ML_" by Yann LeCun, Facebook’s AI research director.

<br><br><br>

Now take a look at the following image:

<img src="img/fake-face.jpeg" width="400">

Believe it or not, **this is not a real person!**

The image above is produced by a GAN that is trained on human faces. If you want to see more, visit [www.thispersondoesnotexist.com](www.thispersondoesnotexist.com). The website is connected to a GAN model living on the cloud, and each time the page refreshes, it generates a new image of a person who **does not exist!**.

<br><br><br>

### Structure of a GAN

In this section, I describe how GANs work for image data, but remember that the idea of GANs is generalizable to any kind of data, not just images.

Here is the visualization of the structure of a GAN:

<img src="img/gan-1.png" width="900"><br>
[(image source)](https://freecontent.manning.com/practical-applications-of-gans-part-1/)

The structure of a GAN consists of a **discriminator** and a **generator**:

- A discriminator is just a **typical CNN** that receives an image as the input, and generates a vector of probabilities of the input belonging to some class

- A generator is an **inverted CNN** that receives a vector of random numbers and generates an image in the output

<br><br><br>

The word "adversarial" comes from the fact that we actually have two networks battling each other:

- The generator: tries to generate fake images that look as realistic as possible such that it can **fool** the discriminator

- The discriminator: takes in real data and fake data and tries to correctly determine whether an input was real or fake

<br><br><br>

**An analogy:**

Think of the "Generator" as a new counterfeit artist trying to produce realistic-looking famous artworks to sell.

The "Discriminator" is an art critic, trying to determine if a piece of art is "real" or "fake".

At first, the "Generator" produces poor art-replicas which the "Discriminator" can easily tell are fake. But over time, the "Generator" learns ways to produce art that fools the "Discriminator". Eventually, the "Generator" becomes so good that the "Discriminator" can't tell if a piece of art is real or fake.

<br><br><br>

#### Convolution Layers

Convolution layers are used in the discriminator of a GAN to do a binary classification (real vs. fake). Their structure is very much the same as what we have seen so far in the course for CNNs.

Convolution layers **downsample** input features. In other words, the goal of convolution layers is to go from **larger features (images)** to **smaller features (images)**.

Here is an animation of how kernels are applied to input features:

<img src="img/conv-padded.gif" width="200"><img src="img/conv-strided-padded.gif" width="200">
<br>
[(image source)](https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md#transposed-convolution-animations)

<br><br><br>

#### Transposed Convolution Layers

This is the first time we see transposed convolution layers. These layers do the opposite of what convolution layers do; that is, instead of downsampling images, they upsample. Transposed convolutions are used in the generator of a GAN to generate a image from some random vector.

The goal of transposed convolution layers is to go from **smaller features (images)** to **larger features (images)**.

<img src="img/conv-trans-strided.gif" width="200"><img src="img/conv-trans-not-strided.gif" width="200"><br>
[(image source)](https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md#transposed-convolution-animations)

While downsampling seems very intuitive, upsampling might look like doing magic: we try to generate information (pixel values) that did not exist before. But it's practically not hard to do.

We do something similar to what we did with convolution layers: we convolve (pass) the kernel over the inputs, and multiply each input element by all kernel elements. The resulting array will be part of the larger image:

<img src="img/conv-trans-how.svg" width="500"><br>
[(image source)](https://d2l.ai/chapter_computer-vision/transposed-conv.html)

<img src="img/conv-trans-how2.svg" width="500"><br>
[(image source)](https://d2l.ai/chapter_computer-vision/transposed-conv.html)

If we repeat this operation in several layers, we can progressively increase the size of the input images. This is exactly what the generator does; it starts from some random noise, and progressively expands that into larger and larger images.

<br><br><br>

### Training GANs

Training a GAN happens in two iterative phases:

1. **Train the Discriminator:**

    - Generate some fake images with the generator
    
    - Show the discriminator real images and fake images and get it to classify them correctly (a simple binary classification problem)

2. **Train the Generator:**

    - Generate fake images with the generator but label them as "real"
    
    - Pass these fake images through the discriminator, and ask it for its judgment, i.e. the probability of this image being real
    - Pass this judgment to a loss function, and see how far it is from the ideal output. The ideal output is that the generator was so good that it has fooled the discriminator to give it the label of "real".
    - Do backpropagation based on the gradients of this loss value to adjust the parameters of the generator, such that it can better and better fool the discriminator.

3. **Repeat**.

<br><br><br>

<img src="img/gan-train.png" width="600"><br>
[(image source)](https://sthalles.github.io/intro-to-gans/)

<br><br><br>

### PyTorch Implementation

Alright, now's the time to implement a GAN in PyTorch. Since training GANs is a very resource-intensive job, we need to do our computations on a GPU. Here, I'll write the code **so that you can take this notebook and run it directly on [Kaggle](https://www.kaggle.com)**. If you want to run it on your own computer, you need to change the folder paths of the dataset that we're going to use.

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device.type}")

For the purpose of demonstrating how training a GAN works in PyTorch, I have chosen to use the [_Face Recognition Dataset_](https://www.kaggle.com/stoicstatic/face-recognition-dataset) from Kaggle, which contains face images of celebrities.

This dataset contains two folders: `Face Dataset` and `Extracted Faces`. We'll use the images in `Extracted Faces`, which are already `128x128` pixels. Therefore, there is no need to resize them, which is why I've commented out `transforms.Resize(IMAGE_SIZE)` in the code below. This speeds up the computations significantly, as resizing is done on CPU and it would have been the bottle-neck of our computations. Fortunately, we don't need to do this here:

In [None]:
DATA_DIR = "../input/face-recognition-dataset/Extracted Faces"

BATCH_SIZE = 64
IMAGE_SIZE = (128, 128)

data_transforms = transforms.Compose([
#     transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = datasets.ImageFolder(root=DATA_DIR, transform=data_transforms)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
        
# Plot samples
sample_batch = next(iter(data_loader))
plt.figure(figsize=(10, 8)); plt.axis("off"); plt.title("Sample Training Images")
plt.imshow(np.transpose(utils.make_grid(sample_batch[0], padding=1, normalize=True), (1, 2, 0)));

print(f'Size of dataset: {len(data_loader) * BATCH_SIZE}')

Example output:

<img src="img/faces.png" width="500">

<br><br><br>

#### Creating the Generator

The generator takes in a random vector called **latent vector**, which can be thought of as a `1x1` pixel image having an arbitrary number of **channels** (specified in the code by `LATENT_SIZE`).

Through the generator, we pass this latent vector through the **deconvolution** layers (known as **transposed convolution** layers in PyTorch), and progressively expand its size, such that in the output we'll have an image similar in dimensions to the images in our dataset. Here for example, our images in the dataset are `128x128`, so the goal of the generator is to start from an image of `1x1` pixel and generate an image of `128x128` pixels.

**Details:**

- In the following code, I've used `nn.BatchNorm2d()` for all layers, `nn.LeakyReLU()` as activation for intermediate layers, and `nn.Tanh()` as activation for the output of the generator. These are suggested to be used based on empirical evidence in training GANs.

- We usually do in-place modification of tensors in `nn.LeakyReLU()` by setting `inplace=True` to save some memory.
- We set `bias=False` because the batch normalization layer contains a bias term, so we don't want to do it twice.

Note that we mainly play with the strides to progressively expand the size of the input latent vector.

Here is the code for the generator:

In [None]:
class Generator(nn.Module):
    
    def __init__(self, LATENT_SIZE):
        super(Generator, self).__init__()
        
        self.main = nn.Sequential(
            
            # input dim: [-1, LATENT_SIZE, 1, 1]
            
            nn.ConvTranspose2d(LATENT_SIZE, 1024, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            
            # output dim: [-1, 1024, 4, 4]

            nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            # output dim: [-1, 512, 8, 8]

            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            # output dim: [-1, 256, 16, 16]

            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            # output dim: [-1, 128, 32, 32]
            
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            # output dim: [-1, 64, 64, 64]

            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(3),
            
            # output dim: [-1, 3, 128, 128]
            
            nn.Tanh()
            
            # output dim: [-1, 3, 128, 128]
        )
        
    def forward(self, input):
        output = self.main(input)
        return output

<br><br><br>

The details of how exactly a transposed convolution layer works in PyTorch can be fairly confusing at first, but here's some further remarks to help you feel more comfortable with it:

The parameters `stride` and `padding` in `nn.ConvTranspose2d` are (unfortunately?) not what we’re used to in using `nn.Conv2d`. For example, `stride=2` doesn’t mean that the kernel in the transposed convolution moves in steps of 2 pixels each time. These parameters are, instead, designed such that if you use the same stride and padding for a `ConvTranspose2d` as in a `Conv2d`, and apply it on the output of the that `Conv2d`, it will give you an image of the same shape (but not the same pixel values). In other words, if

```python
out_conv = Conv2d(img, stride=s, padding=p)
x = out.shape
```

and

```python
out_convT = ConvTranspose2d(out_conv, stride=s, padding=p)
y = out_convT.shape
```

then

```python
x == y --> True
```

In fact, the `stride` and `padding` work this way to make writing code easier.

If you’re wondering about the mechanics of computing a transposed convolution in various scenarios, **make sure to check out [this blog post](https://numbersmithy.com/understanding-transposed-convolutions-in-pytorch/)**—the author has done an excellent job of explaining all the relevant details of `nn.ConvTranspose2d` and has some very useful diagrams as well.

<br><br><br>

#### Creating the Discriminator

As discussed before, this is a conventional CNN that receives an image (`128x128` in our case here) and outputs the probability of this image belonging to some certain class:

In [None]:
class Discriminator(nn.Module):
    
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.main = nn.Sequential(
        
            # input dim: [-1, 3, 128, 128]
            
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),

            # output dim: [-1, 64, 64, 64]

            nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),

            # output dim: [-1, 64, 32, 32]

            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            # output dim: [-1, 128, 16, 16]

            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            # output dim: [-1, 256, 8, 8]

            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            # output dim: [-1, 512, 4, 4]

            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0),
            
            # output dim: [-1, 1, 1, 1]

            nn.Flatten(),
            
            # output dim: [-1]

            nn.Sigmoid()
            
            # output dim: [-1]
        )

    def forward(self, input):
        output = self.main(input)
        return output

<br><br><br>

#### Instantiating and Initializing our GAN

Let's create the discriminator and generator objects, as well as the loss function and optimizers:

In [None]:
LATENT_SIZE = 200

generator = Generator(LATENT_SIZE)
discriminator = Discriminator()

generator.to(device)
discriminator.to(device)

criterion = nn.BCELoss()

optimizerG = optim.Adam(generator.parameters(), lr=0.001, betas=(0.5, 0.999))
optimizerD = optim.Adam(discriminator.parameters(), lr=0.001, betas=(0.5, 0.999))

<br><br><br>

We explored how the starting point of the optimization can affect the final results. It is recommended that to initialize the weights of a GAN with values obtained randomly from a normal distribution. In PyTorch, we can define the initialization function as we like and apply it to the model parameters using the `.apply()` method:

In [None]:
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
        
generator.apply(weights_init)
discriminator.apply(weights_init);

<br><br><br>

#### Training our GAN

We use the following cell to keep track of how a fixed noise (latent vector) is transformed to a generated image in each epoch. We will see that the generations become better and better throughout the epochs:

In [None]:
img_list = []
fixed_noise = torch.randn(BATCH_SIZE, LATENT_SIZE, 1, 1).to(device)

Finally, here is the training loop (you can also put everything inside a function):

In [None]:
NUM_EPOCHS = 50

print('Training started:\n')

D_real_epoch, D_fake_epoch, loss_dis_epoch, loss_gen_epoch = [], [], [], []

for epoch in range(NUM_EPOCHS):
    
    D_real_iter, D_fake_iter, loss_dis_iter, loss_gen_iter = [], [], [], []
    
    for real_batch, _ in data_loader:

        # STEP 1: train discriminator
        # ==================================
        # Train with real data
        discriminator.zero_grad()
        
        real_batch = real_batch.to(device)
        real_labels = torch.ones((real_batch.shape[0],), dtype=torch.float).to(device)
        
        output = discriminator(real_batch).view(-1)
        loss_real = criterion(output, real_labels)
        
        # Iteration book-keeping
        D_real_iter.append(output.mean().item())
        
        # Train with fake data
        noise = torch.randn(real_batch.shape[0], LATENT_SIZE, 1, 1).to(device)
        
        fake_batch = generator(noise)
        fake_labels = torch.zeros_like(real_labels)
        
        output = discriminator(fake_batch.detach()).view(-1)
        loss_fake = criterion(output, fake_labels)
        
        # Update discriminator weights
        loss_dis = loss_real + loss_fake
        loss_dis.backward()
        optimizerD.step()
        
        # Iteration book-keeping
        loss_dis_iter.append(loss_dis.mean().item())
        D_fake_iter.append(output.mean().item())
        
        # STEP 2: train generator
        # ==================================
        generator.zero_grad()
        output = discriminator(fake_batch).view(-1)
        loss_gen = criterion(output, real_labels)
        loss_gen.backward()
        
        # Book-keeping
        loss_gen_iter.append(loss_gen.mean().item())
        
        # Update generator weights and store loss
        optimizerG.step()
        
    print(f"Epoch ({epoch + 1}/{NUM_EPOCHS})\t",
          f"Loss_G: {mean(loss_gen_iter):.4f}",
          f"Loss_D: {mean(loss_dis_iter):.4f}\t",
          f"D_real: {mean(D_real_iter):.4f}",
          f"D_fake: {mean(D_fake_iter):.4f}")
    
    # Epoch book-keeping
    loss_gen_epoch.append(mean(loss_gen_iter))
    loss_dis_epoch.append(mean(loss_dis_iter))
    D_real_epoch.append(mean(D_real_iter))
    D_fake_epoch.append(mean(D_fake_iter))
    
    # Keeping track of the evolution of a fixed noise latent vector
    with torch.no_grad():
        fake_images = generator(fixed_noise).detach().cpu()
        img_list.append(utils.make_grid(fake_images, normalize=True, nrows=10))
        
print("\nTraining ended.")

Example output:

```
Epoch (1/50)	 Loss_G: 16.1643 Loss_D: 2.7680	 D_real: 0.8808 D_fake: 0.1760
Epoch (2/50)	 Loss_G: 5.1395 Loss_D: 1.2464	 D_real: 0.7797 D_fake: 0.2442
Epoch (3/50)	 Loss_G: 2.3695 Loss_D: 1.3886	 D_real: 0.6162 D_fake: 0.3758
Epoch (4/50)	 Loss_G: 2.3272 Loss_D: 1.3902	 D_real: 0.6065 D_fake: 0.3902
Epoch (5/50)	 Loss_G: 2.4110 Loss_D: 1.2522	 D_real: 0.6205 D_fake: 0.3870
Epoch (6/50)	 Loss_G: 2.6082 Loss_D: 1.2582	 D_real: 0.6301 D_fake: 0.3727
Epoch (7/50)	 Loss_G: 2.3240 Loss_D: 1.2784	 D_real: 0.6152 D_fake: 0.3882
Epoch (8/50)	 Loss_G: 2.3167 Loss_D: 1.3681	 D_real: 0.6032 D_fake: 0.3983
Epoch (9/50)	 Loss_G: 2.3370 Loss_D: 1.2700	 D_real: 0.6068 D_fake: 0.3959
Epoch (10/50)	 Loss_G: 2.3772 Loss_D: 1.2815	 D_real: 0.6077 D_fake: 0.3872
```

<br><br><br>

#### Visualizing Training Progress

The following plots will help you see how the loss values of the generator and the discriminator, as well as the probabilities generated by the discriminator on real and fake images evolve during the training of our GAN:

In [None]:
plt.plot(np.array(loss_gen_epoch), label='loss_gen')
plt.plot(np.array(loss_dis_epoch), label='loss_dis')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend();

Example output:

<img src="img/loss_epoch.png" width="350">

In [None]:
plt.plot(np.array(D_real_epoch), label='D_real')
plt.plot(np.array(D_fake_epoch), label='D_fake')
plt.xlabel("Epoch")
plt.ylabel("Probability")
plt.legend();

Example output:

<img src="img/D_epoch.png" width="350">

<br><br><br>

The following code cells help to see the evolution of one fixed noise vector throughout the epochs. The generator is applied on this fixed random noise in each epoch, and the results are saved as batches of generated images.

In [None]:
%%capture

fig = plt.figure(figsize=(10, 10))
ims = [[plt.imshow(np.transpose(i,(1, 2, 0)), animated=True)] for i in img_list[::1]]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
ani.save('GAN.gif', writer='imagemagick', fps=2)

In [None]:
HTML(ani.to_jshtml()) # run this in a new cell to produce the below animation

These are my results after running the GAN model for around 100 epochs:

<img src="img/GAN.gif" width="600"><br>

**Note:** You might have noticed the checker-board patterns that appear in the generated images, especially early in the training process. This is a known issue with transposed convolutions. This problem and potential solutions are discussed in great detail in [this article](https://distill.pub/2016/deconv-checkerboard/).