# Conditional GAN (C-GAN)
Originally proposed by [Mirza et al.](https://arxiv.org/pdf/1411.1784.pdf) is their work titled Conditional Generative Adversarial Nets. This network uses a basic implementation where generator and discriminator models are MLPs with additional inputs for conditioning with class labels. This notebook trains both networks using ADAM optimizer to play the minimax game. We showcase the effectiveness using MNIST digit generation

## Load Libraries

In [1]:
import os
import numpy as np
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets
from torch.autograd import Variable
from torch.utils.data import DataLoader

from torchvision.utils import save_image
import torchvision.transforms as transforms

## Set Parameters

In [2]:
CUDA = True if torch.cuda.is_available() else False

In [3]:
NUM_CHANNELS = 1
N_CLASSES = 10
IMG_DIM = 32
BATCH_SIZE = 32
Z_DIM = 100 # Noise Vector Dimension
N_EPOCHS = 50
SAMPLE_INTERVAL = 400
IMG_SHAPE = (NUM_CHANNELS,IMG_DIM, IMG_DIM)

## Get MNIST Dataset

In [4]:
# create directory
os.makedirs("images", exist_ok=True)

# download dataset
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(IMG_DIM), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=BATCH_SIZE,
    shuffle=True,
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/mnist/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 160656104.23it/s]

Extracting data/mnist/MNIST/raw/train-images-idx3-ubyte.gz to data/mnist/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 111133664.06it/s]


Extracting data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 49227630.81it/s]


Extracting data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 6303947.31it/s]


Extracting data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/mnist/MNIST/raw



## Discriminator Model

In [5]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.label_embedding = nn.Embedding(N_CLASSES, N_CLASSES)

        self.model = nn.Sequential(
            nn.Linear(N_CLASSES + int(np.prod(IMG_SHAPE)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1),
        )

    def forward(self, img, labels):
        # concatenate embedded label vector with image to get final input vector
        input_vector = torch.cat(
            (img.view(img.size(0), -1), self.label_embedding(labels)),
            -1
            )
        validity = self.model(input_vector)
        return validity

## Generator Model

In [11]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.label_emb = nn.Embedding(N_CLASSES, N_CLASSES)

        def block(in_feat_shape, out_feat_shape):
            layers = [
                nn.Linear(in_feat_shape, out_feat_shape),
                nn.BatchNorm1d(out_feat_shape, 0.8),
                nn.LeakyReLU(0.2, inplace=True)
                ]
            return layers

        self.model = nn.Sequential(
            *block(N_CLASSES + Z_DIM, 128),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(IMG_SHAPE))),
            nn.Tanh()
        )

    def forward(self, z_vector, labels):
        # concatenate embedded label vector with image to get final input vector
        input_vector = torch.cat((z_vector,self.label_emb(labels)), -1)
        img = self.model(input_vector)
        img = img.view(img.size(0), *IMG_SHAPE)
        return img

## Attach Loss & Optimizers

In [12]:
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

# Loss function
adversarial_loss = torch.nn.MSELoss()

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [13]:
if CUDA:
  generator.cuda()
  discriminator.cuda()
  adversarial_loss.cuda()

  Tensor = torch.cuda.FloatTensor
  LongTensor = torch.cuda.LongTensor
else:
  Tensor = torch.FloatTensor
  LongTensor = torch.LongTensor

In [14]:
def save_cgan_images(n_row, batches_done):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    # Sample noise
    z = Variable(Tensor(np.random.normal(0, 1, (n_row ** 2, Z_DIM))))
    # Get labels ranging from 0 to n_classes for n rows
    labels = np.array([num for _ in range(n_row) for num in range(n_row)])
    labels = Variable(LongTensor(labels))
    gen_imgs = generator(z, labels)
    save_image(gen_imgs.data, f"images/{batches_done}.png", nrow=n_row, normalize=True)

## Train C-GAN

In [None]:
for epoch in range(N_EPOCHS):
    for i, (imgs, labels) in enumerate(dataloader):

        # Set Real and Fake Labels
        valid = Variable(Tensor(BATCH_SIZE, 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(BATCH_SIZE, 1).fill_(0.0), requires_grad=False)

        # Set Variable for real images
        real_imgs = Variable(imgs.type(Tensor))
        labels = Variable(labels.type(LongTensor))
        #  Train Generator
        optimizer_G.zero_grad()

        # Sample noise vector z for generator
        z = Variable(Tensor(np.random.normal(0, 1, (BATCH_SIZE, Z_DIM))))
        gen_labels = Variable(LongTensor(np.random.randint(0, N_CLASSES, BATCH_SIZE)))

        # get generator output
        gen_imgs = generator(z,gen_labels)

        # Calculate and update generator loss
        g_loss = adversarial_loss(discriminator(gen_imgs,gen_labels), valid)
        g_loss.backward()
        optimizer_G.step()

        #  Train Discriminator
        optimizer_D.zero_grad()

        # Calculate Discriminator loss over Fake and Real Samples
        real_loss = adversarial_loss(discriminator(real_imgs,labels), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach(),gen_labels), fake)
        d_loss = (real_loss + fake_loss) / 2

        # Update Discriminator loss
        d_loss.backward()
        optimizer_D.step()
        print(f'Epoch: {epoch}/{N_EPOCHS}-Batch: {i}/{len(dataloader)}--D.loss:{d_loss.item():.4f},G.loss:{g_loss.item():.4f}')

        batches_done = epoch * len(dataloader) + i
        if batches_done % SAMPLE_INTERVAL == 0:
            save_cgan_images(n_row=10, batches_done=batches_done)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch: 0/50-Batch: 638/1875--D.loss:0.2902,G.loss:0.1113
Epoch: 0/50-Batch: 639/1875--D.loss:0.1782,G.loss:1.0785
Epoch: 0/50-Batch: 640/1875--D.loss:0.1173,G.loss:0.5102
Epoch: 0/50-Batch: 641/1875--D.loss:0.1421,G.loss:0.3519
Epoch: 0/50-Batch: 642/1875--D.loss:0.1122,G.loss:0.4826
Epoch: 0/50-Batch: 643/1875--D.loss:0.1146,G.loss:0.7084
Epoch: 0/50-Batch: 644/1875--D.loss:0.1065,G.loss:0.4925
Epoch: 0/50-Batch: 645/1875--D.loss:0.0992,G.loss:0.5769
Epoch: 0/50-Batch: 646/1875--D.loss:0.0815,G.loss:0.6735
Epoch: 0/50-Batch: 647/1875--D.loss:0.0577,G.loss:0.8077
Epoch: 0/50-Batch: 648/1875--D.loss:0.0917,G.loss:0.5197
Epoch: 0/50-Batch: 649/1875--D.loss:0.0948,G.loss:0.6811
Epoch: 0/50-Batch: 650/1875--D.loss:0.0868,G.loss:0.5531
Epoch: 0/50-Batch: 651/1875--D.loss:0.0973,G.loss:0.7173
Epoch: 0/50-Batch: 652/1875--D.loss:0.2410,G.loss:0.1671
Epoch: 0/50-Batch: 653/1875--D.loss:0.4884,G.loss:2.0836
Epoch: 0/50-Batch: 654/