# OSGD - GAN Project - MNIST Fashion


## Contents
1. Setup
2. Loading The Data
3. Defining the GAN Class
4. Instantiating a GAN
5. Training and Testing

## 1. Setup

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt

from torch import nn, optim
from torch.autograd.variable import Variable

import math 
import seaborn as sns # To plot graphs
import pandas as pd

## 2. Loading The Data

In [None]:
train_set = torchvision.datasets.FashionMNIST(        # Gets the train_set from FashionMNIST
    root = './data/FashionMNIST',
    train = True,
    download = True,
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
)

In [None]:
# Helper function to display images
def show_images(images_arry,num_images_to_show,title):
    # Imput: images_arry, an array of tensors to display as images; num_images_to_show, integer representing num images to show
    # Output: prints images as output
    
    n = np.min([len(images_arry),num_images_to_show]) # Number of images to display
    fig, ax = plt.subplots(1,n) # Makes an array of subplot objects 
    fig.suptitle(title)
    
    for index in range(n):
        ax[index].imshow(images_arry[index].squeeze(), cmap='gray_r'); # Investigate the cmap
        ax[index].axis('off')  # clear x-axis and y-axis
    plt.show()

In [None]:
temp_data_loader = torch.utils.data.DataLoader(train_set, batch_size=5, shuffle=True)
temp_dataiter = iter(temp_data_loader)
images,labels = temp_dataiter.next()

In [None]:
show_images(images,5,'Fashion_MNIST')
print(images[0].size())

## 3. Defining the GAN Class

In [None]:
class NeuralNetwork(torch.nn.Module):
    def __init__(self,ModuleList):
        # Input: a torch.nn.ModuleList object, specifying nn layers
        # Output: Instanciates a nn object with the correct layers and activiation functions
        
        super(NeuralNetwork,self).__init__()
        self.layers = ModuleList # Sets network layers
        
    def forward(self,x):
        # Input: a tensor x, with size that agrees with the network
        # Output: a tensor, the network evaluated on x
        
        for l in self.layers:
            x = l(x)
        return x

In [None]:
def ones_target(n):
    # Input: An integer n
    # Output: Tensor of 1's size nx1
    
    return Variable(torch.ones(n,1))
    
def zeros_target(n):
    # Input: An integer n
    # Output: Tensor of 0's size nx1
    
    return Variable(torch.zeros(n,1))

In [None]:
class GAN():
    def __init__(self,d_ModuleList,g_ModuleList,data_dim,latent_dim):
        # Input: d_ModuleList,g_ModuleList are torch.nn.ModuleLists, data_dim is a tuple, latent_dim is a tuple
        # Output: Instance of GAN class 
        
        # Instantiate the discriminator & generator
        self.discriminator = NeuralNetwork(d_ModuleList)
        self.generator = NeuralNetwork(g_ModuleList)
        
        self.data_dim = data_dim # Number of inputs into discriminator = dimention of data, as a tuple
        self.latent_dim = latent_dim # Number of inputs into generator = dimention of latent space
    
    def update_discriminator(self,real_data,generated_data):
        # Input: real_data (minibatch from real data set) ,generated_data (minibatch made by generator)
        # Output: discriminator's loss and mean predictions for real and generated data
        
        # Initial setup/clearing of gradients
        self.d_optimiser.zero_grad()
        loss = nn.BCELoss()
        N,M = real_data.size(0),generated_data.size(0) # Number of data items in each sample
        
        # Apply discriminator to data entered (in order real then fake)
        x = self.discriminator(torch.cat([real_data,generated_data]))
        
        # Create target labels (in order real then fake)
        y = torch.cat([ones_target(N),zeros_target(M)])
        
        # Calculate loss and backprop
        error = loss(x,y)
        error.backward()
        
        # Update discriminator network
        self.d_optimiser.step()
        
        # Return error and mean predictions
        return error.detach().numpy(), torch.mean(x[:N].detach()).numpy(),torch.mean(x[-M:].detach()).numpy()
        
    def update_generator(self,generated_data):
        # Input: generated_data (minibatch made by generator)
        # Output: generators' loss
        
        # Initial setup/clearing of gradients
        self.g_optimiser.zero_grad()
        loss = nn.BCELoss()
        M = generated_data.size(0)
        
        # Apply discriminator to data entered
        x = self.discriminator(generated_data)
        
        # Create target labels
        y = ones_target(M) # Want target to be 1's, opposite of disciminator's aim
        
        # Calculate loss and backprop
        error = loss(x,y)
        error.backward()
        
        # Update generator network
        self.g_optimiser.step()
        
        # Return error
        return error.detach().numpy() 
    
    def batch_train(self,data_batch,d_learning_rate,g_learning_rate):
        # Input: data_batch (in a tensor size 1 by n),d_learning_rate/g_learning_rate, learning rates for each nn
        # Output: returns progress info, updates discriminator and generator networks
        
        # Sets optimisers
        self.d_optimiser = optim.Adam(self.discriminator.parameters(), d_learning_rate)
        self.g_optimiser = optim.Adam(self.generator.parameters(), g_learning_rate)
        
        k = 1 # Number of steps to apply to the discriminator, In original paper this variable is assigned to 1
        n = data_batch.size(0) # Number of data items in  batch

        for i in range(k):
            # Reshape real data if needed
            data_batch = data_batch.reshape((n,)+self.data_dim)
            
            # Get batch of fake data from generator
            generated_data = self.generate(n) # Same size as real_data, detatch so generator gradient not affected
            
            # Optimise discriminator
            d_error,avg_real_pred,avg_fake_pred = self.update_discriminator(data_batch,generated_data.detach())
            
            
        # Optimise generator, with new batch of generated data
        generated_data = self.generate(n) # Need new data here as discriminator has been trained on the 'old' generated_data
        g_error = self.update_generator(generated_data)
        
        return g_error,d_error,avg_real_pred,avg_fake_pred
            
    def noise(self,n):
    # Input: An integer n, the number of samples to make
    # Output: A tensor of size nxdims of random values

        return Variable(torch.randn((n,)+self.latent_dim))
    
    def generate(self,n):
        # Input: n (number of values to geneerate)
        # Output: n generated data items from the generator
        
        return self.generator(self.noise(n)).view((n,)+self.data_dim)

## 4. Instantiating a GAN

Below are two cells, one using a CNN, one using an linear nn. Just comment one out and turn the other one into code to swap back and forth.

In [None]:
fashion_GAN = GAN(nn.ModuleList([
                    nn.Sequential(                     # Module list for discriminator
                        nn.Linear(784, 1024),
                        nn.LeakyReLU(0.2),
                        nn.Dropout(0.3),
                        
                        nn.Linear(1024, 512),
                        nn.LeakyReLU(0.2),
                        nn.Dropout(0.3),

                        nn.Linear(512, 256),
                        nn.LeakyReLU(0.2),
                        nn.Dropout(0.3),
                        
                        torch.nn.Linear(256, 1),
                        torch.nn.Sigmoid()
                    )]),
                  nn.ModuleList([                      # Module list for generator
                    nn.Sequential(
                        nn.Linear(200,256),
                        nn.LeakyReLU(0.1),
                        nn.Dropout(0.3),

                        nn.Linear(256,512),
                        nn.LeakyReLU(0.1),
                        nn.Dropout(0.3),

                        nn.Linear(512,1024),
                        nn.LeakyReLU(0.1),

                        nn.Linear(1024,784),
                        torch.nn.Sigmoid()
                    )]),
                  (784,),                              # Tuple representing data dim into discriminator
                  (200,))                              # Tuple representing latent dim into generator

fashion_GAN = GAN(nn.ModuleList([                      # Module list for discriminator
                    nn.Sequential(
                        nn.Conv2d(1, 128, kernel_size=3, padding=0),
                        nn.ReLU(),
                        nn.MaxPool2d(2),
                        
                        nn.Conv2d(128, 128, kernel_size=3, padding=0),
                        nn.ReLU(),
                        
                        nn.Flatten(),
                        
                        nn.Linear(128*11*11,1),
                        torch.nn.Sigmoid()
                    )]),
                  nn.ModuleList([                      # Module list for generator
                    nn.Sequential(
                        nn.Linear(100,256),
                        nn.LeakyReLU(0.1),
                        nn.Dropout(0.3),

                        nn.Linear(256,512),
                        nn.LeakyReLU(0.1),
                        nn.Dropout(0.3),

                        nn.Linear(512,1024),
                        nn.LeakyReLU(0.1),

                        nn.Linear(1024,784),
                        torch.nn.Sigmoid()
                    )]),
                  (1,28,28),                           # Tuple representing data dim into discriminator
                  (100,))                              # Tuple representing latent dim into generator

In [None]:
data_loader = torch.utils.data.DataLoader(train_set, batch_size=512, shuffle=True)

## 5. Training and Testing

In [None]:
import time
start = time.time()

# Training Constants
N_EPOCHS = 5
d_lr = 0.0001*0.02
g_lr = 0.02

# Log info with no training
print("Training = "+str(0)+" % Complete:")
sample = torch.Tensor(fashion_GAN.generate(7).detach()).view(7,1,28,28)
show_images(sample,7,"EPOCH "+str(0))

for e in range(1,N_EPOCHS+1):
    for i, (images,labels) in enumerate(data_loader):
        g_error,d_error,avg_real_pred,avg_fake_pred = fashion_GAN.batch_train(images,d_lr,g_lr) #d lr then g lr

        # EPOCH Progress info
        epoch_progress = round(i*1000/len(data_loader))/10
        if (epoch_progress) % 20 == 0:
            print("\tEPOCH "+str(epoch_progress)+"% Complete:")
            print("\t\t Generator Loss =     "+str(g_error))
            print("\t\t Discriminator Loss = "+str(d_error))
            
            print("\t\t\t avg prediction on real data = "+str(avg_real_pred)) # Prints avg predictions on real data (discrim wants this to be 1, we want 0.5)
            print("\t\t\t avg prediction on fake data = "+str(avg_fake_pred)) # Prints avg predictions on fake data (discrim wants this to be 0, we want 0.5)
     
            # Display Sample Images
            sample = (fashion_GAN.generate(7).detach()).view(7,1,28,28)
            show_images(sample,7,"EPOCH "+str(e))
    
    # Overall Progress       
    print("Training = "+str(100*e/N_EPOCHS)+" % Complete:")
    
end = time.time()
print('Total Training Time = ',str(end - start)/60 + " mins")