In [None]:
import torch 
from torch import nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import numpy as np
import matplotlib.pyplot as plt

import time

In [None]:
imgSize = 28
classes = 10
channels = 1
latenDim = 100
epochNumber = 50
lr = 2e-4
batchSize = 32
zDimension = 100
imgResize = 28

device = "cuda" if torch.cuda.is_available() else "cpu"
device

seed = 1

In [None]:
if device:
    torch.cuda.manual_seed(seed)
cudnn.benchmark = True

# Class Generator

In [None]:
class Generator(nn.Module):
    def __init__(self, classes, channels, imgSize, latenDim):
        super(Generator, self).__init__();
        self.classes = classes
        self.channels = channels
        self.imgSize = imgSize
        self.latenDim = latenDim
        self.imgShape = (self.channels, self.imgSize, self.imgSize)
        self.labelEmbedding = nn.Embedding(self.classes, self.classes)
        
        self.interLayers = nn.Sequential(
            self._linearBlock(self.latenDim + self.classes, 128),
            self._linearBlock(128, 256),
            self._linearBlock(256, 512),
            self._linearBlock(512, 1024)
        )
        
        # Converting 1024 to 784 hidden layers for converting back it to 1x28x28
        self.linear = nn.Linear(1024, 784)
        self.tanh = nn.Tanh()
        
        
    def _linearBlock(self, inputHiddens, outputHiddens):
        return nn.Sequential(
            nn.Linear(in_features=inputHiddens, out_features=outputHiddens),
            nn.BatchNorm1d(outputHiddens),
            nn.LeakyReLU(0.2, inplace=True)
        )
    def forward(self, noise, labels):
        #noise -> [B, 100] labels->[B]
        print(noise.shape, labels.shape)
        ll = self.labelEmbedding(labels)
        z = torch.concat([noise, ll], dim=1)
        x = self.interLayers(z)
        return self.tanh(self.linear(x)) # Output -> [B,28*28] = [B, 784]
        

In [None]:
gen = Generator(classes, channels,imgSize, latenDim).to(device)
# outGen = gen(torch.randn(2,100), torch.randint(low=0, high=10, size=(2,)))
# outGen.shape

In [None]:
torch.randint(low=0, high=10, size=(2,))

# Class Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, classes, channels, imgSize, latentDim):
        super(Discriminator, self).__init__();
        
        self.classes = classes
        self.channels = channels
        self.imgSize = imgSize
        self.latentDim = latenDim
        self.imgShape = (self.channels, self.imgSize, self.imgSize)
        self.labelEmbedding = nn.Embedding(self.classes, self.classes)
        
        self.interLayers = nn.Sequential(
            self._linearBlock(self.classes+int(np.prod(self.imgShape)), 1024),
            self._linearBlock(1024, 512),
            self._linearBlock(512, 256),
        )
        self.linear1 = nn.Linear(256,128)
        self.linear2 = nn.Linear(128,1)
        self.sigmoid = nn.Sigmoid()
        
    def _linearBlock(self, inHiddens, outHiddens):
        return nn.Sequential(
            nn.Linear(in_features=inHiddens, out_features=outHiddens),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2)
        )
        
    def forward(self, x, labels):
        # x->[B, 784] LABELS -> [B]
        ll = self.labelEmbedding(labels)
        y = torch.concat([x, ll], dim=1) 
        outputs = self.interLayers(y)
        return self.sigmoid(self.linear2(self.linear1(outputs))) # output -> [B,1]
        

In [None]:
disc = Discriminator(classes, channels, imgSize, latenDim).to(device)
# discOut = disc(torch.randn(2,784), torch.randint(0, classes, size=(2,)))
# discOut.shape

# Dataset and Dataloader

In [None]:
transforms = transforms.Compose([
    transforms.Resize(imgResize),
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))
])

In [None]:
data = datasets.MNIST(root="", download=True, transform=transforms)
dataloader = DataLoader(data, batch_size=batchSize, shuffle=True)

# Loss Function, Optimiser and Noise Vector

In [None]:
criterion = nn.BCELoss()

# Optimizer
optimDisc = torch.optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))
optimGen = torch.optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))

# Noise Vector
noiseVectotForGen = torch.randn(batchSize, zDimension, 1, 1, device=device)
noiseVectorForGenTesting = torch.randn(batchSize, zDimension, 1, 1, device=device)

fig=plt.figure(figsize=(6, 6))
# Define row and cols in the figure
rows, cols = 2, 1

# Training Loop

In [None]:
DiscLoss = []
GenLoss = []