# Generative Adversarial Network

-------------
## Importing Libraries


In [None]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
import torch.optim as optim

In [None]:
import numpy as np
import os.path
from glob import glob
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
import torchvision.datasets as dat_s
import torch.utils.data as data

-----
## GPU Device Configuration

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

-----
## Configuration

### Hyperparameters

In [None]:
latentSize = 100
hiddenSize = 100
inputImgSize = int(64*64*3)
n_epochs = 25
batchSize = 128
outputDirectory = "modelSaved"
learning_rate = 0.00001

### Device

In [None]:
if not os.path.exists(outputDirectory):   
    os.makedirs(outputDirectory)

-----
##  Image processing

In [None]:
some_var = dat_s.ImageFolder(root="/home/mmvc/img",
                           transform=transforms.Compose([
                               transforms.CenterCrop(160),
                               transforms.Resize(64),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# Create the dataloader
dataloader_img = torch.utils.data.DataLoader(some_var, batch_size=batchSize,
                                         shuffle=True, num_workers=2)

img, target = dataloader_img.dataset.__getitem__(202598)
print(target)

-----
##  Creating Network

### Discriminator

In [None]:
class Discriminator(nn.Module):
    
    def __init__(self):
        super(Discriminator,self).__init__()
        
        self.l1 = nn.Sequential(
            nn.Linear(inputImgSize,256),
            nn.LeakyReLU()
        )
       
        self.l2 = nn.Sequential(
            nn.Linear(256,512),
            nn.LeakyReLU()
        )
        
        self.l3 = nn.Sequential(
            nn.Linear(512,512),
            nn.LeakyReLU()
        )
        
        self.l4 = nn.Sequential(
            nn.Linear(512,1),
            nn.Sigmoid()
        )
        
        
    def forward(self,x):
        
        output = self.l1(x)
        output = self.l2(output)
        output = self.l3(output)
        output = self.l4(output)
        
        return output
    

### Generator 

In [None]:
class Generator(nn.Module):
    
    def __init__(self):
        super(Generator,self).__init__()
        
        self.l1 = nn.Sequential(
            nn.Linear(latentSize,256),
            nn.ReLU()
        )
        
        self.l2 = nn.Sequential(
            nn.Linear(256,512),
            nn.ReLU()
        )
        
        self.l3 = nn.Sequential(
            nn.Linear(512,1024),
            nn.ReLU()
        )
        
        self.l4 = nn.Sequential(
            nn.Linear(1024,1024),
            nn.ReLU()
        )
        
        self.l5 = nn.Sequential(
            nn.Linear(1024,inputImgSize),
            nn.Tanh()
        )
        
        
        
    def forward(self,x):
        
        output = self.l1(x)
        output = self.l2(output)
        output = self.l3(output)
        output = self.l4(output)
        output = self.l5(output)
        
        return output
        

### Sending Network to Device

In [None]:
discriminator = Discriminator()
discriminator.to(device)


generator = Generator()
generator.to(device)


### Loss Function and Optimizer

In [None]:
lossFunction = nn.BCELoss()

optimizerD = optim.Adam(discriminator.parameters(), lr = learning_rate)
optimizerG = optim.SGD(generator.parameters(), lr = learning_rate)


-----
##  Training

In [None]:
totalSteps = len(dataloader_img)
discriminatorLosses = []
generatorLosses = []

for epoch in range(n_epochs):
    
    for i, (X,y) in enumerate(dataloader_img):
        
        if (i==len(dataloader_img)-1):
            continue
                    
        X = X.reshape(batchSize,-1)
        
        X = X.to(device)
                
        realLabels = torch.ones(batchSize, 1).to(device)
        fakeLabels = torch.zeros(batchSize, 1).to(device)
        
        #train the descriminator
        optimizerD.zero_grad()
        discriminatorResult_real = discriminator(X)
        lossDis_real = lossFunction(discriminatorResult_real, realLabels)
        lossDis_real.backward()
        
        z = torch.randn(batchSize, latentSize)
        z = z.to(device)
      
        generatorResult_z = generator(z)
        discriminatorResult_z = discriminator(generatorResult_z.detach())
        lossDis_fake = lossFunction(discriminatorResult_z, fakeLabels)
        lossDis_fake.backward()
        
        optimizerD.step()
        
        
        #train the generator
        optimizerG.zero_grad()
        
        generatorResult = discriminator(generatorResult_z)
        loss_real = lossFunction(generatorResult, realLabels)
        loss_real.backward()
        
        optimizerG.step()       
        
        if (i % 300 == 0):
            print("Epoch: {}".format(epoch)) 
            print("Index: {}".format(i))
            print("Discriminator Loss: {}".format(lossDis_fake.item() + lossDis_real.item())) 
            print("Generator Loss: {}".format(loss_real.item()))
    
    discriminatorLosses.append(lossDis_fake.item() + lossDis_real.item())
    generatorLosses.append(loss_real.item())
        

-----------

## Visualizing Data

In [None]:
c = np.arange(0,len(discriminatorLosses))

plt.figure(figsize = [10,6])

plt.plot(c,discriminatorLosses, marker = "+",color = "r")
plt.plot(c,generatorLosses, marker = "x", color = "k")
plt.title("Discriminator and Generator Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend(("Discriminator","Generator"))
plt.show()

In [None]:
plt.figure(figsize = [10,6])
plt.plot(c,discriminatorLosses, marker = "+",color = "r")
plt.title("Discriminator Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()

## Summarizing Results

In [None]:
maxGen = max(generatorLosses)
maxDis = max(discriminatorLosses)

minGen = min(generatorLosses)
minDis = min(discriminatorLosses)

print("Maximum Generator Loss: ", maxGen)
print("Minimum Generator Loss: ", minGen)

print()

print("Maximum Discriminator Loss: ", maxDis)
print("Minimum Discriminator Loss: ", minDis)