In [None]:
# Import Pytorch libraries
import torch
import torch.nn as nn
# Import dataset class for downloading MNIST datasets
from torch.utils.data import Dataset
# Import torchvision libraries
import torchvision
# Import transforms class for image processing
from torchvision import transforms 
# Import dataloader class to load the data from datasets
from torch.utils.data import DataLoader
# Import numpy for numeric computation
import numpy as np
# Import matplotlib for displaying plots
import matplotlib.pyplot as plt

In [None]:
# Select GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#Display the device
device

In [None]:
# Functions to generate random noise
def create_noise(batch_size, z_size, mode_z):
    # Uniform distribution
    if mode_z == 'uniform':
        input_z = torch.rand(batch_size, z_size, 1, 1)*2 - 1 
    # Normal distribution
    elif mode_z == 'normal':
        input_z = torch.randn(batch_size, z_size, 1, 1)
    return input_z

In [None]:
# Set up the data loader

# Create an image path on the current location
image_path = './'
# Create a transforms function to process the image  
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5), std=(0.5))
])

# Get datasets from torchvision MNIST datasets
mnist_dataset = torchvision.datasets.MNIST(root=image_path, 
                                           train=True, 
                                           transform=transform, 
                                           download=True)

# Set the batch size
batch_size = 64

# Set the seed for generating random numbers
torch.manual_seed(1)
np.random.seed(1)

# Set up the dataset
mnist_dl = DataLoader(mnist_dataset, batch_size=batch_size, 
                      shuffle=True, drop_last=True)

In [None]:
# Function to creator generator network
def make_generator_network(input_size, n_filters):
    model = nn.Sequential(
        # Input size, n_filters: output size, 4: kernel size, 1: stride, 0: padding, bias flag: False 
        nn.ConvTranspose2d(input_size, n_filters*4, 4, 1, 0, bias=False),
        # Perform batch normalization
        nn.BatchNorm2d(n_filters*4),
        # Use the leaky ReLU
        nn.LeakyReLU(0.2),
        # Perform 2D Transpose Convolution
        nn.ConvTranspose2d(n_filters*4, n_filters*2, 3, 2, 1, bias=False),
        # Perform batch normalization
        nn.BatchNorm2d(n_filters*2),
        # Use the leaky ReLU
        nn.LeakyReLU(0.2),
        # Perform 2D Transpose Convolution
        nn.ConvTranspose2d(n_filters*2, n_filters, 4, 2, 1, bias=False),
        # Perform batch normalization
        nn.BatchNorm2d(n_filters),
        # Use the leaky ReLU
        nn.LeakyReLU(0.2),
        # Perform 2D Transpose Convolution
        nn.ConvTranspose2d(n_filters, 1, 4, 2, 1, bias=False),
        # Use tanh function
        nn.Tanh())
    return model

In [None]:
# Class to define the discriminator network
class Discriminator(nn.Module):
    def __init__(self, n_filters):
        super().__init__()
        # Create a neural network model
        self.network = nn.Sequential(
            # Perform 2D Convolution
            nn.Conv2d(1, n_filters, 4, 2, 1, bias=False),
            # Use the leaky ReLU
            nn.LeakyReLU(0.2),
            # Perform 2D Convolution
            nn.Conv2d(n_filters, n_filters*2, 4, 2, 1, bias=False),
            # Perform batch normalization
            nn.BatchNorm2d(n_filters * 2),
            # Use the leaky ReLU
            nn.LeakyReLU(0.2),
            # Perform 2D Convolution
            nn.Conv2d(n_filters*2, n_filters*4, 3, 2, 1, bias=False),
            # Perform batch normalization
            nn.BatchNorm2d(n_filters*4),
            # Use the leaky ReLU
            nn.LeakyReLU(0.2),
            # Perform 2D Convolution
            nn.Conv2d(n_filters*4, 1, 4, 1, 0, bias=False),
            # Use the sigmoid activation function
            nn.Sigmoid())
    # Forward pass
    def forward(self, input):
        output = self.network(input)
        return output.view(-1, 1).squeeze(0)

In [None]:
# Size of the random noise
z_size = 100
# Size of the image
image_size = (28, 28)
# Filter size
n_filters = 32
# Create a generator netowrk on GPU
gen_model = make_generator_network(z_size, n_filters).to(device)  
# Display the generator model
print(gen_model)
# Create a discriminator network on GPU
disc_model = Discriminator(n_filters).to(device)     
# Display the discriminator model
print(disc_model)

In [None]:
# Loss function and optimizers

# Set the loss function as Binary Cross Entropy loss
loss_fn = nn.BCELoss()
# Use Adam optimizer for generator model parameters
g_optimizer = torch.optim.Adam(gen_model.parameters(), 0.0003)
# Use Adam optimizer for discriminator model parameters
d_optimizer = torch.optim.Adam(disc_model.parameters(), 0.0002)

In [None]:
# Train the discriminator
def d_train(x):
    disc_model.zero_grad()

    # Get the batch size
    batch_size = x.size(0)
    # Send the discriminator to GPU
    x = x.to(device)
    # Label the real data
    d_labels_real = torch.ones(batch_size, 1, device=device)
    # Get the probability of real data
    d_proba_real = disc_model(x)
    # Get the loss of real data
    d_loss_real = loss_fn(d_proba_real, d_labels_real)

    # Train discriminator on a fake batch
    input_z = create_noise(batch_size, z_size, mode_z).to(device)
    g_output = gen_model(input_z)
    # Get the probability of fake data
    d_proba_fake = disc_model(g_output)
    # Label the fake data
    d_labels_fake = torch.zeros(batch_size, 1, device=device)
    # Get the loss of fake data
    d_loss_fake = loss_fn(d_proba_fake, d_labels_fake)

    # Perform backpropagaton by combining real and fake loss variable
    d_loss = d_loss_real + d_loss_fake
    d_loss.backward()
    d_optimizer.step()
  
    return d_loss.data.item(), d_proba_real.detach(), d_proba_fake.detach()

In [None]:
# Train the generator
def g_train(x):
    gen_model.zero_grad()
    
    # Get the batch size
    batch_size = x.size(0)
    # Create noise in GPU 
    input_z = create_noise(batch_size, z_size, mode_z).to(device)
    # Label the real data
    g_labels_real = torch.ones((batch_size, 1), device=device)
    # Get output from the generated model
    g_output = gen_model(input_z)
    # Get the probability of the fake model 
    d_proba_fake = disc_model(g_output)
    # Get the generator loss using probability of fake data and real labels
    g_loss = loss_fn(d_proba_fake, g_labels_real)
    # Perform backpropagation
    g_loss.backward()
    g_optimizer.step()
        
    return g_loss.data.item()

In [None]:
# Create various samples using generator
def create_samples(g_model, input_z):
    g_output = g_model(input_z)
    images = torch.reshape(g_output, (batch_size, *image_size))    
    return (images+1)/2.0

In [None]:
# Set the random noise distrubution model
mode_z = 'normal'
# Generate a noise using the distribution model
fixed_z = create_noise(batch_size, z_size, mode_z).to(device)

In [None]:
# Set the array for storing the images at different epoch samples
epoch_samples = []

# Set the number of epochs
num_epochs = 50

# Set the seed for generating random numbers
torch.manual_seed(1)

#Start training the model
for epoch in range(1, num_epochs+1):    
    gen_model.train()
    d_losses, g_losses = [], []
    for i, (x, _) in enumerate(mnist_dl):
        d_loss, d_proba_real, d_proba_fake = d_train(x)
        d_losses.append(d_loss)
        g_losses.append(g_train(x))
 
    print(f'Epoch {epoch:03d} | Avg Losses >>'
          f' G/D {torch.FloatTensor(g_losses).mean():.4f}'
          f'/{torch.FloatTensor(d_losses).mean():.4f}')
    gen_model.eval()
    epoch_samples.append(
        create_samples(gen_model, fixed_z).detach().cpu().numpy())

In [None]:
# Set different epochs to display 
selected_epochs = [1, 2, 4, 10, 25, 50]
# Create a figure plot to display generated numbers
fig = plt.figure(figsize=(10, 14))
# Display generated numbers at different epochs
for i,e in enumerate(selected_epochs):
    for j in range(6):
        ax = fig.add_subplot(6, 6, i*6+j+1)
        ax.set_xticks([])
        ax.set_yticks([])
        if j == 0:
            ax.text(
                -0.06, 0.5, f'Epoch {e}',
                rotation=90, size=18, color='red',
                horizontalalignment='right',
                verticalalignment='center', 
                transform=ax.transAxes)
        
        image = epoch_samples[e-1][j]
        ax.imshow(image, cmap='gray_r')
plt.show()