Referecnce : https://github.com/eriklindernoren/PyTorch-GAN

In [1]:
import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

### Generator architecture
* input random vector: 100 dim, labels
* embedding layer: embeding labels to 100 dim
* linear layer: out_features 128 * 8 * 8
* batchnorm
* upsample: factor 2
* Conv2d: out_channel: 128, kernel size 3, stride 1, padding 1
* batchnorm
* leakyrelu: 0.2
* upsample: factor 2
* conv2d: out_channel: 64, kernel size 3, stride 1, padding 1
* batchnorm
* leakyrelu: 0.2
* conv2d: out_channel: 1, kernel size 3, stride 1, padding 1
* tanh

In [2]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        ## fill here
        ## noise, labels -> noise: 100, label: 1 dim -> 100 dim 
        ## noise * label
        self.label_emb = nn.Embedding(10, 100) 
        ## label 1: -> 100, 2: -> 100, 9-> 100
        self.l1 = nn.Linear(100, 128*8*8)
        self.conv_blocks = nn.Sequential( #[batch, 128, 8, 8]
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor =2), #[16,16]
            nn.Conv2d(128, 128, kernel_size = 3, stride = 1, padding = 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Upsample(scale_factor =2), #[32, 32]
            nn.Conv2d(128, 64, kernel_size = 3, stride = 1, padding = 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 1, kernel_size = 3, stride = 1, padding = 1),
            nn.Tanh()
        )
        
    def forward(self, noise, labels):
        ## fill here
        out = torch.mul(self.label_emb(labels), noise) # 100
        out = self.l1(out)
        out = out.view(out.shape[0], 128, 8, 8)
        img = self.conv_blocks(out)
        
        return img

### Discriminator architecture
* input: [1 , 32 , 32] image 
* conv2d: out_channel: 16, kernel size 3, stride 2, padding 1
* leakyrelu: 0.2
* dropout: 0.25
* Conv2d: out_channel: 32, kernel size 3, stride 2, padding 1
* leakyrelu: 0.2
* dropout: 0.25
* batchnorm
* Conv2d: out_channel: 64, kernel size 3, stride 2, padding 1
* leakyrelu: 0.2
* dropout: 0.25
* batchnorm
* Conv2d: out_channel: 128, kernel size 3, stride 2, padding 1
* leakyrelu: 0.2
* dropout: 0.25
* batchnorm
* two linear layers: one for adversarial loss, one for classification


In [3]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        ## fill here
        ## linear 2 
        ## one for real/fake classification
        ## one for class classification
        def discriminator_block(in_features, out_features, bn = True):
            block = []
            block.append(
                nn.Conv2d(in_features, out_features, kernel_size = 3, stride = 2, padding = 1)
            )
            block.append(nn.LeakyReLU(0.2))
            block.append(nn.Dropout2d(0.25))
            
            if bn:
                block.append(nn.BatchNorm2d(out_features))
            
            return block
        
        self.model = nn.Sequential( #[batch, 1, 32, 32]
            *discriminator_block(1, 16, bn = False), #[16, 16]
            *discriminator_block(16, 32),#[8, 8]
            *discriminator_block(32, 64),#[4, 4]
            *discriminator_block(64, 128)#[batch, 128, 2, 2]
        )
        
        self.adv_layer = nn.Sequential(
            nn.Linear(128*2*2, 1), nn.Sigmoid()
        )
        self.aux_layer = nn.Sequential(
            nn.Linear(128*2*2, 10) ## skip softmax
        )
        
    def forward(self, img):
        ## fill here
        out = self.model(img)
        out = out.view(out.shape[0], -1) #flatten
        validity = self.adv_layer(out) # 0, 1
        label = self.aux_layer(out) # 0 - 9
        
        return validity, label

In [4]:
# generator = Generator()

# input_sample = torch.randn(2,100)
# labels = torch.Tensor([0, 1]).long()
# output = generator(input_sample, labels)
# output.shape

In [5]:
# discriminator = Discriminator()
# validity, pred_label = discriminator(output)
# print(validity.shape)
# print(pred_label.shape)

In [6]:
# Loss function
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()

# Initialize generator and discriminator
generator = Generator().cuda()
discriminator = Discriminator().cuda()

In [7]:
dataset = datasets.MNIST("./mnist", train=True, download=True, 
                         transform=transforms.Compose([transforms.Resize(32), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True)

In [8]:
## optimizer fill here
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.00002, betas=(0.5, 0.9999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.00002, betas=(0.5, 0.9999))

In [9]:
os.makedirs("./acgan_images", exist_ok=True)

In [10]:
def sample_image(n_row, batches_done):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    # Sample noise
    z = torch.Tensor(np.random.normal(0, 1, (n_row ** 2, 100))).cuda()
    # Get labels ranging from 0 to n_classes for n rows
    labels = np.array([num for _ in range(n_row) for num in range(n_row)])
    labels = torch.LongTensor(labels).cuda()
    gen_imgs = generator(z, labels)
    save_image(gen_imgs, "acgan_images/%d.png" % batches_done, nrow=n_row, normalize=True)

In [13]:
for epoch in range(200):
    for i, (real_imgs, labels) in enumerate(dataloader):

        batch_size = real_imgs.shape[0]

        # Configure input
        real_imgs = real_imgs.cuda()
        labels = labels.cuda()
        
        # -----------------
        #  Train Generator
        # -----------------
        ## fill here
        ## z sample
        ## label sample 0-9 int value sample
        optimizer_G.zero_grad()
        
        z = torch.Tensor(np.random.normal(0, 1, (batch_size, 100))).cuda()
        gen_labels = torch.LongTensor(np.random.randint(0, 10, batch_size)).cuda()
        gen_images = generator(z, gen_labels)
        
        ## loss
        fake_validity, fake_label_pred = discriminator(gen_images)
        ## label prediction
        ce_loss = auxiliary_loss(fake_label_pred, gen_labels)
        adv_loss = adversarial_loss(fake_validity, torch.ones((batch_size, 1)).cuda())
        g_loss = (adv_loss + ce_loss)/2
        
        g_loss.backward()
        optimizer_G.step()
        
        # ---------------------
        #  Train Discriminator
        # ---------------------
        
        ## adv loss + ce loss
        optimizer_D.zero_grad()
        # Loss for real images
        real_validity, real_label_pred = discriminator(real_imgs)
        real_loss = (adversarial_loss(real_validity, torch.ones((batch_size, 1)).cuda()) 
                     + auxiliary_loss(real_label_pred, labels))/2
        # Loss for fake images
        fake_validity, fake_label_pred = discriminator(gen_images.detach())
        fake_loss = (adversarial_loss(fake_validity, torch.zeros((batch_size, 1)).cuda()) 
                     + auxiliary_loss(fake_label_pred, gen_labels))/2
        # Total discriminator loss
        d_loss = (real_loss + fake_loss)/2
        d_loss.backward()
        optimizer_D.step()
        
        # Calculate discriminator accuracy
        pred = np.concatenate([real_label_pred.detach().cpu().numpy(), fake_label_pred.detach().cpu().numpy()], axis=0)
        gt = np.concatenate([labels.detach().cpu().numpy(), gen_labels.detach().cpu().numpy()], axis=0)
        d_acc = np.mean(np.argmax(pred, axis=1) == gt)

        batches_done = epoch * len(dataloader) + i
        if batches_done % 100 == 0:     
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"
                % (epoch, 200, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item())
            )
            
        if batches_done % 2000 == 0:
            sample_image(n_row=10, batches_done=batches_done)

[Epoch 0/200] [Batch 0/937] [D loss: 1.574347, acc: 13%] [G loss: 1.676080]
[Epoch 0/200] [Batch 100/937] [D loss: 1.527292, acc: 12%] [G loss: 1.530517]
[Epoch 0/200] [Batch 200/937] [D loss: 1.416282, acc: 21%] [G loss: 1.476052]
[Epoch 0/200] [Batch 300/937] [D loss: 1.414452, acc: 23%] [G loss: 1.502225]
[Epoch 0/200] [Batch 400/937] [D loss: 1.397690, acc: 31%] [G loss: 1.458599]
[Epoch 0/200] [Batch 500/937] [D loss: 1.309003, acc: 32%] [G loss: 1.510559]
[Epoch 0/200] [Batch 600/937] [D loss: 1.279793, acc: 40%] [G loss: 1.436679]
[Epoch 0/200] [Batch 700/937] [D loss: 1.207819, acc: 39%] [G loss: 1.366102]
[Epoch 0/200] [Batch 800/937] [D loss: 1.166805, acc: 42%] [G loss: 1.362931]
[Epoch 0/200] [Batch 900/937] [D loss: 1.097959, acc: 53%] [G loss: 1.310933]
[Epoch 1/200] [Batch 63/937] [D loss: 1.034429, acc: 55%] [G loss: 1.235254]
[Epoch 1/200] [Batch 163/937] [D loss: 0.991135, acc: 58%] [G loss: 1.168180]
[Epoch 1/200] [Batch 263/937] [D loss: 0.903980, acc: 67%] [G loss:

KeyboardInterrupt: 