Referecnce : https://github.com/eriklindernoren/PyTorch-GAN

In [1]:
import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

### Generator architecture
* input random vector: 100 dim, labels
* embedding layer: embeding labels to 100 dim
* linear layer: out_features 128 * 8 * 8
* batchnorm
* upsample: factor 2
* Conv2d: out_channel: 128, kernel size 3, stride 1, padding 1
* batchnorm
* leakyrelu: 0.2
* upsample: factor 2
* conv2d: out_channel: 64, kernel size 3, stride 1, padding 1
* batchnorm
* leakyrelu: 0.2
* conv2d: out_channel: 1, kernel size 3, stride 1, padding 1
* tanh

In [2]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        ## fill here
        self.label_emb = nn.Embedding(10, 100)
        
        self.l1 = nn.Sequential(nn.Linear(100, 128*8*8))
        
        self.conv_blocks = nn.Sequential(
                nn.BatchNorm2d(128),
                nn.Upsample(scale_factor=2),
                nn.Conv2d(128, 128, 3, stride=1, padding=1),
                nn.BatchNorm2d(128),
                nn.LeakyReLU(0.2),
                nn.Upsample(scale_factor=2),
                nn.Conv2d(128, 64, 3, stride=1, padding=1),
                nn.BatchNorm2d(64),
                nn.LeakyReLU(0.2),
                nn.Conv2d(64, 1, 3, stride=1, padding=1),
                nn.Tanh()
                )

    def forward(self, noise, labels):
        ## fill here
        gen_input = torch.mul(self.label_emb(labels), noise)
        out = self.l1(gen_input)
        out = out.view(out.shape[0], 128, 8, 8)
        img = self.conv_blocks(out)
        return img

### Discriminator architecture
* input: [1 , 32 , 32] image 
* conv2d: out_channel: 16, kernel size 3, stride 2, padding 1
* leakyrelu: 0.2
* dropout: 0.25
* Conv2d: out_channel: 32, kernel size 3, stride 2, padding 1
* leakyrelu: 0.2
* dropout: 0.25
* batchnorm
* Conv2d: out_channel: 64, kernel size 3, stride 2, padding 1
* leakyrelu: 0.2
* dropout: 0.25
* batchnorm
* Conv2d: out_channel: 128, kernel size 3, stride 2, padding 1
* leakyrelu: 0.2
* dropout: 0.25
* batchnorm
* two linear layers: one for adversarial loss, one for classification


In [3]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        ## fill here
        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2), nn.Dropout2d(0.25)]
            
            if bn:
                block.append(nn.BatchNorm2d(out_filters))
            return block ## *list

        self.conv_blocks = nn.Sequential(
            *discriminator_block(1, 16, bn=False), ## 16 * 16 * 16
            *discriminator_block(16, 32), ## 32 * 8 * 8
            *discriminator_block(32, 64), ## 64 * 4 * 4
            *discriminator_block(64, 128),## 128 * 2 * 2 = 512
        )
        
        self.adv_layer = nn.Sequential(nn.Linear(512, 1), nn.Sigmoid())
        self.aux_layer = nn.Sequential(nn.Linear(512, 10))

    def forward(self, img):
        ## fill here
        out = self.conv_blocks(img) ## 128 by 2 by 2
        out = out.view(out.shape[0], -1) ## batch_size * 512
        validity = self.adv_layer(out)
        label = self.aux_layer(out)
        return validity, label

In [4]:
# generator = Generator()

# input_sample = torch.randn(2,100)
# labels = torch.Tensor([0, 1]).long()
# output = generator(input_sample, labels)
# output.shape

In [5]:
# Loss function
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()

# Initialize generator and discriminator
generator = Generator().cuda()
discriminator = Discriminator().cuda()

In [6]:
dataset = datasets.MNIST("../data/mnist", train=True, download=True, 
                         transform=transforms.Compose([transforms.Resize(32), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True)

In [7]:
## optimizer fill here
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.00002, betas=(0.5, 0.9999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.00002, betas=(0.5, 0.9999))

In [8]:
os.makedirs("./acgan_images", exist_ok=True)

In [9]:
def sample_image(n_row, batches_done):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    # Sample noise
    z = torch.Tensor(np.random.normal(0, 1, (n_row ** 2, 100))).cuda()
    # Get labels ranging from 0 to n_classes for n rows
    labels = np.array([num for _ in range(n_row) for num in range(n_row)])
    labels = torch.LongTensor(labels).cuda()
    gen_imgs = generator(z, labels)
    save_image(gen_imgs, "acgan_images/%d.png" % batches_done, nrow=n_row, normalize=True)

In [10]:
for epoch in range(200):
    for i, (real_imgs, labels) in enumerate(dataloader):

        batch_size = real_imgs.shape[0]

        # Configure input
        real_imgs = real_imgs.cuda()
        labels = labels.cuda()

        z = torch.Tensor(np.random.normal(0,1, (batch_size, 100))).cuda()
        gen_labels = torch.LongTensor(np.random.randint(0, 10, batch_size)).cuda()
        # -----------------
        #  Train Generator
        # -----------------
        ## fill here
        optimizer_G.zero_grad()
        gen_imgs = generator(z, gen_labels)
        
        validity, fake_aux = discriminator(gen_imgs)
        g_loss = 0.5*(adversarial_loss(validity, torch.ones((gen_imgs.size(0), 1)).cuda())
                     + auxiliary_loss(fake_aux, gen_labels))
        g_loss.backward()
        optimizer_G.step()
        
        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()
        
        # Loss for real images
        real_pred, real_aux = discriminator(real_imgs)
        d_real_loss = (adversarial_loss(real_pred, torch.ones((gen_imgs.size(0),1)).cuda()) + auxiliary_loss(real_aux, labels)) / 2

        # Loss for fake images
        fake_pred, fake_aux = discriminator(gen_imgs.detach())
        d_fake_loss = (adversarial_loss(fake_pred, torch.zeros((gen_imgs.size(0),1)).cuda()) + auxiliary_loss(fake_aux, gen_labels)) / 2

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2
        
        d_loss.backward()
        optimizer_D.step()
        
        # Calculate discriminator accuracy
        pred = np.concatenate([real_aux.detach().cpu().numpy(), fake_aux.detach().cpu().numpy()], axis=0)
        gt = np.concatenate([labels.detach().cpu().numpy(), gen_labels.detach().cpu().numpy()], axis=0)
        d_acc = np.mean(np.argmax(pred, axis=1) == gt)
        

        batches_done = epoch * len(dataloader) + i
        if batches_done % 100 == 0:     
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"
                % (epoch, 200, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item())
            )
            
        if batches_done % 2000 == 0:
            sample_image(n_row=10, batches_done=batches_done)

[Epoch 0/200] [Batch 0/937] [D loss: 1.579751, acc: 6%] [G loss: 1.586793]
[Epoch 0/200] [Batch 100/937] [D loss: 1.463645, acc: 12%] [G loss: 1.529694]
[Epoch 0/200] [Batch 200/937] [D loss: 1.371188, acc: 28%] [G loss: 1.520335]
[Epoch 0/200] [Batch 300/937] [D loss: 1.336450, acc: 26%] [G loss: 1.473397]
[Epoch 0/200] [Batch 400/937] [D loss: 1.280214, acc: 36%] [G loss: 1.504840]
[Epoch 0/200] [Batch 500/937] [D loss: 1.243498, acc: 34%] [G loss: 1.447753]
[Epoch 0/200] [Batch 600/937] [D loss: 1.276594, acc: 29%] [G loss: 1.469209]
[Epoch 0/200] [Batch 700/937] [D loss: 1.202138, acc: 39%] [G loss: 1.451305]
[Epoch 0/200] [Batch 800/937] [D loss: 1.127721, acc: 39%] [G loss: 1.487319]
[Epoch 0/200] [Batch 900/937] [D loss: 1.113358, acc: 45%] [G loss: 1.505359]
[Epoch 1/200] [Batch 63/937] [D loss: 1.039862, acc: 50%] [G loss: 1.382103]
[Epoch 1/200] [Batch 163/937] [D loss: 0.985844, acc: 60%] [G loss: 1.278590]
[Epoch 1/200] [Batch 263/937] [D loss: 0.956709, acc: 54%] [G loss: 

KeyboardInterrupt: 