# ECE 50024 Final Project - Conditional Generative Adversarial Networks by Abhinav Rao

## Setup

In [2]:
# Standard Libraries
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import time
import math
import json
import os
import pickle 
%matplotlib inline

# Neural Network Libraries
import torch
import torch.nn as nn
import torchvision
from torch.autograd import Variable
from torchvision.utils import make_grid

#for colab purposes only
# from google.colab import drive
# drive.mount('/content/gdrive')

<b>Accessing GPU if available and defining Log Folder path</b>

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logPath='GAN/log11'

<b>Selecting Hyperparameters and saving to log file</b>

In [6]:
batchSize = 32 #inputs in one pass of gradient
leakyReLUNegSlope = 0.2 #slope for negative part of leaky ReLU, 0 means ReLU
dropoutRate = 0.3 #what factor of neurons to turn off
dataDim = 784 # 28x28
labelDim = 10 # 0-9 
noiseDim = 100 # Z dimension 
learningRateG = 1e-4 # learning rate (alpha) of generator
learningRateD = 1e-4 # learning rate (alpha) of discriminator
nEpochs = 100 #how many training cycles/epochs
nCritic = 5 #wasserstein number of critical steps - ncritic = 1 --> vanilla GAN
nLog = 10

hyperParamDict = {
    'batch size':batchSize,
    'generator learning rate': learningRateG,
    'discriminator learning rate': learningRateD,
    'number of epochs':nEpochs,
    'dropout rate': dropoutRate,
    'leaky RELU negative slope': leakyReLUNegSlope,
    'wasserstein nCritic': nCritic}
  

with open(os.path.join(logPath,'hyperparameters.txt'), 'w') as logfile:
    logfile.write(json.dumps(hyperParamDict))

## Model Definition

<b>Downloading data and Pre-Processing</b>

In [7]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(), #converts to pytorch object called tensor, like array for numpy
    torchvision.transforms.Normalize([0.5], [0.5]) #Normalize to mean 0.5, stdev = 0.5
])


dataLoader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('data', #make new folder and save in it
                               train=True, 
                               download=True, 
                               transform=transform),
    batch_size=batchSize, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz


11.2%

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100.0%


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz


100.0%

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz



100.0%
100.0%


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



<b>Defining Discriminator</b>

In [8]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__() #Inherits the init of superclass 
        self.labelEmbedding = nn.Embedding(labelDim, labelDim) #labelsize, onehot
        self.model = nn.Sequential(
            nn.Linear(dataDim + labelDim, 1000),
            nn.LeakyReLU(leakyReLUNegSlope, inplace=True),
            nn.Dropout(dropoutRate),
            nn.Linear(1000, 500),
            nn.LeakyReLU(leakyReLUNegSlope, inplace=True),
            nn.Dropout(dropoutRate),
            nn.Linear(500, 250),
            nn.LeakyReLU(leakyReLUNegSlope, inplace=True),
            nn.Dropout(dropoutRate),
            nn.Linear(250, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x, y):
        x = x.view(x.size(0), dataDim)
        c = self.labelEmbedding(y)
        x = torch.cat([x, c], 1)
        out = self.model(x)
        return out.squeeze()

<b>Defining Generator</b>

In [9]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.labelEmbedding = nn.Embedding(labelDim, labelDim)
        
        self.model = nn.Sequential(
            nn.Linear(noiseDim + labelDim , 250),
            nn.LeakyReLU(leakyReLUNegSlope, inplace=True),
            nn.Linear(250, 500),
            nn.LeakyReLU(leakyReLUNegSlope, inplace=True),
            nn.Linear(500, 1000),
            nn.LeakyReLU(leakyReLUNegSlope, inplace=True),
            nn.Linear(1000, dataDim),
            nn.Tanh()
        )
    
    def forward(self, z, labels):
        z = z.view(z.size(0), noiseDim)
        c = self.labelEmbedding(labels)
        x = torch.cat([z, c], 1)
        out = self.model(x)
        return out.view(x.size(0),int(np.sqrt(dataDim)),int(np.sqrt(dataDim)))

### Defining training functions

<b>Instantiating Generator and Discriminator and defining Loss and Optimizers</b>

In [10]:
generator = Generator().to(device)
discriminator = Discriminator().to(device)
loss = nn.BCELoss()
discriminatorOptimizer = torch.optim.Adam(discriminator.parameters(), lr=learningRateD)
generatorOptimizer = torch.optim.Adam(generator.parameters(), lr=learningRateG)

<b>Instantiating Generator and Discriminator and defining Loss and Optimizers</b>

In [11]:
def trainGenerator(batchSize, discriminator, generator, generatorOptimizer, loss):
    generatorOptimizer.zero_grad()
    z = Variable(torch.randn(batchSize, noiseDim)).to(device)
    generatedLabels = Variable(torch.LongTensor(np.random.randint(0, labelDim, batchSize))).to(device)
    generatedData = generator(z, generatedLabels)
    generatorLoss = loss(discriminator(generatedData, generatedLabels), Variable(torch.ones(batchSize)).to(device))
    generatorLoss.backward()
    generatorOptimizer.step()
    return generatorLoss.item()

def trainDiscriminator(batchSize, discriminator, generator, discriminatorOptimizer, loss, realData, realLabels):
    
    discriminatorOptimizer.zero_grad()
    
    realLoss = loss(discriminator(realData, realLabels), Variable(torch.ones(batchSize)).to(device))
    
    z = Variable(torch.randn(batchSize, noiseDim)).to(device)
    generatedLabels = Variable(torch.LongTensor(np.random.randint(0, labelDim, batchSize))).to(device)
    generatedData = generator(z, generatedLabels)
    generatorLoss = loss(discriminator(generatedData, generatedLabels), Variable(torch.zeros(batchSize)).to(device))
    
    discriminatorLoss = realLoss + generatorLoss
    discriminatorLoss.backward()
    discriminatorOptimizer.step()
    return discriminatorLoss.item()

def trainGAN(dataLoader, generator, discriminator, nEpochs=50, nLog = 10, nCritic=1):
    times = []
    discriminatorLosses = []
    generatorLosses = []
    log = {}

    for epoch in range(nEpochs):
        print('Epoch '+str(epoch+1), end=' - Discriminator Loss: ')
        start = time.time()
        for i, (data, labels) in enumerate(dataLoader):
            realData = Variable(data).to(device)
            realLabels = Variable(labels).to(device)
            generator.train()

            for i in range(nCritic):
                discriminatorLoss = trainDiscriminator(len(realData), discriminator,
                                                generator, discriminatorOptimizer, loss,
                                                realData, realLabels)


            generatorLoss = trainGenerator(batchSize, discriminator, generator, generatorOptimizer, loss)
        stop = time.time()
        print(round(discriminatorLoss,4), end=' - Generator Loss: ')
        discriminatorLosses.append(discriminatorLoss)
        print(round(generatorLoss,4))
        generatorLosses.append(generatorLoss)
        times.append(stop - start)

        if (epoch+1)%(nEpochs//nLog)==0:
            log[str(epoch+1)] = generateSamples(generator, 10, random = False)

    performanceDF = pd.DataFrame(columns = ['epoch','time','discriminatorLoss','generatorLoss'])
    performanceDF['epoch'] = range(nEpochs)
    performanceDF['time'] = times
    performanceDF['discriminatorLoss'] = discriminatorLosses
    performanceDF['generatorLoss'] = generatorLosses

    return generator, performanceDF, log

<b>Diagnosis Functions</b>

In [12]:
def generateSamples(generator, nSamples, random=True):
    if not random:
        return generator(torch.randn(nSamples, noiseDim).to(device),
                     torch.LongTensor(np.arange(nSamples)%10).to(device))
    return generator(torch.randn(nSamples, noiseDim).to(device),
                     torch.LongTensor(np.random.randint(0, labelDim, nSamples)).to(device))

def displaySamples(samples, title = 'Digits', savePath = False):
    nSamples = samples.shape[0]
    nrows = math.ceil(nSamples/10)
    fig,ax = plt.subplots(nrows = nrows, ncols = 10, figsize = (15,2*nrows),squeeze=False)
    for i in range(nrows*10):
        if i<nSamples:
            pixels = samples[i].cpu().detach().numpy()
            ax[i//10][i%10].imshow(pixels, cmap='gray')
            ax[i//10][i%10].set_yticklabels([])
            ax[i//10][i%10].set_xticklabels([])
        else:
            fig.delaxes(ax[i//10][i%10])
    plt.subplots_adjust(wspace=0.1, hspace=0.1)
    plt.suptitle(title)
    if savePath:
        plt.savefig(os.path.join(savePath,title))
    plt.show()

def saveLog(trainedGenerator, trainingPerformance, dataLog, logPath):
    trainingPerformance.to_csv(os.path.join(logPath, 'trainingLog.csv'),index=False)
    for key,value in dataLog.items():
        displaySamples(value,title = 'Epoch '+key+ ': All Digits', savePath = logPath)
    modelfile = open(os.path.join(logPath,'generator.pkl'), 'wb')
    pickle.dump(trainedGenerator, modelfile)
    modelfile.close()
    print('Training details saved')

<b></b>

In [13]:
trainedGenerator, trainingPerformance, dataLog = trainGAN(dataLoader, generator, discriminator, nEpochs= nEpochs, nLog = nLog, nCritic=nCritic)
saveLog(trainedGenerator, trainingPerformance, dataLog, logPath)

Epoch 1 - Discriminator Loss: 0.0003 - Generator Loss: 12.8884
Epoch 2 - Discriminator Loss: 0.0031 - Generator Loss: 8.3807
Epoch 3 - Discriminator Loss: 