In [None]:
import os
import sys
import time

import numpy as np
import matplotlib.pyplot as plt
import cv2
import utilsFunction
from model import Generator

import torch 
from torch import nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
from torchvision import datasets, transforms
import torchvision
import torch.optim as optim
import torchvision.utils as vutils

In [None]:
# Paramters
device = "cuda" if torch.cuda.is_available() else "cpu"
dataPath = ''
outPath = "output/"

batchSize = 32
imgChannels = 1
zDimension = 100
gDimension = 64
discDimension = 64

imgResize = 64

epochNumber = 50
lr = 2e-4
seed = 1

In [None]:
if device:
    torch.cuda.manual_seed(seed)
cudnn.benchmark = True

# Generator and Dicriminator Models

In [None]:
class Generator(nn.Module):
    def __init__(self, latenDim, featuresGen, outputChannels):
        super(Generator, self).__init__();
        
        self.gen = nn.Sequential(
            # Input -> [B,100,1,1]
            
            # First Layer
            nn.ConvTranspose2d(latenDim, featuresGen*8, 4, 1, 0),
            nn.BatchNorm2d(featuresGen*8),
            nn.ReLU(),
            
            # Second Layer
            nn.ConvTranspose2d(featuresGen*8, featuresGen*4, 4, 2, 1),
            nn.BatchNorm2d(featuresGen*4),
            nn.ReLU(),
            
            # Third layer
            nn.ConvTranspose2d(featuresGen*4, featuresGen*2, 4, 2, 1),
            nn.BatchNorm2d(featuresGen*2),
            nn.ReLU(),
            
            # Fourth Layer
            nn.ConvTranspose2d(featuresGen*2, featuresGen, 4, 2, 1),
            nn.BatchNorm2d(featuresGen),
            nn.ReLU(),
            
            # Output layer
            nn.ConvTranspose2d(featuresGen, outputChannels, 4, 2, 1),
            nn.Tanh()
            )
        
    def forward(self, x):
        return self.gen(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, imgChannels, featuresDisc) -> None:
        super().__init__()

        self.disc = nn.Sequential(
            # Input -> [B,C, 64, 64]
            nn.Conv2d(imgChannels, featuresDisc, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 2nd Layer
            nn.Conv2d(featuresDisc, featuresDisc*2, 4, 2, 1),
            nn.BatchNorm2d(featuresDisc*2),
            nn.LeakyReLU(0.2),
            
            # 3rd Layer
            nn.Conv2d(featuresDisc*2, featuresDisc*4, 4, 2, 1),
            nn.BatchNorm2d(featuresDisc*4),
            nn.LeakyReLU(0.2),
            
            # 4th Layer
            nn.Conv2d(featuresDisc*4, featuresDisc*8, 4, 2, 1),
            nn.BatchNorm2d(featuresDisc*8),
            nn.LeakyReLU(0.2),
            
            # Output Layer
            nn.Conv2d(featuresDisc*8, 1, 4, 1, 0),
            nn.Sigmoid()
        )
    def forward(self,x):
        return self.disc(x)

In [None]:
gen = Generator(zDimension, gDimension, imgChannels).to(device)
disc = Discriminator(imgChannels, discDimension).to(device)

In [None]:
outputGen = gen(torch.randn(2,100,1,1, device=device))
outputDisc = disc(torch.randn(2,1,64,64, device=device))
assert outputGen.shape[1]==imgChannels, "Image Channels not match to parameters"

# outputGen.shape # [2, 1, 64, 64]
# outputDisc.shape, outputDisc #[2, 1, 1, 1]

# Dataset and Dataloader

In [None]:
transforms = transforms.Compose([
    transforms.Resize(imgResize),
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))
])

In [None]:
dataset = datasets.MNIST(root = '', download=False, transform=transforms)
assert dataset,"dataset null value"
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batchSize, shuffle=True)

You can also add a pin_memory=True argument when
calling torch.utils.data.DataLoader() on small datasets, which will
make sure data is stored at fixed GPU memory addresses and thus
increase the data loading speed during training.

# Loss Functions and optimisers 

In [None]:
criterion = nn.BCELoss()

optimizerDisc = optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerGen = optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))

noiseVectotForGen = torch.randn(batchSize, zDimension, 1, 1, device=device)
noiseVectorForGenTesting = torch.randn(batchSize, zDimension, 1, 1, device=device)

fig=plt.figure(figsize=(6, 6))
# Define row and cols in the figure
rows, cols = 2, 1

# Training Loop

In [None]:
DiscLoss = []
GenLoss = []

In [None]:
for i in range(0,epochNumber):
    print(f"=================================== EPOCH:{i}======================================================")
    
    start = time.time()
    trainDiscLoss = 0
    trainGenLoss = 0
    for batch,(data,_) in enumerate(dataloader):
        ## Train Discriminator: max log(D(x)) + log(1-D(G(z)))
        data = data.to(device)
        
        optimizerDisc.zero_grad()
        # 1) Train Discriminator with real data and recongnize it as real i.e., max log(D(x)) 
        discRealOutput = disc(data) #output -> [B,1,1,1]
        discLossReal = criterion(discRealOutput, torch.ones_like(discRealOutput))
        
        # 2)Train the discriminator with the fake data and recognize it as fake i.e., max log(1-D(G(z)))
        genfakeOutput = gen(noiseVectotForGen) #output ->[b,imgChannels, 64, 64]
        discFakeOuput = disc(genfakeOutput.detach()) ##output -> [B,1,1,1]
        discLossFake = criterion(discFakeOuput, torch.zeros_like(discFakeOuput))
        
        
        finalDiscLoss = discLossReal + discLossFake
        trainDiscLoss+=finalDiscLoss.item()
        finalDiscLoss.backward()
        optimizerDisc.step()
        
        # 3)Train the generator with the fake data and recognize it as real i.e., min log(D(G(z)))
        optimizerGen.zero_grad()
        discReal = disc(genfakeOutput)
        lossGenerator = criterion(discReal, torch.ones_like(discRealOutput))
        trainGenLoss+=lossGenerator.item()
        lossGenerator.backward()
        optimizerGen.step()
    DiscLoss.append(trainDiscLoss/len(dataloader))
    GenLoss.append(trainGenLoss/len(dataloader))
    print(f"Discriminator Loss:{trainDiscLoss/(len(dataloader))} and Generator Loss:{trainGenLoss/len(dataloader)}")
        
    # After every 3 epochs will check results
    if i%5==0:
        with torch.no_grad():
            fake = gen(noiseVectorForGenTesting) # [b,3,64,64]
            imgGridReal = torchvision.utils.make_grid(data[:5], normalize=True,nrow=5)
            imgGridFake = torchvision.utils.make_grid(fake[:5], normalize=True,nrow=5)
            images = []
            images.append(imgGridFake)
            images.append(imgGridReal)
            for j in range(0, cols*rows):
                fig.add_subplot(rows, cols, j+1)
                a = images[j].permute(1,2,0)
                b = a.detach().cpu().numpy()
                
                plt.imshow(b)
                plt.show()
    end = time.time()
    print(f"Time duration for Epoch no: {i} is {end-start}")
    print(f"=====================================================================================================")
        

In [None]:
torch.save(gen.state_dict(),f"output/generatorEpoch{i}.pth")
torch.save(disc.state_dict(),f"output/discriminatorEpoch{i}.pth")

In [None]:
x = [i+1 for i in range(0,epochNumber)]

In [None]:
# plot lines
plt.plot(x, DiscLoss, label = "Discriminator Loss")
plt.plot(x, GenLoss, label = "Generator Loss")
plt.legend()
plt.show()