In [None]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

In [None]:
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    # The view command is used to modify the dimensions of the image. Here we flatten the images.
    image_unflat = image_tensor.detach().cpu().view(-1, *size)
    # The grid details are similar to MNIST image dimensions
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

In [None]:
# Define the dimensions of the image prior to the execution
# n_channels is 1 for grayscale images and 3 for RGB images.
n_channels = 1
# Dimensions of images in MNIST
basewidth = 28
baseheight = 28

# Print the name of GPU and set it as the default device for tensor operations
print(torch.cuda.get_device_name(0))
torch.cuda.set_device(0)

In [None]:
#We need to create a generator which inherits from the parent class nn.Module
class Generator(nn.Module):
    #Here dist_size represents the size of the distribution which will be used to generate the image(s).
    def __init__(self, dist_size = 10, base_hidden_units = 128, basewidth = 28, baseheight = 28, n_channels = 1):
        super(Generator, self).__init__()
        self.dist_size = dist_size
        self.base_hidden_units = base_hidden_units
        #Define the size of the image as the product of its dimensions
        self.final_img_size = basewidth * baseheight * n_channels
        #Define the generator as a sequential block which in itself contains a generalised block.
        self.gen = nn.Sequential(
            #Generator needs to be dense
            self.make_genblock(dist_size, base_hidden_units),
            self.make_genblock(base_hidden_units, base_hidden_units * 2),
            self.make_genblock(base_hidden_units * 2, base_hidden_units * 4),
            self.make_genblock(base_hidden_units * 4, base_hidden_units * 8),
            nn.Linear(base_hidden_units * 8,self.final_img_size),
            #The output is restricted to values between 0 and 1 using Sigmoid. TanH can be experimented with as well
            nn.Sigmoid()
        )
        
    def forward(self, noise_vecs):
        return self.gen(noise_vecs)
    #This is a basic process for hidden layer.
    def make_genblock(self, input_units, output_units):
        return nn.Sequential(
            nn.Linear(input_units,output_units),
            nn.BatchNorm1d(output_units),
            # This works better than normal ReLu (atleast for 200 iterations)
            nn.LeakyReLU(0.01,inplace=True)
        )

In [None]:
# We need to create a discriminator which inherits from the parent class nn.Module
# Its very similar to the Generator class
class Discriminator(nn.Module):
    def __init__(self, base_hidden_units = 128, basewidth = 28, baseheight = 28, n_channels = 1):
        super(Discriminator, self).__init__()
        self.base_hidden_units = base_hidden_units
        self.input_img_size = basewidth * baseheight * n_channels
        #Since this acts as a classifying model, we use a model decreasing in size
        #as it can then be trained faster.
        self.disc = nn.Sequential(
            self.make_discblock(self.input_img_size, base_hidden_units * 4),
            self.make_discblock(base_hidden_units * 4, base_hidden_units * 2),
            self.make_discblock(base_hidden_units * 2, base_hidden_units),
            #Since the inputs are between 0 and 1, a sigmoid is unnecessary.
            #A function to strictly restrict the values between 0 and 1 can be used for surety.
            nn.Linear(base_hidden_units, 1)
        )
        
    def forward(self, images):
        return self.disc(images)
    #Here LeakyReLu is used since if the values reach zero then the training is impacted.
    def make_discblock(self, input_units, output_units):
        return nn.Sequential(
            nn.Linear(input_units, output_units),
            # Batch Norm should not be used since the mean distribution would consist of real and fake images during training
            # while the distribution of images during test time will be of fake images only.
            nn.LeakyReLU(0.2, inplace = True)
# You have to look at the basics of batchnorm to understand what is happening. At train time, the layer aggregates an average mean
# called VAR but trains on that batch’s metrics. At test time, the aggregated weights (mean and variance) are used on the input.

# When training the discriminator, batchnorm uses the real and fake batch to normalize. The moving average of the batchnorm will aggregate
# both distributions. If your batch has a mix of both, the mean at each batch will be similar to the moving average.

# Otherwise, you will have 3 completely different values: one for real samples, one for fake samples, one for moving average (which is the one used at test time).
# Let’s look at what happens at test time (i.e. the generator loss): your input is fake, and the batchnorm is using a normalization based on both real and fake samples. This is very different from what you were seeing at training time if you had full batches of either real or fake samples. And, being different, the backpropagation from D is now not that meaningful…
# Long story short: make sure your batchnorm weights are consistent at train and test, or don’t use it at all.
        )

In [None]:
# Function for return values of input sampled from standard normal distribution.
def generate_noise(num_images, dist_size):
    return torch.randn(num_images, dist_size)

In [None]:
# Setting the hyperparameters
lr = 0.00005
dist_size = 64
batch_size = 128
epochs = 200
cur_step = 0

In [None]:
# Loading the MNIST dataset
# pin_memory = True didnt work for my GPU. Changing num_workers to a value > 0 yields no significant benefit
dataloader = DataLoader(
    MNIST('.', download=True, transform=transforms.ToTensor()),
    batch_size=batch_size,
    shuffle=True)

In [None]:
# The loss function is binary cross entropy
# If the correct value is y and predicted value is y_pred then,
# The loss for one image is defined as loss = -(y * log (y_pred) + (1 - y) * log (1 - y_pred))
bceLoss = nn.BCEWithLogitsLoss()
# The objects which are to be trained
generator = Generator(dist_size = dist_size)
discriminator = Discriminator()
# The optimisers for applying gradient descent. SGD can be used as well
generator_opt = torch.optim.Adam(generator.parameters(),lr=lr)
discriminator_opt = torch.optim.Adam(discriminator.parameters(),lr=lr)

In [None]:
for epoch in range(epochs):

    # Dataloader returns the batches
    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)

        # Using the view method to flatten the images in the form of a vector
        real = real.view(cur_batch_size, -1)

        # Reset the gradients before every training loop to prevent accumulation of gradients from previous iterations.
        # A proper explanantion is availaible on the link below
        # https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch
        discriminator_opt.zero_grad()

        # Calculate performance of the discriminator on detecting fake images
        # Get the input vectors
        disc_inputs = generate_noise(cur_batch_size, dist_size)
        # Store the outputs provided by the generator.
        # Keep in mind that disc_inputs and generator are part of the computation graph.
        disc_generator_outputs = generator(disc_inputs)
        #IMPORTANT: This step is necessary as we dont want the backpropagation to affect the weights of the generator
        # This can occur as the generator is part of the computation graph, which is why we detach disc_generator_outputs
        # The backpropagation will now be considered to end at the disc_fake_outputs. Check the link below for detach function with visualisation
        # http://www.bnikolic.co.uk/blog/pytorch-detach.html
        disc_fake_outputs = discriminator(disc_generator_outputs.detach())
        # We want the discrimator to be penalised if it doesnt predict value = 0 for fake images
        disc_fake_loss = bceLoss(disc_fake_outputs, torch.zeros_like(disc_fake_outputs))
        
        # Calculate performance of the discriminator on recognising real images
        # Store the outputs of the discriminator.
        disc_real_outputs = discriminator(real)
        # We want the discrimator to be penalised if it doesnt predict value = 1 for real images.
        disc_real_loss = bceLoss(disc_real_outputs, torch.ones_like(disc_real_outputs))
        #Take the average of the loss
        discriminator_loss = (disc_real_loss + disc_fake_loss)/2

        # Get the gradients of the discriminator.
        discriminator_loss.backward(retain_graph=True)
        # Perform the gradient descent using the gradients of the previous step.
        discriminator_opt.step()

        # Reset the gradients
        generator_opt.zero_grad()
        # Calculate performance of the generator on fooling the discriminator.
        # Get the input vectors
        gen_inputs = generate_noise(cur_batch_size, dist_size)
        # Store the outputs provided by the generator.
        gen_outputs = generator(gen_inputs)
        # Store the results of discriminator checking the generated images.
        gen_disc_outputs = discriminator(gen_outputs)
        # We want the generator to be penalised if the discriminator gives a value = 0, i.e., detects the image is generated.
        gen_loss = bceLoss(gen_disc_outputs, torch.ones_like(gen_disc_outputs))
        # Get the gradients of the generator.
        gen_loss.backward(retain_graph=True)
        # Perform the gradient descent using the gradients of the previous step.
        # Even though the discriminator is part of the computation graph and hence the backpropagagtion, we only update weights of generator.
        generator_opt.step()

        if cur_step % 150 == 0:
            print(f"Epoch {epoch}, step {cur_step}")
            fake_noise = generate_noise(cur_batch_size, dist_size)
            fake = generator(fake_noise)
            show_tensor_images(fake)
            show_tensor_images(real)
        cur_step += 1

In [None]:
# Storing the generator model
import pickle
filename = 'generator_local_leakyRelu_pickled.pkl'
outfile = open(filename,'wb')
pickle.dump(generator,outfile)
outfile.close()

In [None]:
# loading the model to check whether it works or not
infile = open(filename,'rb')
generator_local_pickled = pickle.load(infile)
infile.close()

In [None]:
# Check output generated or not
show_tensor_images(generator_local_pickled(generate_noise(25,64)))