<a href="https://colab.research.google.com/github/Hamza-Ali0237/PyTorch-GAN-Implementation/blob/main/GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Implementing GAN from scratch in PyTorch

Dataset: [https://www.kaggle.com/datasets/kvpratama/pokemon-images-dataset](https://www.kaggle.com/datasets/kvpratama/pokemon-images-dataset)

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import torch as t
import torch.nn as nn
import torchvision as tv
import matplotlib.pyplot as plt

In [None]:
# Deep Convolutional Generator Block
def dc_gen_block(in_dim, out_dim, kernel_size, stride):
  return nn.Sequential(
      nn.ConvTranspose2d(in_dim, out_dim, kernel_size, stride=stride),
      nn.BatchNorm2d(out_dim),
      nn.ReLU()
  )

# Generator Loss Function
def gen_loss(gen, disc, batch_size, z_dim):
    noise = t.randn(batch_size, z_dim, 1, 1, device=next(gen.parameters()).device)
    fake = gen(noise)
    disc_pred = disc(fake)
    criterion = nn.BCEWithLogitsLoss()
    return criterion(disc_pred, t.ones_like(disc_pred))

# Deep Convolutional Generator Class
class DCGenerator(nn.Module):
    def __init__(self, z_dim, kernel_size=4, stride=2):
        super(DCGenerator, self).__init__()
        self.gen = nn.Sequential(
            dc_gen_block(z_dim, 256, kernel_size, stride),
            dc_gen_block(256, 128, kernel_size, stride),
            dc_gen_block(128, 64, kernel_size, stride),
            nn.ConvTranspose2d(64, 3, kernel_size, stride=stride, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = x.view(len(x), -1, 1, 1)
        return self.gen(x)

In [None]:
# Deep Convolutional Discriminator Block
def dc_disc_block(in_dim, out_dim, kernel_size, stride):
    return nn.Sequential(
        nn.Conv2d(in_dim, out_dim, kernel_size, stride=stride, padding=1),
        nn.BatchNorm2d(out_dim),
        nn.LeakyReLU(0.2)
    )

# Discriminator Loss Function

def disc_loss(gen, disc, real, batch_size, z_dim):
    noise = t.randn(batch_size, z_dim, 1, 1, device=next(gen.parameters()).device)
    fake = gen(noise).detach()
    fake_pred = disc(fake)
    real_pred = disc(real)
    criterion = nn.BCEWithLogitsLoss()
    fake_loss = criterion(fake_pred, t.zeros_like(fake_pred))
    real_loss = criterion(real_pred, t.ones_like(real_pred))
    return (fake_loss + real_loss) / 2

# Deep Convolutional Discriminator Class
class DCDiscriminator(nn.Module):
    def __init__(self, kernel_size=4, stride=2):
        super(DCDiscriminator, self).__init__()
        self.disc = nn.Sequential(
            dc_disc_block(3, 64, kernel_size, stride),
            dc_disc_block(64, 128, kernel_size, stride),
            dc_disc_block(128, 256, kernel_size, stride),
            nn.Conv2d(256, 1, kernel_size, stride=stride, padding=0)
        )

    def forward(self, x):
        return self.disc(x).view(len(x), -1)

# Loading Dataset

In [None]:
# Define Transformations
transform = tv.transforms.Compose([
    tv.transforms.Resize(64),
    tv.transforms.ToTensor(),
    tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = tv.datasets.ImageFolder(root="/content/drive/MyDrive/pokemon-dataset/pokemon", transform=transform)

dataloader_train = t.utils.data.DataLoader(dataset, batch_size=128, shuffle=True)

# Training Model

In [None]:
# HYPERPARAMETERS
z_dim = 100
lr = 0.0002
batch_size = 128
EPOCHS = 10

In [None]:
# Check for GPU availability
device = t.device("cuda" if t.cuda.is_available() else "cpu")

# Models
gen = DCGenerator(z_dim).to(device)
disc = DCDiscriminator().to(device)

# Optimizers
gen_opt = t.optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))
disc_opt = t.optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))

In [None]:
# Training Loop


for epoch in range(EPOCHS):
  for real in dataloader_train:
    cur_batch_size = len(real)

    disc_opt.zero_grad()
    disc_loss = disc_loss(gen, disc, real, cur_batch_size, z_dim=16)
    disc_loss.backward()
    disc_opt.step()

    gen_opt.zero_grad()
    gen_loss = gen_loss(gen, disc, cur_batch_size, z_dim=16)
    gen_loss.backward()
    gen_opt.step()