This code is based on https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/cgan/cgan.py

In [15]:
import torch
import torch.nn as nn

from torchvision import datasets
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader

import os
import numpy as np
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1. Dataset

In [16]:
data_dir = './data'
os.makedirs(data_dir, exist_ok=True)

In [21]:
params = {
    'num_classes': 10,
    'latent_space': 100,
    'input_size': (1,32,32),
    'image_size': 32,
    'lr': 2e-4,
    'b1': 0.5,
    'b2': 0.999,
    'epochs': 100,
    'batch_size': 64,
}

In [22]:
train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(params['image_size']),
    transforms.Normalize([0.5],[0.5]),
    ])
train_dataset = datasets.MNIST(data_dir, train=True, transform=train_transform, download=True)
train_dataloader = DataLoader(train_dataset, params['batch_size'], shuffle=True)

# 2. Model

In [23]:
class Generator(nn.Module):
    def __init__(self, params):
        super().__init__()
        self.num_classes = params['num_classes']
        self.latent_dim = params['latent_space']
        self.input_size = params['input_size']

        # Label embedding matrix
        self.label_emb = nn.Embedding(self.num_classes, self.num_classes)

        self.model = nn.Sequential(
            *self.block(self.latent_dim + self.num_classes, 128, normalize=False),
            *self.block(128,256),
            *self.block(256,512),
            *self.block(512,1024),
            nn.Linear(1024, int(np.prod(self.input_size))),
            nn.Tanh()
        )
        
    def block(self, in_channels, out_channels, normalize=True):
        layers = []
        layers.append(nn.Linear(in_channels, out_channels)) # fc layer
        if normalize:
            layers.append(nn.BatchNorm1d(out_channels, 0.8)) # Batch Normalization
        layers.append(nn.LeakyReLU(0.2)) # LeakyReLU
        return layers

    def forward(self, noise, labels):
        # Concatenate label embedding and image to produce input
        gen_input = torch.cat((self.label_emb(labels), noise), -1)
        img = self.model(gen_input)
        img = img.view(img.size(0), *self.input_size)
        return img

In [24]:
class Discriminator(nn.Module):
    def __init__(self,params):
        super().__init__()
        self.num_classes = params['num_classes']
        self.input_size = params['input_size']

        # Label embedding matrix
        self.label_emb = nn.Embedding(self.num_classes, self.num_classes)

        self.model = nn.Sequential(
            nn.Linear(self.num_classes + int(np.prod(self.input_size)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1),
        )

    def forward(self, img, labels):
        # Concatenate label embedding and image to produce input
        dis_input = torch.cat((img.view(img.size(0), -1), self.label_emb(labels)), -1)
        validity = self.model(dis_input)
        return validity

# 3. Train

In [25]:
# Loss function
adversarial_loss = torch.nn.MSELoss()

# Initialize generator and discriminator
generator = Generator(params)
discriminator = Discriminator(params)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=params['lr'], betas=(params['b1'], params['b2']))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=params['lr'], betas=(params['b1'], params['b2']))

In [26]:
def sample_image(n_row, batches_done):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    # Sample noise
    z = torch.randn(n_row ** 2, params['latent_space'])
    # Get labels ranging from 0 to n_classes for n rows
    labels = np.array([num for _ in range(n_row) for num in range(n_row)])
    labels = torch.tensor(labels)
    gen_imgs = generator(z, labels)
    save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)

In [27]:
batch_count = 0

for epoch in tqdm(range(params['epochs'])):
    for i, (imgs, labels) in enumerate(train_dataloader):
        
        # Adversarial ground truths
        valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device)
        fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device)

        real_imgs = imgs.to(device)
        real_labels = labels.to(device)

        #-----------------
        # Train Generator 
        #-----------------
        optimizer_G.zero_grad()

        # Sample noise as Generator input
        z = torch.randn(imgs.size(0), params['latent_space']).to(device)
        gen_labels = torch.randint(0, params['num_classes'], (imgs.size(0),)).to(device)
        
        # Generate a batch of images
        gen_imgs = generator(z, gen_labels) 

        # Loss measure Generator's ability to fool the discriminator
        validity = discriminator(gen_imgs, gen_labels)
        g_loss = adversarial_loss(validity, valid) 
        
        g_loss.backward()
        optimizer_G.step()

        #---------------------
        # Train Discriminator 
        #---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs, real_labels), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach(), labels), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()


        batches_done = epoch * len(train_dataloader) + i
        if batches_done % 1000 == 0:      
            print(
                "[Epoch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, params['epochs'], d_loss.item(), g_loss.item())
            )
            sample_image(n_row=10, batches_done=batches_done)

  0%|          | 0/100 [00:00<?, ?it/s]

[Epoch 0/100] [D loss: 0.505711] [G loss: 1.002621]


  1%|          | 1/100 [00:42<1:09:47, 42.30s/it]

[Epoch 1/100] [D loss: 0.089273] [G loss: 0.406389]


  2%|▏         | 2/100 [01:25<1:09:30, 42.56s/it]

[Epoch 2/100] [D loss: 0.070553] [G loss: 0.391528]


  3%|▎         | 3/100 [02:07<1:08:56, 42.65s/it]

[Epoch 3/100] [D loss: 0.057716] [G loss: 0.400965]


  4%|▍         | 4/100 [02:49<1:07:41, 42.31s/it]

[Epoch 4/100] [D loss: 0.052597] [G loss: 0.337107]


  5%|▌         | 5/100 [03:31<1:07:00, 42.32s/it]

[Epoch 5/100] [D loss: 0.046736] [G loss: 0.179180]


  6%|▌         | 6/100 [04:14<1:06:26, 42.41s/it]

[Epoch 6/100] [D loss: 0.064955] [G loss: 0.159578]


  7%|▋         | 7/100 [04:57<1:05:57, 42.56s/it]

[Epoch 7/100] [D loss: 0.059494] [G loss: 0.446391]


  8%|▊         | 8/100 [05:40<1:05:24, 42.66s/it]

[Epoch 8/100] [D loss: 0.052291] [G loss: 0.191006]


  9%|▉         | 9/100 [06:22<1:04:25, 42.48s/it]

[Epoch 9/100] [D loss: 0.053831] [G loss: 0.280842]


 10%|█         | 10/100 [07:03<1:02:59, 41.99s/it]

[Epoch 10/100] [D loss: 0.054730] [G loss: 0.360515]


 11%|█         | 11/100 [07:44<1:01:48, 41.67s/it]

[Epoch 11/100] [D loss: 0.048098] [G loss: 0.152394]


 12%|█▏        | 12/100 [08:24<1:00:43, 41.40s/it]

[Epoch 12/100] [D loss: 0.064987] [G loss: 0.158490]


 13%|█▎        | 13/100 [09:06<59:55, 41.33s/it]  

[Epoch 13/100] [D loss: 0.050852] [G loss: 0.150052]


 14%|█▍        | 14/100 [09:47<59:06, 41.24s/it]

[Epoch 14/100] [D loss: 0.043077] [G loss: 0.159060]


 15%|█▌        | 15/100 [10:28<58:19, 41.17s/it]

[Epoch 15/100] [D loss: 0.059799] [G loss: 0.142306]


 17%|█▋        | 17/100 [11:49<56:37, 40.93s/it]

[Epoch 17/100] [D loss: 0.057359] [G loss: 0.126982]


 18%|█▊        | 18/100 [12:30<55:46, 40.81s/it]

[Epoch 18/100] [D loss: 0.051211] [G loss: 0.199906]


 19%|█▉        | 19/100 [13:10<55:00, 40.75s/it]

[Epoch 19/100] [D loss: 0.076723] [G loss: 0.101398]


 20%|██        | 20/100 [13:51<54:19, 40.74s/it]

[Epoch 20/100] [D loss: 0.051624] [G loss: 0.166433]


 21%|██        | 21/100 [14:32<53:40, 40.76s/it]

[Epoch 21/100] [D loss: 0.070136] [G loss: 0.206847]


 22%|██▏       | 22/100 [15:13<53:01, 40.79s/it]

[Epoch 22/100] [D loss: 0.066544] [G loss: 0.065077]


 23%|██▎       | 23/100 [15:54<52:24, 40.84s/it]

[Epoch 23/100] [D loss: 0.045901] [G loss: 0.139393]


 24%|██▍       | 24/100 [16:35<51:45, 40.87s/it]

[Epoch 24/100] [D loss: 0.044567] [G loss: 0.128716]


 25%|██▌       | 25/100 [17:15<51:05, 40.88s/it]

[Epoch 25/100] [D loss: 0.055202] [G loss: 0.061420]


 26%|██▌       | 26/100 [17:56<50:21, 40.84s/it]

[Epoch 26/100] [D loss: 0.027464] [G loss: 0.115194]


 27%|██▋       | 27/100 [18:37<49:40, 40.83s/it]

[Epoch 27/100] [D loss: 0.050658] [G loss: 0.213541]


 28%|██▊       | 28/100 [19:18<49:01, 40.86s/it]

[Epoch 28/100] [D loss: 0.058289] [G loss: 0.133112]


 29%|██▉       | 29/100 [19:59<48:24, 40.91s/it]

[Epoch 29/100] [D loss: 0.030345] [G loss: 0.186383]


 30%|███       | 30/100 [20:40<47:45, 40.94s/it]

[Epoch 30/100] [D loss: 0.070685] [G loss: 0.212864]


 31%|███       | 31/100 [21:21<47:02, 40.91s/it]

[Epoch 31/100] [D loss: 0.081577] [G loss: 0.316992]


 33%|███▎      | 33/100 [22:43<45:43, 40.94s/it]

[Epoch 33/100] [D loss: 0.074529] [G loss: 0.045183]


 34%|███▍      | 34/100 [23:24<45:00, 40.91s/it]

[Epoch 34/100] [D loss: 0.053890] [G loss: 0.106573]


 35%|███▌      | 35/100 [24:05<44:20, 40.93s/it]

[Epoch 35/100] [D loss: 0.052124] [G loss: 0.177002]


 36%|███▌      | 36/100 [24:46<43:43, 41.00s/it]

[Epoch 36/100] [D loss: 0.078220] [G loss: 0.156333]


 37%|███▋      | 37/100 [25:27<43:04, 41.02s/it]

[Epoch 37/100] [D loss: 0.031375] [G loss: 0.174125]


 38%|███▊      | 38/100 [26:08<42:22, 41.01s/it]

[Epoch 38/100] [D loss: 0.053705] [G loss: 0.146353]


 39%|███▉      | 39/100 [26:49<41:40, 40.99s/it]

[Epoch 39/100] [D loss: 0.045095] [G loss: 0.116664]


 40%|████      | 40/100 [27:30<40:58, 40.98s/it]

[Epoch 40/100] [D loss: 0.073593] [G loss: 0.067274]


 41%|████      | 41/100 [28:11<40:21, 41.04s/it]

[Epoch 41/100] [D loss: 0.042656] [G loss: 0.048722]


 42%|████▏     | 42/100 [28:52<39:50, 41.22s/it]

[Epoch 42/100] [D loss: 0.028218] [G loss: 0.102089]


 43%|████▎     | 43/100 [29:33<39:02, 41.09s/it]

[Epoch 43/100] [D loss: 0.053765] [G loss: 0.054577]


 44%|████▍     | 44/100 [30:14<38:16, 41.01s/it]

[Epoch 44/100] [D loss: 0.039797] [G loss: 0.154720]


 45%|████▌     | 45/100 [30:55<37:34, 40.99s/it]

[Epoch 45/100] [D loss: 0.032044] [G loss: 0.077832]


 46%|████▌     | 46/100 [31:36<36:56, 41.05s/it]

[Epoch 46/100] [D loss: 0.049331] [G loss: 0.125000]


 47%|████▋     | 47/100 [32:17<36:12, 40.99s/it]

[Epoch 47/100] [D loss: 0.041658] [G loss: 0.054657]


 49%|████▉     | 49/100 [33:39<34:46, 40.92s/it]

[Epoch 49/100] [D loss: 0.037010] [G loss: 0.157126]


 50%|█████     | 50/100 [34:20<34:04, 40.89s/it]

[Epoch 50/100] [D loss: 0.065854] [G loss: 0.089815]


 51%|█████     | 51/100 [35:01<33:23, 40.89s/it]

[Epoch 51/100] [D loss: 0.043531] [G loss: 0.111748]


 52%|█████▏    | 52/100 [35:42<32:46, 40.97s/it]

[Epoch 52/100] [D loss: 0.037985] [G loss: 0.190178]


 53%|█████▎    | 53/100 [36:23<32:04, 40.94s/it]

[Epoch 53/100] [D loss: 0.044804] [G loss: 0.192128]


 54%|█████▍    | 54/100 [37:03<31:21, 40.90s/it]

[Epoch 54/100] [D loss: 0.039262] [G loss: 0.113817]


 55%|█████▌    | 55/100 [37:44<30:39, 40.88s/it]

[Epoch 55/100] [D loss: 0.055224] [G loss: 0.191202]


 56%|█████▌    | 56/100 [38:25<29:56, 40.82s/it]

[Epoch 56/100] [D loss: 0.069175] [G loss: 0.115806]


 57%|█████▋    | 57/100 [39:06<29:14, 40.81s/it]

[Epoch 57/100] [D loss: 0.038438] [G loss: 0.114606]


 58%|█████▊    | 58/100 [39:46<28:32, 40.78s/it]

[Epoch 58/100] [D loss: 0.045158] [G loss: 0.214839]


 59%|█████▉    | 59/100 [40:27<27:50, 40.74s/it]

[Epoch 59/100] [D loss: 0.042550] [G loss: 0.117932]


 60%|██████    | 60/100 [41:08<27:07, 40.69s/it]

[Epoch 60/100] [D loss: 0.047722] [G loss: 0.177577]


 61%|██████    | 61/100 [41:48<26:29, 40.74s/it]

[Epoch 61/100] [D loss: 0.050498] [G loss: 0.204233]


 62%|██████▏   | 62/100 [42:29<25:48, 40.74s/it]

[Epoch 62/100] [D loss: 0.035130] [G loss: 0.214614]


 63%|██████▎   | 63/100 [43:10<25:07, 40.74s/it]

[Epoch 63/100] [D loss: 0.051811] [G loss: 0.132927]


 65%|██████▌   | 65/100 [44:31<23:46, 40.75s/it]

[Epoch 65/100] [D loss: 0.035350] [G loss: 0.103066]


 66%|██████▌   | 66/100 [45:12<23:06, 40.78s/it]

[Epoch 66/100] [D loss: 0.041190] [G loss: 0.170456]


 67%|██████▋   | 67/100 [45:53<22:26, 40.79s/it]

[Epoch 67/100] [D loss: 0.059769] [G loss: 0.345856]


 68%|██████▊   | 68/100 [46:34<21:46, 40.82s/it]

[Epoch 68/100] [D loss: 0.032655] [G loss: 0.194548]


 69%|██████▉   | 69/100 [47:15<21:05, 40.83s/it]

[Epoch 69/100] [D loss: 0.023603] [G loss: 0.152115]


 70%|███████   | 70/100 [47:56<20:27, 40.92s/it]

[Epoch 70/100] [D loss: 0.050673] [G loss: 0.077486]


 71%|███████   | 71/100 [48:37<19:48, 40.97s/it]

[Epoch 71/100] [D loss: 0.042907] [G loss: 0.115298]


 72%|███████▏  | 72/100 [49:18<19:06, 40.94s/it]

[Epoch 72/100] [D loss: 0.040670] [G loss: 0.089380]


 73%|███████▎  | 73/100 [49:59<18:25, 40.95s/it]

[Epoch 73/100] [D loss: 0.027937] [G loss: 0.111976]


 74%|███████▍  | 74/100 [50:40<17:42, 40.88s/it]

[Epoch 74/100] [D loss: 0.057868] [G loss: 0.082366]


 75%|███████▌  | 75/100 [51:21<17:02, 40.89s/it]

[Epoch 75/100] [D loss: 0.036110] [G loss: 0.146628]


 76%|███████▌  | 76/100 [52:01<16:20, 40.86s/it]

[Epoch 76/100] [D loss: 0.060849] [G loss: 0.029700]


 77%|███████▋  | 77/100 [52:42<15:40, 40.87s/it]

[Epoch 77/100] [D loss: 0.053666] [G loss: 0.139401]


 78%|███████▊  | 78/100 [53:23<14:58, 40.83s/it]

[Epoch 78/100] [D loss: 0.032040] [G loss: 0.096708]


 79%|███████▉  | 79/100 [54:04<14:17, 40.84s/it]

[Epoch 79/100] [D loss: 0.037944] [G loss: 0.077633]


 81%|████████  | 81/100 [55:25<12:55, 40.79s/it]

[Epoch 81/100] [D loss: 0.032480] [G loss: 0.072371]


 82%|████████▏ | 82/100 [56:06<12:13, 40.75s/it]

[Epoch 82/100] [D loss: 0.039520] [G loss: 0.140143]


 83%|████████▎ | 83/100 [56:47<11:32, 40.75s/it]

[Epoch 83/100] [D loss: 0.034250] [G loss: 0.130723]


 84%|████████▍ | 84/100 [57:27<10:51, 40.74s/it]

[Epoch 84/100] [D loss: 0.039962] [G loss: 0.138818]


 85%|████████▌ | 85/100 [58:08<10:11, 40.79s/it]

[Epoch 85/100] [D loss: 0.056669] [G loss: 0.094562]


 86%|████████▌ | 86/100 [58:49<09:30, 40.77s/it]

[Epoch 86/100] [D loss: 0.063081] [G loss: 0.057624]


 87%|████████▋ | 87/100 [59:30<08:51, 40.86s/it]

[Epoch 87/100] [D loss: 0.050173] [G loss: 0.086229]


 88%|████████▊ | 88/100 [1:00:11<08:10, 40.89s/it]

[Epoch 88/100] [D loss: 0.022860] [G loss: 0.081306]


 89%|████████▉ | 89/100 [1:00:52<07:29, 40.86s/it]

[Epoch 89/100] [D loss: 0.053206] [G loss: 0.155510]


 90%|█████████ | 90/100 [1:01:33<06:48, 40.89s/it]

[Epoch 90/100] [D loss: 0.031498] [G loss: 0.079928]


 91%|█████████ | 91/100 [1:02:14<06:08, 40.93s/it]

[Epoch 91/100] [D loss: 0.030679] [G loss: 0.169252]


 92%|█████████▏| 92/100 [1:02:55<05:26, 40.87s/it]

[Epoch 92/100] [D loss: 0.039595] [G loss: 0.097331]


 93%|█████████▎| 93/100 [1:03:36<04:47, 41.13s/it]

[Epoch 93/100] [D loss: 0.052948] [G loss: 0.068106]


 94%|█████████▍| 94/100 [1:04:19<04:09, 41.55s/it]

[Epoch 94/100] [D loss: 0.044773] [G loss: 0.163996]


 95%|█████████▌| 95/100 [1:05:01<03:28, 41.64s/it]

[Epoch 95/100] [D loss: 0.038456] [G loss: 0.240753]


 97%|█████████▋| 97/100 [1:06:31<02:10, 43.51s/it]

[Epoch 97/100] [D loss: 0.053702] [G loss: 0.139786]


 98%|█████████▊| 98/100 [1:07:13<01:26, 43.17s/it]

[Epoch 98/100] [D loss: 0.033137] [G loss: 0.115518]


 99%|█████████▉| 99/100 [1:07:56<00:42, 42.93s/it]

[Epoch 99/100] [D loss: 0.032435] [G loss: 0.069675]


100%|██████████| 100/100 [1:08:37<00:00, 41.18s/it]
