<a href="https://colab.research.google.com/github/Hamza-Ali0237/PyTorch-GAN-Implementation/blob/main/GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Implementing GAN from scratch in PyTorch

Dataset: [https://www.kaggle.com/datasets/kvpratama/pokemon-images-dataset](https://www.kaggle.com/datasets/kvpratama/pokemon-images-dataset)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch as t
import torch.nn as nn
import torchvision as tv
import matplotlib.pyplot as plt

In [None]:
# Deep Convolutional Generator Block
def dc_gen_block(in_dim, out_dim, kernel_size, stride):
  return nn.Sequential(
      nn.ConvTranspose2d(in_dim, out_dim, kernel_size, stride=stride),
      nn.BatchNorm2d(out_dim),
      nn.ReLU()
  )

# Generator Loss Function
def gen_loss(gen, disc, num_images, z_dim):
  noise = t.randn(num_images, z_dim)
  fake = gen(noise)
  disc_pred = disc(fake)
  criterion = nn.BCEWithLogitsLoss()
  gen_loss = criterion(disc_pred, torch.ones_like(disc_pred))
  return gen_loss

# Deep Convolutional Generator Class
class DCGenerator(nn.Module):
  def __init__(self, in_dim, kernel_size=4, stride=2):
    super(DCGenerator, self).__init__()
    self.in_dim = in_dim
    self.gen = nn.Sequential(
        dc_gen_block(in_dim, 1024, kernel_size, stride),
        dc_gen_block(1024, 512, kernel_size, stride),
        dc_gen_block(512, 256, kernel_size, stride),
        nn.ConvTransposed2d(256, 3, kernel_size, stride=stride),
        nn.Tanh()
    )

  def forward(self, x):
    x = x.view(len(x), self.in_dim, 1, 1)
    return self.gen(x)

In [None]:
# Deep Convolutional Generator Block
def dc_disc_block(in_dim, out_dim, kernel_size, stride):
  return nn.Sequential(
      nn.Conv2d(in_dim, out_dim, kernel_size, stride=stride),
      nn.BatchNorm2d(out_dim),
      nn.LeakyReLU(0.2)
  )

# Deep Convolutional Generator Class
class DCDiscrimantor(nn.Module):
  def __init__(self, kernel_size=4, stride=2):
    super(DCDiscriminator, self).__init__()
    self.disc = nn.Sequential(
        dc_disc_block(3, 512, kernel_size, stride),
        dc_disc_block(512, 1024, kernel_size, stride),
        nn.Conv2d(1024, 1, kernel_size, stride)
    )

  def forward(self, x):
    x = self.disc(x)
    return x.view(len(x), -1)