# Hand-written digits generation using GAN with Pytorch

In [57]:
import torch
from torch import nn
from tqdm.auto import tqdm # Arabic word 'taqaddum'
# provides a fast and extensible progress bar for loops and other iterable process. 'auto' module is used to automatically choose the appropriate tqdm implementation based on the environment
from torchvision import transforms
# for variouls data transformations on images 
from torchvision.datasets import MNIST # 60,000 dataset of handwritten digits 
from torchvision.utils import make_grid 
# used to create a grid of images from a collections of images, for visualization purposes
from torch.utils.data import DataLoader
# for loading and managing bataches of data for training NN. for large datasets
import matplotlib.pyplot as plt

torch.manual_seed(0)

<torch._C.Generator at 0x7f6078f7be90>

In [58]:
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    image_unflat = image_tensor.detach().cpu().view(-1, *size) # view() to reshape the tensor
    # detach from the computation graph to prevent any gradient calculations.
    # move the tensor to cpu. the (-1, *size) reshapes the tensor with -1 (total number of elements), 
    image_grid = make_grid(image_unflat[:num_images], nrow=5) # to arrange images in a grid. 
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())# to rearrange the dim of the tensor
    plt.show()

## Generator Block

In [59]:
# single layer/block for the generator's neural network. each block should have a linear transformation to map to another shape
def get_generator_block(input_dim, output_dim): # scalar
    # returns a generator neural network layer
    return nn.Sequential(
        nn.Linear(input_dim, output_dim), # fully connected layer, output_dim, is the no of units/neurons in the layer
        nn.BatchNorm1d(num_features=output_dim ),
        nn.ReLU(inplace=True)
    )

In [60]:
get_generator_block(25, 12)

Sequential(
  (0): Linear(in_features=25, out_features=12, bias=True)
  (1): BatchNorm1d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
)

Generator:    
* Noise vector dim
* Image dim
* Initial hidden dim


In [61]:
# inherits from the base class nn.Module, we call super().__init__()initializer/constructor of the base class to properly initialize the inherited attirbutes and methods. 
class Generator(nn.Module): # to create a overall generator model
    def __init__(self, z_dim=10, im_dim=784, hidden_dim=128): 
        super(Generator, self).__init__() # used to access the parent class, nn.Module, by creating a reference to it through the 'super' function.
        # Generator is the current class and self refers to the instance of the 'Generator class'
        self.gen = nn.Sequential(   # sequential neural netowork
            get_generator_block(z_dim, hidden_dim),
            get_generator_block(hidden_dim, hidden_dim*2),
            get_generator_block(hidden_dim*2, hidden_dim*4),
            get_generator_block(hidden_dim*4, hidden_dim*8),
            
            #Dropdown
            nn.Linear(hidden_dim*8, im_dim),
            nn.Sigmoid()            
        )
    def forward(self, noise):  # forward pass of the model, with noise tensor
        return self.gen(noise)

    def get_gen(self):
        return self.gen
    

In [62]:
def get_noise(n_samples, z_dim, device='cpu'): # samples , # dim of noise vector
    return torch.randn(n_samples, z_dim, device=device)


Discriminator:

We use Leaky ReLUs to prevent "dying ReLU" problem

Phenomenon where the parameters stop changing due to consistently negative values passed to a ReLU, which result in zero gradient

In [63]:
def get_disctiminator_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.LeakyReLU(negative_slope=0.2),   
    )

In [64]:
class Discriminator(nn.Module):
    def __init__(self, im_dim=784, hidden_dim=128):
        super(Discriminator, self).__init__()
        
        self.disc = nn.Sequential(
            get_disctiminator_block(im_dim, hidden_dim*4),
            get_disctiminator_block(hidden_dim*4, hidden_dim*2),
            get_disctiminator_block(hidden_dim*2, hidden_dim),
            
            nn.Linear(hidden_dim, 1) # output should be 1-D for Binary classification
        )
        
    def forward(self, image):
        return self.disc(image)
    
    def get_disc(self): # by instance of the class itself
        return self.disc
        

## Training

In [65]:
# Set your parameters
criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.00001

# Load MNIST dataset as tensors
dataloader = DataLoader(
    MNIST('.', download=False, transform=transforms.ToTensor()),
    batch_size=batch_size,
    shuffle=True)

#device = 'cuda'

device = 'cpu'

Initializing the generator and discriminator and optimizers

In [66]:
gen = Generator(z_dim).to(device) # creating an instance and moving to the cuda
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)# optimizer of the gen.parameters()

disc = Discriminator().to(device) 
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)

Calculate the discriminator's loss and generator's loss.

Since generator is needed when calcualating the discriminator's loss, we do .detach() on gen result to ensure that only the discriminator is updated

In [67]:
def get_disc_loss(gen, disc, criterion, real, num_images, z_dim, device):
    
    # Create noise vectors and generate a batch (num_images) of fake_images
    noise = get_noise(n_samples=num_images, z_dim=z_dim, device=device)
    fake = gen(noise)
    
    # Get disc pred of fake images and calc the loss. --- Detach the generator. Ground truth of fake images is all zeros
    disc_fake_pred = disc(fake.detach()) # dont update the param
    disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred, device=device))
   
    # Get the disc pred of the real image and calc the loss. One for real image
    disc_real_pred = disc(real)
    disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred, device=device))
     
    # calc the disc loss by averaging the real and fake loss
    disc_loss = (disc_fake_loss + disc_real_loss) / 2
    
    return disc_loss
    

In [68]:
def get_gen_loss(gen, disc, criterion, num_images, z_dim, device):
    
    noise = get_noise(n_samples=num_images, z_dim=z_dim, device=device)
    fake = gen(noise)
    
    disc_fake_pred = disc(fake)
    
    # calc the gen loss
    gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
    
    return gen_loss

Finale

Both Gen and Disc should grow equally. At first Disc will do better as it is easy to predict it is fake. If one model outperforms other, the training stops, so balance the two models.

It is harder to do in Normal GAN

In [None]:
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
test_generator = True 
gen_loss = False
error = False
for epoch in range(n_epochs):
  
    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)

        real = real.view(cur_batch_size, -1).to(device)

        disc_opt.zero_grad()

        # Calculate discriminator loss
        disc_loss = get_disc_loss(gen, disc, criterion, real, cur_batch_size, z_dim, device)

        disc_loss.backward(retain_graph=True)

        disc_opt.step()

        # For testing purposes, to keep track of the generator weights
        if test_generator:
            old_generator_weights = gen.gen[0][0].weight.detach().clone()

        gen_opt.zero_grad()

        # Calculate generator loss
        gen_loss = get_gen_loss(gen, disc, criterion, cur_batch_size, z_dim, device)

        gen_loss.backward()

        gen_opt.step()

        mean_discriminator_loss += disc_loss.item() / display_step

        mean_generator_loss += gen_loss.item() / display_step

        if cur_step % display_step == 0 and cur_step > 0:
            print(f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")
            fake_noise = get_noise(cur_batch_size, z_dim, device=device)
            fake = gen(fake_noise)
            show_tensor_images(fake)
            show_tensor_images(real)
            mean_generator_loss = 0
            mean_discriminator_loss = 0
        cur_step += 1
