We're going to create a very basic GAN that is able to generate images of the hand-written digits (0-9) using pytorch framework. If you're not familiar with pytorch, you may find the [PyTorch documentation](https://pytorch.org/docs/stable/index.html) useful.

Read my [blog on medium](https://medium.com/@Mustafa77/gans-specialization-part1-8d03c64d42ad) to deeply inderstand how GAN actualy works.

# 1. Loading our Toolkit.


In [None]:
import torch
from torch import nn
from tqdm.auto import tqdm # tqdm provides a progress bar for loops and other iterative tasks
from torchvision import transforms # For Data Augmentation and Transformation
from torchvision.datasets import MNIST # Training dataset
from torchvision.utils import make_grid # make_grid is a utility function to visualize a batch of images at once.
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt  # Data Visualization

# 2. Loading our Dataset (MNIST Dataset).
For the sake of simplicity, the training images that We'll be using is from [MNIST Dataset](http://yann.lecun.com/exdb/mnist/). It contains 60,000 images of handwritten digits, from 0 to 9

In [None]:
# ---------> Load MNIST dataset as tensors
dataloader = DataLoader( MNIST('.', download= True, transform=transforms.ToTensor()), batch_size= 128, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 16.5MB/s]


Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 498kB/s]


Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 4.50MB/s]


Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 4.42MB/s]

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw






# 3. Building The Generator.
We'll start by creating the generator's NN, coposed of 4 blocks, each block should include a [linear transformation](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) to map to another shape, a [batch normalization](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html) for stabilizing and accelerating the training process, and finally a non-linear activation function ( We use [ReLU here](https://pytorch.org/docs/master/generated/torch.nn.ReLU.html)) so the output can be transformed in complex ways. You will learn more about activations and batch normalization later in part2.

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim=10, im_dim=784, hidden_dim=128):
        super(Generator, self).__init__() # To make sure that the child class behaviour is just like the parent class
        self.gen = nn.Sequential(
            self.generator_block(z_dim, hidden_dim),
            self.generator_block(hidden_dim, hidden_dim * 2),
            self.generator_block(hidden_dim * 2, hidden_dim * 4),
            self.generator_block(hidden_dim * 4, hidden_dim * 8),
            nn.Linear(hidden_dim * 8, im_dim),
            nn.Sigmoid()
        )

    def generator_block(self, input_dim, output_dim):
      return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.BatchNorm1d(output_dim),
        nn.ReLU(inplace=True))

    def forward(self, noise):
        return self.gen(noise)

        '''
        In OOP, the `forward` method plays a significant role, particularly in the context of classes that represent models e.g., build a NN.
        Here are the key points regarding its importance:
        1. Encapsulation of Functionality 2. Interface for Model Execution 3. Abstraction 4. Modularity
        5. Ease of Testing and Debugging 6. Support for Inheritance and Polymorphism 7.Alignment with Frameworks
'''

    def get_gen(self):
        return self.gen

# 4. Building The Discriminator.
The second component that you need to construct is the discriminator.
We use leakyReLU to prevent the "dying ReLU" problem, which refers to the phenomenon where the parameters stop changing due to consistently negative values passed to a ReLU, which result in a zero gradient




In [None]:
class Discriminator(nn.Module):
    def __init__(self, im_dim=784, hidden_dim=128):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            self.discriminator_block(im_dim, hidden_dim * 4),
            self.discriminator_block(hidden_dim * 4, hidden_dim * 2),
            self.discriminator_block(hidden_dim * 2, hidden_dim),
            nn.Linear(hidden_dim, 1),
            ## nn.Sigmoid()
        )
    def discriminator_block(self, input_dim, output_dim):
      return nn.Sequential(
         nn.Linear(input_dim, output_dim),
         nn.LeakyReLU(0.2, inplace=True))

    def forward(self, image):
        return self.disc(image)

    def get_disc(self):
        # Returns the sequential model
        return self.disc

# 5. Random Noise Vector.
The noise vector has the important role of insuring that the generated images don't all look the same.
<br> We'll generate it randomly by sampling random numbers from the normal distribution.
<br> Since multiple images will be processed per pass, we'll generate all the noise vectors at once.

Note that whenever you create a new tensor using torch.ones, torch.zeros, or torch.randn, you either need to create it on the target device, e.g. `torch.ones(3, 3, device=device)`, or move it onto the target device using `torch.ones(3, 3).to(device)`. You do not need to do this if you're creating a tensor by manipulating another tensor or by using a variation that defaults the device to the input, such as `torch.ones_like`. In general, use `torch.ones_like` and `torch.zeros_like` instead of `torch.ones` or `torch.zeros` where possible.

In [None]:
def get_noise(n_samples, z_dim, device='cpu'):
    return torch.randn(n_samples,z_dim, device=device)

# 6. Intialize the generator and discriminator and their Optimizers.
Now, we can initialize our generator and discriminator optimizers. Note that each optimizer only takes the parameters of one particular model.

In [None]:
z_dim = 64

gen = Generator(z_dim)
gen_opt = torch.optim.Adam(gen.parameters(), lr= 0.0001)
disc = Discriminator()
disc_opt = torch.optim.Adam(disc.parameters(), lr= 0.0001)

# 7. Create Loss Functions.
We are in need of creating functions to calculate the discriminator's loss and the generator's loss. This is how the discriminator and generator will know how they are doing and improve themselves. Since the generator is needed when calculating the discriminator's loss, you will need to call .detach() on the generator result to ensure that only the discriminator is updated!

Note, We've efined a loss function (`criterion`) to encourage to use `torch.ones_like` and `torch.zeros_like` instead of `torch.ones` or `torch.zeros`. If you use `torch.ones` or `torch.zeros`, you'll need to pass `device=device` to them.

In [None]:
criterion = nn.BCEWithLogitsLoss()

In [None]:
def get_disc_loss(gen, disc, criterion, real, num_images, z_dim= 64):
    fake_noise = get_noise(num_images, z_dim)
    fake = gen(fake_noise)
    disc_fake_pred = disc(fake.detach())

    disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
    disc_real_pred = disc(real)
    disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))

    return (disc_fake_loss + disc_real_loss) / 2

In [None]:
def get_gen_loss(gen, disc, criterion, num_images, z_dim= 64):
    fake_noise = get_noise(num_images, z_dim)
    fake = gen(fake_noise)
    disc_fake_pred = disc(fake)
    return criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))

# 8. Visualize a Grid of Images.
We'll create a visualizer function (show_tensor_images) to visualize a grid of images to help investigate the generated images and keep track the generator progress.

In [None]:
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    '''
    Function for visualizing images in a uniform grid : Given a tensor of images, number of images, and size per image.
    Parameters:
              image_tensor: This is the input tensor containing the image data.
              num_images: The total number of images to display in the grid.
              size: The dimensions of each image (default is `(1, 28, 28)`),
              indicating that images are grayscale with height and width of 28 pixels.
    Functionality:
                  - It 1st detaches the tensor from the computation graph,
                  and reshapes it to separate individual images using view(-1, *size),
                  where -1 allows for automatic dimension calculation based on the number of images.
                  - It then creates a grid layout of images using `make_grid`,
                  which arranges the specified number of images 'num_images' into a uniform grid with 5 images per row.
                  - Finally, the generated grid is displayed using `plt.imshow`,
                  adjusting its dimensions with `permute` to ensure the color channels are in the right order,
                  and `squeeze` to remove any singleton dimensions.

    '''
    image_unflat = image_tensor.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

# 9. Training (put everything together)

For each epoch, we'll process the entire dataset in batches. For every batch, we'll need to update the discriminator and generator weights. Note that you may see a loss to be greater than 1, this is okay since binary cross entropy loss can be any positive number for a sufficiently confident wrong guess.

It’s also often the case that the discriminator will outperform the generator, especially at the start, because its job is easier. It's important that neither one gets too good (near-perfect accuracy), which would cause the entire model to stop learning. Balancing the 2 models is actually remarkably hard to do in a standard GAN.


In [None]:
n_epochs = 200
display_step = 500
batch_size = 128
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0

In [None]:
for epoch in range(n_epochs):
    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)
        real = real.view(cur_batch_size, -1)
        disc_opt.zero_grad()
        gen_opt.zero_grad()
        disc_loss= get_disc_loss(gen, disc, criterion, real, cur_batch_size, z_dim)
        disc_loss.backward(retain_graph=True)
        disc_opt.step()
        gen_loss = get_gen_loss(gen, disc, criterion, cur_batch_size, z_dim)
        gen_loss.backward()
        gen_opt.step()
        mean_discriminator_loss += disc_loss.item() / display_step
        mean_generator_loss += gen_loss.item() / display_step
        if cur_step % display_step == 0 and cur_step > 0:
            print(f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")
            fake_noise = get_noise(cur_batch_size, z_dim)
            fake = gen(fake_noise)
            show_tensor_images(fake)
            show_tensor_images(real)
            mean_generator_loss = 0
            mean_discriminator_loss = 0
        cur_step += 1

  0%|          | 0/469 [00:00<?, ?it/s]

KeyboardInterrupt: 