Referecnce : https://github.com/eriklindernoren/PyTorch-GAN

In [1]:
import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

### Generator architecture
* input random vector: 100 dim
* linear layer: out_features 128 * 8 * 8
* batchnorm
* upsample: factor 2
* Conv2d: out_channel: 128, kernel size 3, stride 1, padding 1
* batchnorm
* leakyrelu: 0.2
* upsample: factor 2
* conv2d: out_channel: 64, kernel size 3, stride 1, padding 1
* batchnorm
* leakyrelu: 0.2
* conv2d: out_channel: 1, kernel size 3, stride 1, padding 1
* tanh

In [2]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.l1 = nn.Sequential(nn.Linear(100, 128*8*8)) 
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 1, 3, stride=1, padding=1),
            nn.Tanh()
        )


    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.size(0), 128, 8, 8) # [128, 8, 8]
        img = self.conv_blocks(out)
        
        return img

### Discriminator architecture
* input: [1 , 32 , 32] image 
* conv2d: out_channel: 16, kernel size 3, stride 2, padding 1
* leakyrelu: 0.2
* dropout: 0.25
* Conv2d: out_channel: 32, kernel size 3, stride 2, padding 1
* leakyrelu: 0.2
* dropout: 0.25
* batchnorm
* Conv2d: out_channel: 64, kernel size 3, stride 2, padding 1
* leakyrelu: 0.2
* dropout: 0.25
* batchnorm
* Conv2d: out_channel: 128, kernel size 3, stride 2, padding 1
* leakyrelu: 0.2
* dropout: 0.25
* batchnorm
* linear: out_features 1
* sigmoid

In [3]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, kernel_size = 3, stride = 2, padding = 1), nn.LeakyReLU(0.2), 
                     nn.Dropout2d(0.25)]
            
            if bn:
                block.append(nn.BatchNorm2d(out_filters))
            
            return block
            
        self.model = nn.Sequential(# [1, 32, 32]
                *discriminator_block(1, 16, bn=False),# [16, 16, 16]
                *discriminator_block(16, 32),# [32, 8, 8]
                *discriminator_block(32, 64),# [64, 4, 4]
                *discriminator_block(64, 128))# [128, 2, 2]
        
        self.adv_layer = nn.Sequential(nn.Linear(128*2*2, 1), nn.Sigmoid())

    def forward(self, img):# [1, 32, 32]
        out = self.model(img)
        out = out.view(out.size(0), -1)
        out = self.adv_layer(out)
        
        return out


In [4]:
# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator().cuda()
discriminator = Discriminator().cuda()

In [5]:
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "./mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=64,
    shuffle=True,
    drop_last=True)

In [6]:
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.00002, betas=(0.5, 0.9999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.00002, betas=(0.5, 0.9999))

In [7]:
os.makedirs("./dcgan_images", exist_ok=True)

In [8]:
for epoch in range(200):
    for i, (real_imgs, _) in enumerate(dataloader):

        real_imgs = real_imgs.cuda()

        # Sample noise as generator input
        z = torch.Tensor(np.random.normal(0, 1, (real_imgs.shape[0], 100))).cuda()

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), torch.ones((gen_imgs.size(0),1)).cuda())

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), torch.ones((gen_imgs.size(0),1)).cuda())
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), torch.zeros((gen_imgs.size(0),1)).cuda())
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        batches_done = epoch * len(dataloader) + i


        if batches_done % 100 == 0:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, 64, i, len(dataloader), d_loss.item(), g_loss.item())
            )

        if batches_done % 2000 == 0:
            save_image(gen_imgs.data[:25], "dcgan_images/%d.png" % batches_done, nrow=5, normalize=True)

[Epoch 0/64] [Batch 0/937] [D loss: 0.706032] [G loss: 0.784511]
[Epoch 0/64] [Batch 100/937] [D loss: 0.562451] [G loss: 0.707362]
[Epoch 0/64] [Batch 200/937] [D loss: 0.568465] [G loss: 0.682858]
[Epoch 0/64] [Batch 300/937] [D loss: 0.508166] [G loss: 0.890154]
[Epoch 0/64] [Batch 400/937] [D loss: 0.459021] [G loss: 1.005925]
[Epoch 0/64] [Batch 500/937] [D loss: 0.403770] [G loss: 1.179395]
[Epoch 0/64] [Batch 600/937] [D loss: 0.343128] [G loss: 1.198708]
[Epoch 0/64] [Batch 700/937] [D loss: 0.276520] [G loss: 1.454750]
[Epoch 0/64] [Batch 800/937] [D loss: 0.240279] [G loss: 1.685916]
[Epoch 0/64] [Batch 900/937] [D loss: 0.193468] [G loss: 2.117780]
[Epoch 1/64] [Batch 63/937] [D loss: 0.154905] [G loss: 2.165761]


KeyboardInterrupt: 