<a href="https://colab.research.google.com/github/SarveshD7/Conditional-GAN-Pytorch/blob/main/Conditional_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 5e-5
BATCH_SIZE = 64
IMAGE_SIZE = 64
CHANNELS_IMG = 1
NUM_CLASSES = 10
GEN_EMBEDDING = 100
Z_DIM = 100
NUM_EPOCHS = 5
FEATURES_critic = 64
FEATURES_GEN = 64
CRITIC_ITERATIONS = 5
WEIGHT_CLIP = 0.01

The idea is to add an additional channel showing the label of the image so that the required output according to the condition can be generated

In [None]:
def gradient_penalty(critic,labels, real, fake, device="cpu"):
  BATCH_SIZE, C, H, W = real.shape
  epsilon = torch.rand((BATCH_SIZE,1, 1, 1)).repeat(1, C, H, W).to(device)
  interpolated_images = real*epsilon + fake*(1-epsilon)

  #  Calculate Critic scores
  mixed_scores = critic(interpolated_images, labels)
  gradient = torch.autograd.grad(
      inputs=interpolated_images,
      outputs=mixed_scores,
      grad_outputs = torch.ones_lime(mixed_scores),
      create_graph = True,
      retain_graph = True
  )[0]

  gradient = gradient.view(gradient.shape[0], -1)
  gradient_norm = gradient.norm(2, dim=1)
  gradient_penalty = torch.mean((gradient_norm- 1)**2)
  return gradient_penalty

In [None]:
class Discriminator(nn.Module):
  def __init__(self, channels_img, features_d, num_classes, img_size):  # features_d is channels that are going to change in different layers
    super(Discriminator, self).__init__()
    self.img_size = img_size
    self.critic = nn.Sequential(
        # Input: N x channels_img x 64 x 64
        nn.Conv2d(channels_img+1, features_d, kernel_size=4, stride=2, padding=1),  # 32x32
        nn.LeakyReLU(0.2),
        self._block(features_d, features_d*2, 4,2,1),  # 16x16
        self._block(features_d*2, features_d*4, 4,2,1),  # 8x8
        self._block(features_d*4, features_d*8, 4,2,1),  # 4x4
        nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0), # 1x1
    )
    self.embed = nn.Embedding(num_classes, img_size*img_size)
    # The nn.Embedding layer is a simple lookup table that maps an index value to a weight matrix of a certain dimension
  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
        nn.InstanceNorm2d(out_channels, affine=True),  # LayerNorm <----> InstanceNorm
        nn.LeakyReLU(0.2)
    )

  def forward(self,x, labels):
    embedding = self.embed(labels).view(labels.shape[0], 1, self.image_size, self.img_size)
    x = torch.cat([x, embedding], dim=1)  # N x C x img_size x img_size

    return self.critic(x)


In [None]:
class Generator(nn.Module):
  def __init__(self, z_dim, channels_img, features_g, num_classes, img_size, embed_size):
    super(Generator, self).__init__()
    self.img_size = img_size
    # Input: N x z_dim x 1 x 1
    self.gen = nn.Sequential(
        self._block(z_dim+embed_size, features_g*16, 4, 1, 0), # 4x4
        self._block(features_g*16, features_g*8, 4, 2, 1), # 8x8
        self._block(features_g*8, features_g*4, 4, 2, 1), # 16x16
        self._block(features_g*4, features_g*2, 4, 2, 1), # 32x32
        nn.ConvTranspose2d(features_g*2,channels_img, kernel_size=4, stride=2, padding=1),  # 64x64
        nn.Tanh()  # Range = [-1,1]
    )
    self.embed = nn.Embedding(num_classes, embed_size)
  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
        # nn.BatchNorm2d(out_channels),
        nn.ReLU()
    )

  def forward(self, x, labels):
    embedding = self.embed(labels).unsqueeze(2).unsqueeze(3)
    x = torch.cat([x, embedding], dim=1)
    return self.gen(x)


In [None]:
# Initializing weights of the model with a normal distribution with given mean and standard deviation
def initialize_weights(model):
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data, 0.0, 0.02)

In [None]:
transforms = torchvision.transforms.Compose(
   [ torchvision.transforms.Resize(IMAGE_SIZE),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)],
    )]
)

In [None]:
dataset = datasets.MNIST(root="/content/dataset/", train=True, transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN, NUM_CLASSES, IMAGE_SIZE, GEN_EMBEDDING).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_critic, NUM_CLASSES,IMAGE_SIZE).to(device)
initialize_weights(gen)
initialize_weights(critic)

In [None]:
LAMBDA_GP = 10
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
# criterion = nn.BCELoss()  Not needed in WGAN

# Wrting the training loop
for epoch in range(NUM_EPOCHS):
  print(f"Started epoch-> {epoch}")
  for batch_idx, (real, labels) in enumerate(loader):
    real = real.to(device)
    labels = labels.to(device)
    # Train the Discriminator (Critic) CRITIC_ITERATIONS time for every time Generator is trained
    for _ in range(CRITIC_ITERATIONS):
      noise = torch.randn((BATCH_SIZE, Z_DIM, 1, 1)).to(device)
      fake = gen(noise)
      critic_real = critic(real, labels).reshape(-1)
      critic_fake = critic(fake, labels).reshape(-1)
      gp = gradient_penalty(critic, labels, real, fake, device=device)
      loss_critic = (
          -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP*gp
      )
      critic.zero_grad()
      loss_critic.backward(retain_graph=True)  # retain_graph=True since we want to reuse fake and Pytorch erases intermediate results on loss.backward()
      opt_critic.step()

    # Training the Generator
    output = critic(fake, labels).reshape(-1)
    loss_gen = -torch.mean(output)
    gen.zero_grad()
    loss_gen.backward()
    opt_gen.step()
  print(f"Completed epoch-> {epoch}")
