<a href="https://colab.research.google.com/github/Diishasing/GANs/blob/main/Conditional_GANs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

NECESSARY IMPORTS

In [None]:
import torch 
import torch.nn as nn
from torch.nn.modules.conv import ConvTranspose2d
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

DISCRIMINATOR

In [None]:
class Discriminator(nn.Module):
  def __init__(self, channels_img, features_d, num_classes, img_size):
    super(Discriminator, self).__init__()
    self.img_size = img_size
    self.disc = nn.Sequential(
        nn.Conv2d(
            channels_img + 1, features_d, kernel_size = 4, stride = 2, padding = 1
        ),
        nn.LeakyReLU(0.2),
        self._block(features_d, features_d * 2, 4, 2, 1),
        self._block(features_d * 2, features_d * 4, 4, 2, 1),
        self._block(features_d * 4, features_d * 8, 4, 2, 1),
        nn.Conv2d(features_d * 8, 1, 4, 2, 0),
        nn.Sigmoid(),
    )
    self.embed = nn.Embedding(num_classes, img_size * img_size)

  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            bias = False
        ),
        nn.InstanceNorm2d(out_channels),
        nn.LeakyReLU(0.2),

    )  

  def forward(self, x, labels):
    embedding = self.embed(labels).view(labels.shape[0], -1, self.img_size, self.img_size)
    x = torch.cat([x, embedding], dim = 1)
    return self.disc(x)  

GENERATOR

In [None]:
class Generator(nn.Module):
  def __init__(self, z_dim, channels_img, features_g, num_classes, img_size, embed_size):
    super(Generator, self).__init__()
    self.img_size = img_size
    self.gen = nn.Sequential(
        self._block(z_dim + embed_size, features_g * 16, 4, 1, 0),
        self._block(features_g * 16, features_g * 8, 4, 2, 1),
        self._block(features_g * 8, features_g * 4, 4, 2, 1),
        self._block(features_g * 4, features_g * 2, 4, 2, 1),
        nn.ConvTranspose2d(
            features_g * 2, channels_img, kernel_size = 4, stride = 2, padding = 1,
        ),
        nn.Tanh(),
    )
    self.embeb = nn.Embedding(num_classes, embed_size)

  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            bias = True
        ),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
    )

  def forward(self, x, labels):
    embedding = self.embed(labels).unsqueeze(2).unsqueeze(3)
    x = torch.cat([x, embedding], dim = 1)
    return self.gen(x)

In [None]:
def gradient_penalty(critic, labels, real, fake, device = 'cpu'):
  BATCH_SIZE, C, H, W = real.shape
  epsilon = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
  interpolated_images = real * epsilon + fake * (1 - epsilon)

  #calculate critic scores
  mixed_scores = critic(interpolated_images, labels)

  gradient = torch.autograd.grad(
      inputs = interpolated_images,
      outputs = mixed_scores,
      grad_outputs = torch.ones_like(mixed_scores),
      create_graph = True,
      retain_graph = True,
  )[0]
  
  gradient = gradient.view(gradient.shape[0], -1)
  gradient_norm = gradient.norm(2, dim = 1)
  gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
  return gradient_penalty

In [None]:
device  = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
l_r = 1e-4
BATCH_SIZE = 64
IMAGE_SIZE = 64
CHANNELS_IMG = 1
Z_DIM = 100
NUM_CLASSES = 10
GEN_EMBEDDING = 100
FEATURES_CRITIC = 16
FEATURES_GEN = 16
EPOCHS = 5
CRITIC_ITERATIONS = 5
LAMBDA_GP = 10 #as stated in the paper

In [None]:
transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
    ),
])

dataset = datasets.MNIST(root = '/dataset',
                         train = True,
                         transform = transform,
                         download = True)

dataloader = DataLoader(dataset,
                        batch_size = BATCH_SIZE,
                        shuffle = True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /dataset/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting /dataset/MNIST/raw/train-images-idx3-ubyte.gz to /dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /dataset/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting /dataset/MNIST/raw/train-labels-idx1-ubyte.gz to /dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting /dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to /dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting /dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to /dataset/MNIST/raw



In [None]:
def initialize_weights(model):
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data, 0.0, 0.02)

In [None]:
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN,
                NUM_CLASSES, IMAGE_SIZE, GEN_EMBEDDING).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_CRITIC,
                       NUM_CLASSES, IMAGE_SIZE).to(device)

initialize_weights(gen)
initialize_weights(critic)

opt_gen = optim.Adam(gen.parameters(), lr = l_r, betas = (0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr = l_r, betas = (0.0, 0.9))

fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)

#tensorboard directory writer
writer_real = SummaryWriter(f'runs/real')
writer_fake = SummaryWriter(f'runs/critic')
step = 0

gen.train()
critic.train()

Discriminator(
  (disc): Sequential(
    (0): Conv2d(2, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (4): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (5): Conv2d(128, 1, kernel_size=(4, 4), stride=(2, 2))
    (6): Sigmoid()
  )
  (embe

In [None]:
for epoch in range(EPOCHS):
  for batch_idx, (real, labels) in enumerate(dataloader):
    real = real.to(device)
    BATCH_SIZE = real.shape[0]
    labels = labels.to(device)
#critic loss
    for labels in range(CRITIC_ITERATIONS):
      noise = torch.randn((BATCH_SIZE, Z_DIM, 1, 1)).to(device)
      fake = gen(noise, labels)
      critic_real = critic(real, labels).reshape(-1)
      critic_fake = critic(fake, labels).reshape(-1)
      gp = gradient_penalty(critic, labels, real , fake, device = device)
      loss_critic = (-(torch.mean(critic_real) - torch.mean(critic_fake)) + 
                     LAMBDA_GP * gp)
      critic.zero_grad()
      loss_critic.backward(retain_graph = True)
      opt_critic.step()


#generator loss
    output = critic(fake, labels).reshape(-1)
    loss_gen = -torch.mean(output)
    gen.zero_grad()
    loss_gen.backward()
    opt_gen.step()            

    #print the losses to the tensorboard
    if batch_idx % 100 == 0 and batch_idx > 0:
      print(f'Epoch [{epoch}/{EPOCHS}]\
            Batch [{batch_idx} / {len(dataloader)}] \
            Loss D {loss_critic:.4f} \
            Loss G {loss_gen:.4f} ')

      with torch.no_grad():
        fake = gen(noise, labels)

        img_grid_real = torchvision.utils.make_grid(real[:25], normalize = True)
        img_grid_fake = torchvision.utils.make_grid(fake[:25], normalize = True)

        writer_real.add_image('real', img_grid_real, global_step = step)
        writer_real.add_image('fake', img_grid_fake, global_step = step)

        step += 1


AttributeError: ignored