#### About
1. Conditional GANs are GANs that allow us to condition the network with additional information such as class labels. It means that during the phase of training, Images are passed to the network along with their actual class labels for it to learn the difference between them.
2. The limitation of generating random samples with a GAN is overcome via a conditional GAN i.e control on output is maintained. For e.g - In Fashion MNIST, CGAN can help output all jacket's image or equivalent customization.
3. The loss function of GANs quoted below
![gan_loss.png](gan_loss.png)
is modified by conditioning class labels as i.e conditional probabilities
![cgan_loss.png](cgan_loss.png)
4. In this notebook, We'll implement CGANs on Fashion MNIST dataset.

In [1]:
from torch import optim
import os
import torchvision.utils as utils
import numpy as numpy
from torchvision import  datasets
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
#downloading fashion mnist dataset
dataset_path = os.path.join('./data', 'FashionMNIST')
os.makedirs(dataset_path, exist_ok=True)
model_path = os.path.join('./model', 'FashionMNIST')
os.makedirs(model_path,exist_ok=True)
samples_path = os.path.join('./samples','FashionMNIST')
os.makedirs(samples_path,exist_ok=True)

#defining the transform
transform = transforms.Compose([transforms.Resize([32,32]),
                                transforms.ToTensor(),
                                transforms.Normalize([0.5],[0.5])])

dataset = datasets.FashionMNIST(dataset_path, train=True, download=True, transform=transform)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/FashionMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/FashionMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/FashionMNIST/raw



In [3]:
train_loader = DataLoader(dataset=dataset, batch_size=256, shuffle=True, num_workers=4, drop_last=True)



In [4]:
for batch in train_loader:
    print(batch)
    break


[tensor([[[[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          ...,
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.]]],


        [[[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          ...,
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.]]],


        [[[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          ...,
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.]]],


        ...,


        [[[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1.,

In [5]:
# storing generated images
def generate_store_image(z,fixed_label,epoch=0):
    #putting generator model to eval mode
    gen.eval()
    fake_imgs = gen(z,fixed_label)
    fake_imgs = (fake_imgs+1)/2
    fake_imgs_ = utils.make_grid(fake_imgs, normalize=False, nrow=10)
    utils.save_image(fake_imgs_, os.path.join(samples_path, 'sample_'+str(epoch)+'.png'))
    

#### Model Architecture 


In [6]:
def convolution_block(in_channels,out_channels, kernel=4,stride=2, pad=1,bias=False, transpose=False):
    module= []
    if transpose:
        module.append(nn.ConvTranspose2d(in_channels,out_channels,kernel,stride, pad, bias=bias))
    else:
        module.append(nn.Conv2d(in_channels,out_channels,kernel,stride,pad,bias=bias))
    if bias == False:
        #use batch norm
        module.append(nn.BatchNorm2d(out_channels))
    
    return nn.Sequential(*module)

In [7]:
class Generator(nn.Module):
    def __init__(self,z_dim=10, num_classes=10, label_embed_size=5, channels=3, conv_dim=64):
        super().__init__()
        self.label_embedding = nn.Embedding(num_classes, label_embed_size)
        self.transpose_conv1 =convolution_block(z_dim+label_embed_size,conv_dim*4, pad=0, transpose=True)
        self.transpose_conv2 = convolution_block(conv_dim*4, conv_dim*2, transpose=True)
        self.transpose_conv3 = convolution_block(conv_dim*2, conv_dim, transpose=True)
        self.transpose_conv4 = convolution_block(conv_dim, channels, transpose=True,bias=True) #no batch norm

        for m in self.modules():
            #initialising weights
            if isinstance(m,nn.Conv2d) or isinstance(m,nn.ConvTranspose2d):
                nn.init.normal_(m.weight, 0.0, 0.02)
            if isinstance(m,nn.BatchNorm2d):
                nn.init.constant_(m.weight,1)
                nn.init.constant_(m.bias,0)

    #enforcing label in forward pass
    def forward(self,x,label):
        #reshaping x
        x = x.reshape([x.shape[0],-1,1,1])
        label_embed = self.label_embedding(label)
        label_embed = label_embed.reshape([label_embed.shape[0],-1,1,1])
        x = torch.cat((x,label_embed),dim=1)
        x = F.relu(self.transpose_conv1(x))
        x = F.relu(self.transpose_conv2(x))
        x = F.relu(self.transpose_conv3(x))
        x = torch.tanh(self.transpose_conv4(x))
        return x

In [8]:
class Discriminator(nn.Module):
    def __init__(self, num_classes=10, channels=3, conv_dim=64):
        super(Discriminator, self).__init__()
        self.image_size = 32
        self.label_embedding = nn.Embedding(num_classes, self.image_size*self.image_size)
        self.conv1 = convolution_block(channels + 1, conv_dim, bias=True)
        self.conv2 = convolution_block(conv_dim, conv_dim * 2)
        self.conv3 = convolution_block(conv_dim * 2, conv_dim * 4)
        self.conv4 = convolution_block(conv_dim * 4, 1, kernel=4, stride=1, pad=0, bias=True)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, 0.0, 0.02)

            if isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x, label):
        alpha = 0.2
        label_embed = self.label_embedding(label)
        label_embed = label_embed.reshape([label_embed.shape[0], 1, self.image_size, self.image_size])
        x = torch.cat((x, label_embed), dim=1)
        x = F.leaky_relu(self.conv1(x), alpha)
        x = F.leaky_relu(self.conv2(x), alpha)
        x = F.leaky_relu(self.conv3(x), alpha)
        x = torch.sigmoid(self.conv4(x))
        return x.squeeze()

In [9]:
Z_DIM=10
LABEL_EMBEDDING_SIZE=5
NUM_CLASSES=10
IMGS_TO_DISPLAY_PER_CLASS=10
LOAD_MODEL = False
CHANNELS=1
EPOCHS =100
BATCH_SIZE=256
gen = Generator(z_dim=Z_DIM, num_classes=NUM_CLASSES, label_embed_size=LABEL_EMBEDDING_SIZE, channels=CHANNELS)
dis = Discriminator(num_classes=NUM_CLASSES, channels=CHANNELS)


In [10]:
if LOAD_MODEL:
    gen.load_state_dict(torch.load(os.path.join(model_path,'gen.pth')))
    dis.load_state_dict(torch.load(os.path.join(model_path,'dis.pth')))

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gen = gen.to(device)
dis = dis.to(device)

In [12]:
gen

Generator(
  (label_embedding): Embedding(10, 5)
  (transpose_conv1): Sequential(
    (0): ConvTranspose2d(15, 256, kernel_size=(4, 4), stride=(2, 2), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (transpose_conv2): Sequential(
    (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (transpose_conv3): Sequential(
    (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (transpose_conv4): Sequential(
    (0): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  )
)

In [13]:
dis

Discriminator(
  (label_embedding): Embedding(10, 1024)
  (conv1): Sequential(
    (0): Conv2d(2, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  )
  (conv2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv3): Sequential(
    (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv4): Sequential(
    (0): Conv2d(256, 1, kernel_size=(4, 4), stride=(1, 1))
  )
)

In [14]:
# Loss function
loss_function = nn.BCELoss()

In [15]:
# Define Optimizers
g_opt = optim.Adam(gen.parameters(), lr=0.0002, betas=(0.5, 0.999), weight_decay=2e-5)
d_opt = optim.Adam(dis.parameters(), lr=0.0002, betas=(0.5, 0.999), weight_decay=2e-5)

In [16]:
# fixing images for visualization
fixed_z = torch.randn(IMGS_TO_DISPLAY_PER_CLASS*NUM_CLASSES, Z_DIM)
fixed_label = torch.arange(0, NUM_CLASSES)
fixed_label = torch.repeat_interleave(fixed_label, IMGS_TO_DISPLAY_PER_CLASS)

In [17]:
fixed_label

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7,
        7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9,
        9, 9, 9, 9])

In [18]:
#labels
real_label = torch.ones(BATCH_SIZE)
fake_label = torch.zeros(BATCH_SIZE)

In [19]:
# transferring to device
real_label, fake_label = real_label.to(device), fake_label.to(device)
fixed_z, fixed_label = fixed_z.to(device), fixed_label.to(device)

In [22]:
# training model
for epoch in range(EPOCHS):
    gen.train()
    dis.train()
    for step, batch in enumerate(train_loader):
        x_real,x_label = batch
        z_fake = torch.randn(BATCH_SIZE,Z_DIM).to(device)

        x_real, x_label = x_real.to(device),x_label.to(device)

        # generate fake data
        x_fake = gen(z_fake,x_label)

        # train discriminator
        fake_out = dis(x_fake.detach(), x_label)
        real_out = dis(x_real.detach(), x_label)

        d_loss = (loss_function(fake_out,fake_label)+ loss_function(real_out,real_label))/21

        d_opt.zero_grad()
        d_loss.backward()
        d_opt.step()

        # Training generator
        fake_out = dis(x_fake,x_label)
        g_loss = loss_function(fake_out, real_label)
        g_opt.zero_grad()
        g_loss.backward()
        g_opt.step()

        if step%10==0:
            print("Epoch - {},step-{}, Discriminator Loss - {}, Generator Loss - {}".format(epoch,step,d_loss.item(),g_loss.item()))

    if epoch+1%10==0:
        torch.save(gen.state_dict(),os.path.join(model_path,'gen.pth'))
        torch.save(dis.state_dict(),os.path.join(model_path,'dis.pth'))

        generate_store_image(fixed_z, fixed_label, epoch=epoch+1)
    
    generate_store_image(fixed_z,fixed_label)



Epoch - 0,step-0, Discriminator Loss - 0.07328560203313828, Generator Loss - 1.573000431060791
Epoch - 0,step-10, Discriminator Loss - 0.01734599657356739, Generator Loss - 2.8096652030944824
Epoch - 0,step-20, Discriminator Loss - 0.015111202374100685, Generator Loss - 2.6701512336730957
Epoch - 0,step-30, Discriminator Loss - 0.013419258408248425, Generator Loss - 3.9305756092071533
Epoch - 0,step-40, Discriminator Loss - 0.011228839866816998, Generator Loss - 2.7949867248535156
Epoch - 0,step-50, Discriminator Loss - 0.023467816412448883, Generator Loss - 3.6591405868530273
Epoch - 0,step-60, Discriminator Loss - 0.019955407828092575, Generator Loss - 1.7215840816497803
Epoch - 0,step-70, Discriminator Loss - 0.02801300771534443, Generator Loss - 2.088083267211914
Epoch - 0,step-80, Discriminator Loss - 0.02797212451696396, Generator Loss - 2.356624126434326
Epoch - 0,step-90, Discriminator Loss - 0.0442306213080883, Generator Loss - 2.6178503036499023
Epoch - 0,step-100, Discrimina