In [5]:
import os
import pickle
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm


class Generator(nn.Module):
    """Image generator
    
    Takes a noise vector as input and syntheses a single channel image accordingly
    """

    def __init__(self, input_dims, output_dims):
        """Init function
        
        Declare the network structure as indicated in CW2 Guidance
        
        Arguments:
            input_dims {int} -- Dimension of input noise vector
            output_dims {int} -- Dimension of the output vector (flatten image)
        """
        super(Generator, self).__init__()
        self.fc1 = nn.Sequential(nn.Linear(input_dims, 256), nn.LeakyReLU(0.2))
        
        self.fc2 = nn.Sequential(nn.Linear(256, 512), nn.LeakyReLU(0.2))
        
        self.fc3 = nn.Sequential(nn.Linear(512, 1024), nn.LeakyReLU(0.2))
        # output hidden layer
        self.fc4 = nn.Sequential(nn.Linear(1024, output_dims),nn.Tanh())


    def forward(self, x):
        """Forward function
        
        Arguments:
            x {Tensor} -- a batch of noise vectors in shape (<batch_size>x<input_dims>)
        
        Returns:
            Tensor -- a batch of flatten image in shape (<batch_size>x<output_dims>)
        """
        ###  TODO: modify to be consistent with the network structure
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        x = self.fc4(x)
        return x


class Discriminator(nn.Module):
    """Image discriminator
    
    Takes a image as input and predict if it is real from the dataset or fake synthesised by the generator
    """

    def __init__(self, input_dims, output_dims=1):
        """Init function
        
        Declare the discriminator network structure as indicated in CW2 Guidance
        
        Arguments:
            input_dims {int} -- Dimension of the flatten input images
        
        Keyword Arguments:
            output_dims {int} -- Predicted probability (default: {1})
        """
        super(Discriminator, self).__init__()

        ###  TODO: Change the architecture and value as CW2 Guidance required
        self.fc1 = nn.Sequential(
            nn.Linear(input_dims, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.fc2 = nn.Sequential(
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.fc3 = nn.Sequential(
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.fc4 = nn.Sequential(
            nn.Linear(256, 1),
            nn.Sigmoid()
        )


    def forward(self, x):
        """Forward function
        
        Arguments:
            x {Tensor} -- a batch of 2D image in shape (<batch_size>xHxW)
        
        Returns:
            Tensor -- predicted probabilities (<batch_size>)
        """
        ###  TODO: modify to be consistent with the network structure

        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        x = self.fc4(x)
        return x


def show_result(G_net, z_, num_epoch, show=False, save=False, path='result.png'):
    """Result visualisation
    
    Show and save the generated figures in the grid fashion
    
    Arguments:
        G_net {[nn.Module]} -- The generator instant
        z_ {[Tensor]} -- Input noise vectors
        num_epoch {[int]} -- Indicate how many epoch has the generator been trained
    
    Keyword Arguments:
        show {bool} -- If to display the images (default: {False})
        save {bool} -- If to store the images (default: {False})
        path {str} -- path to store the images (default: {'result.png'})
    """

    ###  TODO: complete the rest of part
    # hint: use plt.subplots to construct grid
    # hint: use a 5*5 grid to show all images 
    # hint: use plt.imshow and plt.savefig to display and store the images
    fig, axs = plt.subplots(5, 5, figsize=(5, 5))
    fig.suptitle('Epoch #{}'.format(num_epoch))

    # Generate fake images
    for ax, image in zip(axs.flat, G_net(z_).cpu().detach().numpy()):
        ax.imshow(image.reshape((28, 28)), cmap='gray')

    # Save and show the plots
    if save:
        fig.savefig(path)
        plt.close()
    if show:
        fig.show()


def show_train_hist(hist, show=False, save=False, path='Train_hist.png'):
    """Loss tracker
    
    Plot the losses of generator and discriminator independently to see the trend
    
    Arguments:
        hist {[dict]} -- Tracking variables
    
    Keyword Arguments:
        show {bool} -- If to display the figure (default: {False})
        save {bool} -- If to store the figure (default: {False})
        path {str} -- path to store the figure (default: {'Train_hist.png'})
    """
    x = range(len(hist['D_losses']))

    y1 = hist['D_losses']
    y2 = hist['G_losses']

    plt.plot(x, y1, label='D_loss')
    plt.plot(x, y2, label='G_loss')

    plt.xlabel('Epoch')
    plt.ylabel('Loss')

    plt.legend(loc=4)
    plt.grid(True)
    plt.tight_layout()

    if save:
        plt.savefig(path)

    if show:
        plt.show()
    else:
        plt.close()


def create_noise(num, dim):
    """Noise constructor
    
    returns a tensor filled with random numbers from a standard normal distribution
    
    Arguments:
        num {int} -- Number of vectors
        dim {int} -- Dimension of vectors
    
    Returns:
        [Tensor] -- the generated noise vector batch
    """
    return torch.randn(num, dim)


if __name__ == '__main__':
    # initialise the device for training, if gpu is available, device = 'cuda', else: device = 'cpu'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    data_dir = './MNIST_data/'
    save_dir = './MNIST_GAN_results/'
    image_save_dir = './MNIST_GAN_results/results'

    # create folder if not exist
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    if not os.path.exists(image_save_dir):
        os.mkdir(image_save_dir)

    # training parameters
    batch_size = 100
    learning_rate = 0.0002
    epochs = 100

    # parameters for Models
    image_size = 28
    G_input_dim = 100
    G_output_dim = image_size * image_size
    D_input_dim = image_size * image_size
    D_output_dim = 1

    # construct the dataset and data loader
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize(mean=(0.5,), std=(0.5,))])
    train_data = datasets.MNIST(root=data_dir, train=True, transform=transform, download=True)
    train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)

    # declare the generator and discriminator networks    
    G_net = Generator(G_input_dim, G_output_dim).to(device)
    D_net = Discriminator(D_input_dim, D_output_dim).to(device)

    # Binary Cross Entropy Loss function
    criterion = nn.BCELoss().to(device)

    # Initialise the Optimizers
    G_optimizer = torch.optim.Adam(G_net.parameters(), lr=learning_rate)
    D_optimizer = torch.optim.Adam(D_net.parameters(), lr=learning_rate)

    # tracking variables
    train_hist = {}
    train_hist['D_losses'] = []
    train_hist['G_losses'] = []
    train_hist['per_epoch_ptimes'] = []
    train_hist['total_ptime'] = []

    start_time = time.time()
    # training loop
    for epoch in range(epochs):
        G_net.train()
        D_net.train()
        Loss_G = []
        Loss_D = []
        epoch_start_time = time.time()
        for (image, _) in tqdm(train_loader):
            image = image.to(device)
            b_size = len(image)
            # creat real and fake labels
            real_label = torch.ones(b_size, 1).to(device)
            fake_label = torch.zeros(b_size, 1).to(device)

            # generate fake images 
            data_fake = G_net(create_noise(b_size, G_input_dim).to(device))
            data_real = image.view(b_size, D_input_dim)

            # --------train the discriminator network----------
            # compute the loss for real and fake images
            output_real = D_net(data_real)
            output_fake = D_net(data_fake)
            loss_real = criterion(output_real, real_label)
            loss_fake = criterion(output_fake, fake_label)
            loss_d = loss_real + loss_fake

            # back propagation
            D_optimizer.zero_grad()
            loss_d.backward()
            D_optimizer.step()

            # -------- train the generator network-----------
            data_fake = G_net(create_noise(b_size, G_input_dim).to(device))

            # compute the loss for generator network
            output_fake = D_net(data_fake)
            loss_g = criterion(output_fake, real_label)

            ## back propagation
            G_optimizer.zero_grad()
            loss_g.backward()
            G_optimizer.step()

            ## store the loss of each iter
            Loss_D.append(loss_d.item())
            Loss_G.append(loss_g.item())

        epoch_loss_g = np.mean(Loss_G)  # mean generator loss for the epoch
        epoch_loss_d = np.mean(Loss_D)  # mean discriminator loss for the epoch
        epoch_end_time = time.time()
        per_epoch_ptime = epoch_end_time - epoch_start_time

        print("Epoch %d of %d with %.2f s" % (epoch + 1, epochs, per_epoch_ptime))
        print("Generator loss: %.8f, Discriminator loss: %.8f" % (epoch_loss_g, epoch_loss_d))

        path = image_save_dir + '/MNIST_GAN_' + str(epoch + 1) + '.png'
        show_result(G_net, create_noise(25, 100).to(device), (epoch + 1), save=True, path=path)

        # record the loss for every epoch
        train_hist['G_losses'].append(epoch_loss_g)
        train_hist['D_losses'].append(epoch_loss_d)
        train_hist['per_epoch_ptimes'].append(per_epoch_ptime)

    end_time = time.time()
    total_ptime = end_time - start_time
    train_hist['total_ptime'].append(total_ptime)

    print('Avg per epoch ptime: %.2f, total %d epochs ptime: %.2f' % (
        np.mean(train_hist['per_epoch_ptimes']), epochs, total_ptime))
    print("Training finish!... save training results")
    with open(save_dir + '/train_hist.pkl', 'wb') as f:
        pickle.dump(train_hist, f)
    show_train_hist(train_hist, save=True, path=save_dir + '/MNIST_GAN_train_hist.png')
    

100%|██████████| 235/235 [00:12<00:00, 19.14it/s]


Epoch 1 of 50 with 12.28 s
Generator loss: 2.13003122, Discriminator loss: 1.08465548


100%|██████████| 235/235 [00:11<00:00, 19.67it/s]


Epoch 2 of 50 with 11.95 s
Generator loss: 3.25853206, Discriminator loss: 0.79882521


100%|██████████| 235/235 [00:11<00:00, 19.82it/s]


Epoch 3 of 50 with 11.86 s
Generator loss: 2.26381321, Discriminator loss: 1.03289728


100%|██████████| 235/235 [00:11<00:00, 19.93it/s]


Epoch 4 of 50 with 11.79 s
Generator loss: 5.02727356, Discriminator loss: 0.95806929


100%|██████████| 235/235 [00:11<00:00, 20.14it/s]


Epoch 5 of 50 with 11.67 s
Generator loss: 1.54193485, Discriminator loss: 1.15454513


100%|██████████| 235/235 [00:12<00:00, 19.56it/s]


Epoch 6 of 50 with 12.01 s
Generator loss: 1.20995346, Discriminator loss: 1.21228088


100%|██████████| 235/235 [00:12<00:00, 19.34it/s]


Epoch 7 of 50 with 12.16 s
Generator loss: 1.88665737, Discriminator loss: 0.83374181


100%|██████████| 235/235 [00:11<00:00, 20.16it/s]


Epoch 8 of 50 with 11.66 s
Generator loss: 2.52689548, Discriminator loss: 0.76217125


100%|██████████| 235/235 [00:11<00:00, 20.04it/s]


Epoch 9 of 50 with 11.73 s
Generator loss: 3.09087994, Discriminator loss: 0.60641950


100%|██████████| 235/235 [00:11<00:00, 19.79it/s]


Epoch 10 of 50 with 11.88 s
Generator loss: 2.41959507, Discriminator loss: 0.63582949


100%|██████████| 235/235 [00:11<00:00, 20.45it/s]


Epoch 11 of 50 with 11.50 s
Generator loss: 2.79933195, Discriminator loss: 0.42930881


100%|██████████| 235/235 [00:11<00:00, 19.79it/s]


Epoch 12 of 50 with 11.88 s
Generator loss: 2.88485130, Discriminator loss: 0.55499313


100%|██████████| 235/235 [00:11<00:00, 19.90it/s]


Epoch 13 of 50 with 11.81 s
Generator loss: 2.63805512, Discriminator loss: 0.53504952


100%|██████████| 235/235 [00:11<00:00, 19.94it/s]


Epoch 14 of 50 with 11.79 s
Generator loss: 2.51752839, Discriminator loss: 0.64117173


100%|██████████| 235/235 [00:11<00:00, 19.77it/s]


Epoch 15 of 50 with 11.89 s
Generator loss: 2.62533147, Discriminator loss: 0.55121856


100%|██████████| 235/235 [00:11<00:00, 19.76it/s]


Epoch 16 of 50 with 11.90 s
Generator loss: 2.59575949, Discriminator loss: 0.62138507


100%|██████████| 235/235 [00:11<00:00, 19.87it/s]


Epoch 17 of 50 with 11.83 s
Generator loss: 3.01976066, Discriminator loss: 0.48891633


100%|██████████| 235/235 [00:11<00:00, 19.62it/s]


Epoch 18 of 50 with 11.98 s
Generator loss: 2.77189101, Discriminator loss: 0.55492506


100%|██████████| 235/235 [00:12<00:00, 19.43it/s]


Epoch 19 of 50 with 12.10 s
Generator loss: 2.62682214, Discriminator loss: 0.56405070


100%|██████████| 235/235 [00:11<00:00, 19.93it/s]


Epoch 20 of 50 with 11.80 s
Generator loss: 3.05919038, Discriminator loss: 0.55058092


100%|██████████| 235/235 [00:11<00:00, 19.62it/s]


Epoch 21 of 50 with 11.98 s
Generator loss: 2.82220148, Discriminator loss: 0.49172249


100%|██████████| 235/235 [00:11<00:00, 19.67it/s]


Epoch 22 of 50 with 11.95 s
Generator loss: 3.01892633, Discriminator loss: 0.46934201


100%|██████████| 235/235 [00:12<00:00, 19.41it/s]


Epoch 23 of 50 with 12.11 s
Generator loss: 2.46242728, Discriminator loss: 0.62859489


100%|██████████| 235/235 [00:12<00:00, 19.45it/s]


Epoch 24 of 50 with 12.08 s
Generator loss: 2.54905207, Discriminator loss: 0.59296365


100%|██████████| 235/235 [00:11<00:00, 20.16it/s]


Epoch 25 of 50 with 11.66 s
Generator loss: 2.54541469, Discriminator loss: 0.58360817


100%|██████████| 235/235 [00:12<00:00, 19.19it/s]


Epoch 26 of 50 with 12.25 s
Generator loss: 2.36031498, Discriminator loss: 0.63162709


100%|██████████| 235/235 [00:12<00:00, 19.52it/s]


Epoch 27 of 50 with 12.04 s
Generator loss: 2.48547286, Discriminator loss: 0.61072952


100%|██████████| 235/235 [00:11<00:00, 20.00it/s]


Epoch 28 of 50 with 11.76 s
Generator loss: 2.47518652, Discriminator loss: 0.63351383


100%|██████████| 235/235 [00:11<00:00, 19.78it/s]


Epoch 29 of 50 with 11.88 s
Generator loss: 2.43427240, Discriminator loss: 0.61507189


100%|██████████| 235/235 [00:11<00:00, 20.23it/s]


Epoch 30 of 50 with 11.62 s
Generator loss: 2.53115766, Discriminator loss: 0.61495362


100%|██████████| 235/235 [00:11<00:00, 20.08it/s]


Epoch 31 of 50 with 11.71 s
Generator loss: 2.39311050, Discriminator loss: 0.63753537


100%|██████████| 235/235 [00:11<00:00, 20.26it/s]


Epoch 32 of 50 with 11.60 s
Generator loss: 2.43121773, Discriminator loss: 0.63407204


100%|██████████| 235/235 [00:11<00:00, 20.37it/s]


Epoch 33 of 50 with 11.54 s
Generator loss: 2.27530244, Discriminator loss: 0.65865806


100%|██████████| 235/235 [00:11<00:00, 20.28it/s]


Epoch 34 of 50 with 11.59 s
Generator loss: 2.51160493, Discriminator loss: 0.61392676


100%|██████████| 235/235 [00:11<00:00, 20.41it/s]


Epoch 35 of 50 with 11.52 s
Generator loss: 2.35674927, Discriminator loss: 0.62927663


100%|██████████| 235/235 [00:11<00:00, 20.48it/s]


Epoch 36 of 50 with 11.48 s
Generator loss: 2.36552587, Discriminator loss: 0.62717471


100%|██████████| 235/235 [00:11<00:00, 20.33it/s]


Epoch 37 of 50 with 11.56 s
Generator loss: 2.15220213, Discriminator loss: 0.66677236


100%|██████████| 235/235 [00:11<00:00, 19.62it/s]


Epoch 38 of 50 with 11.98 s
Generator loss: 2.09346484, Discriminator loss: 0.71669081


100%|██████████| 235/235 [00:11<00:00, 20.39it/s]


Epoch 39 of 50 with 11.53 s
Generator loss: 2.10737166, Discriminator loss: 0.72410951


100%|██████████| 235/235 [00:11<00:00, 20.13it/s]


Epoch 40 of 50 with 11.67 s
Generator loss: 1.99674011, Discriminator loss: 0.75093832


100%|██████████| 235/235 [00:11<00:00, 20.17it/s]


Epoch 41 of 50 with 11.66 s
Generator loss: 2.12174105, Discriminator loss: 0.73641493


100%|██████████| 235/235 [00:11<00:00, 19.59it/s]


Epoch 42 of 50 with 12.00 s
Generator loss: 2.10711469, Discriminator loss: 0.74573004


100%|██████████| 235/235 [00:11<00:00, 20.55it/s]


Epoch 43 of 50 with 11.44 s
Generator loss: 2.10844531, Discriminator loss: 0.70655686


100%|██████████| 235/235 [00:11<00:00, 20.27it/s]


Epoch 44 of 50 with 11.59 s
Generator loss: 2.19442133, Discriminator loss: 0.70456548


100%|██████████| 235/235 [00:11<00:00, 20.17it/s]


Epoch 45 of 50 with 11.66 s
Generator loss: 2.03686439, Discriminator loss: 0.73611650


100%|██████████| 235/235 [00:11<00:00, 20.17it/s]


Epoch 46 of 50 with 11.65 s
Generator loss: 2.02050362, Discriminator loss: 0.76868241


100%|██████████| 235/235 [00:11<00:00, 20.72it/s]


Epoch 47 of 50 with 11.35 s
Generator loss: 1.97328225, Discriminator loss: 0.76440765


100%|██████████| 235/235 [00:11<00:00, 20.56it/s]


Epoch 48 of 50 with 11.43 s
Generator loss: 1.96521156, Discriminator loss: 0.76661925


100%|██████████| 235/235 [00:11<00:00, 20.32it/s]


Epoch 49 of 50 with 11.57 s
Generator loss: 1.96379870, Discriminator loss: 0.76014150


100%|██████████| 235/235 [00:11<00:00, 20.49it/s]


Epoch 50 of 50 with 11.47 s
Generator loss: 1.83736089, Discriminator loss: 0.80985167
Avg per epoch ptime: 11.77, total 50 epochs ptime: 632.22
Training finish!... save training results
