In [1]:
import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

os.makedirs('images', exist_ok=True)

In [2]:
n_epochs = 200
batch_size = 64
lr = 1e-3
b1 = .5
b2 = .999,
latent_dim = 100
num_classes = 10
img_size = 32
channels = 1
sample_interval = 400

cuda = True if torch.cuda.is_available() else False

## Define the Generator and the Discriminator

In [3]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.label_emb = nn.Embedding(num_classes, latent_dim)

        self.init_size = img_size // 4 # Initial size before upsampling
        
        self.l1 = nn.Sequential(nn.Linear(latent_dim, 128*self.init_size**2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, channels, 3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, noise):
        out = self.l1(noise)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

In [4]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            """Returns layers of each discriminator block"""
            block = [   nn.Conv2d(in_filters, out_filters, 3, 2, 1),
                        nn.LeakyReLU(0.2, inplace=True),
                        nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.conv_blocks = nn.Sequential(
            *discriminator_block(channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = img_size // 2**4

        # Output layers
        self.adv_layer = nn.Sequential(nn.Linear(128*ds_size**2, 1),
                                        nn.Sigmoid())

        self.aux_layer = nn.Sequential( nn.Linear(128*ds_size**2, num_classes+1),
                                        nn.Softmax())

    def forward(self, img):
        out = self.conv_blocks(img)
        out = out.view(out.shape[0], -1)
        reality = self.adv_layer(out)
        label = self.aux_layer(out)

        return reality, label

In [5]:
# Loss functions
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()

In [6]:
generator = Generator()
discriminator = Discriminator()

In [7]:
if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()
    auxiliary_loss.cuda()

Optional step to control weight initialization

In [8]:
def init_weights(m):
    # applied to nn.Module instance
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02) #this sets m.weight.data with a normal dist of 0 mean and 0.02 var
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

generator.apply(init_weights)
discriminator.apply(init_weights)        

Discriminator(
  (conv_blocks): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace)
    (2): Dropout2d(p=0.25)
    (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): LeakyReLU(negative_slope=0.2, inplace)
    (5): Dropout2d(p=0.25)
    (6): BatchNorm2d(32, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (7): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): LeakyReLU(negative_slope=0.2, inplace)
    (9): Dropout2d(p=0.25)
    (10): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (11): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (12): LeakyReLU(negative_slope=0.2, inplace)
    (13): Dropout2d(p=0.25)
    (14): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
  )
  (adv_layer): Sequential(
    (0): Linear(in_features=512, out_features=1, bias=True

Configure data loader

In [9]:
os.makedirs('./data/downloaded/mnist', exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST('./data/downloaded/mnist', train=True, download=False,
                   transform=transforms.Compose([
                        transforms.Resize(img_size),
                        transforms.ToTensor(),
                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                   ])),
    batch_size=batch_size, shuffle=True)

In [10]:
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr)

FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor

In [11]:
def generate_noise():
    return Variable(FloatTensor(np.random.normal(0, 1, (batch_size, latent_dim))),requires_grad=False)

In [15]:
for epoch in range(n_epochs):
    for i, (imgs, labels) in enumerate(dataloader):

        # batch_size = imgs.shape[0]

        # Adversarial ground truths
        ones = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
        zeros = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)
        fake_class = Variable(LongTensor(batch_size).fill_(num_classes), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(FloatTensor))
        real_classification_labels = Variable(labels.type(LongTensor))


        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Loss for ones images
        reality, pred_classification_label = discriminator(real_imgs)
        d_real_loss =  (adversarial_loss(reality, ones) + auxiliary_loss(pred_classification_label, real_classification_labels)) / 2

        # Loss for zeros images
        z = generate_noise()
        generated_images = generator(z)
        reality, fake_classification_label = discriminator(generated_images.detach())
        d_fake_loss =  (adversarial_loss(reality, zeros) + auxiliary_loss(fake_classification_label, fake_class)) / 2

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2

        # Calculate discriminator accuracy score
        pred = np.concatenate([pred_classification_label.data.cpu().numpy(), fake_classification_label.data.cpu().numpy()], axis=0)
        gt = np.concatenate([real_classification_labels.data.cpu().numpy(), fake_class.data.cpu().numpy()], axis=0)
        d_acc = np.mean(np.argmax(pred, axis=1) == gt)
        
        d_loss.backward()
        optimizer_D.step()
        
        
        
         # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise and labels as generator input
        z = generate_noise()

        # Generate a batch of images
        generated_images = generator(z)

        # Loss measures generator's ability to fool the discriminator
        reality, _ = discriminator(generated_images)
        g_loss = adversarial_loss(reality, ones) #here we want out reality scores from the generated images to be close to real images...

        g_loss.backward()
        optimizer_G.step()
        

        if i == 0: print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]" % (epoch, n_epochs, i, len(dataloader),
                                                            d_loss.item(), 100 * d_acc,
                                                            g_loss.item()))

        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:
            save_image(generated_images.data[:25], 'images/%d.png' % batches_done, nrow=5, normalize=True)

  input = module(input)


[Epoch 0/200] [Batch 0/938] [D loss: 1.545556, acc: 6%] [G loss: 0.705941]


KeyboardInterrupt: 

Save the models 

In [None]:
torch.save(generator.state_dict(), './generator.pt')
torch.save(discriminator.state_dict(), './discriminator.pt')