This code is based on https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/cgan/cgan.py

In [None]:
import torch
import torch.nn as nn

from torchvision import datasets
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader

import os
import numpy as np
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1. Dataset

In [None]:
data_dir = './data'
os.makedirs(data_dir, exist_ok=True)

In [None]:
params = {
    'num_classes': 10,
    'latent_space': 100,
    'input_size': (1,28,28),
    'lr': 2e-4,
    'b1': 0.5,
    'b2': 0.999,
    'epochs': 100,
    'batch_size': 64,
}

In [None]:
train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5],[0.5]),
    ])
train_dataset = datasets.MNIST(data_dir, train=True, transform=train_transform, download=True)
train_dataloader = DataLoader(train_dataset, params['batch_size'], shuffle=True)

# 2. Model

In [None]:
class Generator(nn.Module):
    def __init__(self, params):
        super().__init__()
        self.num_classes = params['num_classes']
        self.latent_dim = params['latent_space']
        self.input_size = params['input_size']

        # Label embedding matrix
        self.label_emb = nn.Embedding(self.num_classes, self.num_classes)

        self.model = nn.Sequential(
            *self.block(self.latent_dim + self.num_classes, 128, normalize=False),
            *self.block(128,256),
            *self.block(256,512),
            *self.block(512,1024),
            nn.Linear(1024, int(np.prod(self.input_size))),
            nn.Tanh()
        )
        
    def block(self, in_channels, out_channels, normalize=True):
        layers = []
        layers.append(nn.Linear(in_channels, out_channels)) # fc layer
        if normalize:
            layers.append(nn.BatchNorm1d(out_channels, 0.8)) # Batch Normalization
        layers.append(nn.LeakyReLU(0.2)) # LeakyReLU
        return layers

    def forward(self, noise, labels):
        # Concatenate label embedding and image to produce input
        gen_input = torch.cat((self.label_emb(labels), noise), -1)
        img = self.model(gen_input)
        img = img.view(img.size(0), *self.input_size)
        return img

In [None]:
class Discriminator(nn.Module):
    def __init__(self,params):
        super().__init__()
        self.num_classes = params['num_classes']
        self.input_size = params['input_size']

        # Label embedding matrix
        self.label_emb = nn.Embedding(self.num_classes, self.num_classes)

        self.model = nn.Sequential(
            nn.Linear(self.num_classes + int(np.prod(self.input_size)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1),
        )

    def forward(self, img, labels):
        # Concatenate label embedding and image to produce input
        dis_input = torch.cat((img.view(img.size(0), -1), self.label_emb(labels)), -1)
        validity = self.model(dis_input)
        return validity

# 3. Train

In [None]:
# Loss function
adversarial_loss = torch.nn.MSELoss()

# Initialize generator and discriminator
generator = Generator(params)
discriminator = Discriminator(params)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=params['lr'], betas=(params['b1'], params['b2']))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=params['lr'], betas=(params['b1'], params['b2']))

In [None]:
def sample_image(n_row, batches_done):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    # Sample noise
    z = torch.randn(n_row ** 2, params['latent_space']).to(device)
    # Get labels ranging from 0 to n_classes for n rows
    labels = np.array([num for _ in range(n_row) for num in range(n_row)])
    labels = torch.tensor(labels)
    gen_imgs = generator(z, labels)
    save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)

In [None]:
batch_count = 0

for epoch in tqdm(range(params['epochs'])):
    for i, (imgs, labels) in enumerate(train_dataloader):
        
        # Adversarial ground truths
        valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device)
        fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device)

        real_imgs = imgs.to(device)
        real_labels = labels.to(device)

        #-----------------
        # Train Generator 
        #-----------------
        optimizer_G.zero_grad()

        # Sample noise as Generator input
        z = torch.randn(imgs.size(0), params['latent_space']).to(device)
        gen_labels = torch.randint(0, params['num_classes'], (imgs.size(0),)).to(device)
        
        # Generate a batch of images
        gen_imgs = generator(z, gen_labels) 

        # Loss measure Generator's ability to fool the discriminator
        validity = discriminator(gen_imgs, gen_labels)
        g_loss = adversarial_loss(validity, valid) 
        
        g_loss.backward()
        optimizer_G.step()

        #---------------------
        # Train Discriminator 
        #---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs, real_labels), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach(), labels), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()


        batches_done = epoch * len(train_dataloader) + i
        if batches_done % 1000 == 0:      
            print(
                "[Epoch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, params['epochs'], d_loss.item(), g_loss.item())
            )
            sample_image(n_row=10, batches_done=batches_done)