In [1]:
## Implementation of Conditional Generative Adversarial Nets, by Mirza etal (2017)

In [3]:
## Import packages and modules
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
import matplotlib.pyplot as plt
os.chdir("/home/agastya/Downloads")
%matplotlib inline

In [24]:
# Discriminator input is a 4D tensor of the shape (batch_size, channels, img_dim. img_dim)
# Generator Input is a 2D tensor of the shape (batch_size, latent_dims)
# Labels is a 1D tensor of the shape (batch_size). It is NOT one hot encoded.

In [4]:
n_epochs = 200    # number of epochs of training, int, default=200
batch_size = 64    # size of the batches, int, default=64
lr = 0.0002    # learning rate, float, default=0.0002
b1 = 0.5    # first order momentum gradient decay ADAM, float, default=0.5
b2 = 0.999    # second order momentum gradient decay ADAM, float, default=0.999
n_cpu = 8    # number of cpu threads to use during batch generation, int, default=8
latent_dim = 100    # dimensionality of latent space, int, default=100
img_size = 32    # size of each image dimension, int, default=28
channels = 1    # number of image channels, int, default=1
sample_interval = 400    # interval between image samples, int, default=400
n_classes = 10    # number of classes for dataset

In [6]:
img_shape = (channels, img_size, img_size)
cuda = True if torch.cuda.is_available() else False

In [10]:
## 1.0 Data Preparation and Preprocessing
# Use transforms.Resize(img_size) to resize the image into (batch_size, 1, img_size, img_size)
def mnist_data():
    compose = transforms.Compose([
         transforms.Resize(img_size),
         transforms.ToTensor(),
         transforms.Normalize((.5, .5, .5), (.5, .5, .5)) 
         #Normalized to (-1,1) so as to mimic a tanh activation function
        ])
    out_dir = './dataset'
    return datasets.MNIST(root=out_dir, train=True, transform=compose, download=True)
data = mnist_data()
data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)

In [9]:
# Weights initialization
# gaussian Distribution works the best
def init_weights(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm1d') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

In [72]:
# BatchNorm with a momentum of 0.8
# The desired class is embedded into an extra tensor
# An extra tensor is concatenated with the input image
# Linear GAN is used instead of DCGAN
class Generator(nn.Module):
    
    def __init__(self, n_classes, input_dims, output_dims, img_shape):
        super(Generator, self).__init__()
        
        self.n_classes = n_classes
        self.negative_slope = 0.2
        self.input_dims = input_dims
        self.output_dims= output_dims
        self.img_shape = img_shape
        
        self.label_embeds = nn.Embedding(n_classes, n_classes)
        
        self.layer1 = nn.Linear(input_dims+n_classes, 128)
        self.layer2 = nn.Linear(128, 256)
        self.batchnorm2 = nn.BatchNorm1d(256, 0.8)
        self.layer3 = nn.Linear(256, 512)
        self.batchnorm3 = nn.BatchNorm1d(512, 0.8)
        self.layer4 = nn.Linear(512, 1024)
        self.batchnorm4 = nn.BatchNorm1d(1024, 0.8)
        self.layer5 = nn.Linear(1024, output_dims)
        
    def forward(self, x, labels):
        
        embeds = self.label_embeds(labels)
        x = torch.cat((x, embeds), -1)
        x = F.leaky_relu_(self.layer1(x), self.negative_slope)
        x = F.leaky_relu_(self.batchnorm2(self.layer2(x)), self.negative_slope)
        x = F.leaky_relu_(self.batchnorm3(self.layer3(x)), self.negative_slope)
        x = F.leaky_relu_(self.batchnorm4(self.layer4(x)), self.negative_slope)
        x = F.tanh(self.layer5(x))
        x = x.view(x.size(0), *self.img_shape)
        return x

In [2]:
# number of classes is embedded into a tensor
# Leaky RELU rules!!
class Discriminator(nn.Module):
    
    def __init__(self, n_classes, input_dims, dropout=0.4):
        super(Discriminator, self).__init__()
        
        self.label_embeds = nn.Embedding(n_classes, n_classes)
        self.negative_slope = 0.2
        
        self.layer1 = nn.Linear(n_classes+input_dims, 512)
        self.layer2 = nn.Linear(512, 512)
        self.dropout2 = nn.Dropout(dropout)
        self.layer3 = nn.Linear(512, 512)
        self.dropout3 = nn.Dropout(dropout)
        self.layer4 = nn.Linear(512, 1)
    
    def forward(self, x, labels):
        embeds = self.label_embeds(labels)
        x = torch.cat((x.view(x.size(0), -1), embeds), -1)
        x = F.leaky_relu_(self.layer1(x), self.negative_slope)
        x = F.leaky_relu_(self.dropout2(self.layer2(x)), self.negative_slope)
        x = F.leaky_relu_(self.dropout3(self.layer3(x)), self.negative_slope)
        x = F.sigmoid(self.layer4(x))
        return x

NameError: name 'nn' is not defined

In [74]:
# Binary CrossEntropy Loss
# For discriminator, maximize log(D(x)) + log (1 - D(G(z)))
# For generator, maximize log(D(G(z)))
adversarial_loss = nn.BCELoss()
generator = Generator(n_classes, latent_dim, img_size**2, img_shape)
discriminator = Discriminator(n_classes, img_size**2)
generator.apply(init_weights)
discriminator.apply(init_weights)
gen_optim = optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
disc_optim = optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))
if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

In [78]:
# Sample random noise froma Gaussian Distribution
def noise(size):
    n = torch.tensor(torch.randn(size, 100))
    if torch.cuda.is_available():
        return n.cuda()
    else:
        return n

# Returns an array of 1's for real data
def real_data_targets(size):
    data = torch.tensor(torch.ones(size, 1))
    if torch.cuda.is_available():
        return data.cuda()
    else:
        return data

# Returns a array of 0's for fake data
def fake_data_targets(size):
    data = torch.tensor(torch.zeros(size, 1))
    if torch.cuda.is_available():
        return data.cuda()
    else:
        return data

# Randomized data classes from 0-10 from a uniform distribution
def fake_data_labels(size):
    labels = torch.tensor(torch.randint(low=0,high=10, size=(1,size)).view(-1), dtype=torch.long)
    return labels

In [68]:
def train_disc(optimizer, real_data, real_labels, gen_data, gen_labels):
    
    optimizer.zero_grad()
    
    # Training on real data
    # real_loss = -log(D(x))
    real_prediction = discriminator(real_data, real_labels)
    real_loss = adversarial_loss(real_prediction, real_data_targets(real_data.size(0)))
    real_loss.backward()
    
    # Training on generated data
    # fake_loss = -log(1 - D(G(z)))
    fake_prediction = discriminator(gen_data, gen_labels)
    fake_loss = adversarial_loss(fake_prediction, fake_data_targets(gen_data.size(0)))
    fake_loss.backward()
    
    # Actual gradient update
    # Total_loss = -log(D(x)) - log(1 - D(G(z))
    optimizer.step()
    
    return real_loss+fake_loss, real_prediction, fake_prediction

In [83]:
def train_gen(optimizer, fake_data, fake_labels):
    
    optimizer.zero_grad()
    
    # Generator loss is -log(D(G(z)))
    prediction = discriminator(fake_data, fake_labels)
    loss = adversarial_loss(prediction, real_data_targets(fake_data.size(0)))
    loss.backward()
    optimizer.step()
    return loss

In [84]:
def train_GAN(num_epochs):

    for epoch in range(num_epochs):
        for batch, (real_data, real_labels) in enumerate(data_loader):
            
            real_data = torch.tensor(real_data.view(real_data.size(0),img_size**2))
            real_labels = real_labels
            if torch.cuda.is_available():
                real_data = real_data.cuda()
            fake_labels = fake_data_labels(real_data.size(0))
            fake_data = generator(noise(real_data.size(0)), fake_labels).detach() # Don't train generator when training discriminator
            
            # First train Discriminator without training discriminator
            disc_error, disc_real_pred, disc_fake_pred = train_disc(disc_optim, real_data, real_labels, fake_data, fake_labels)
            
            # Then train generator via discriminator without training discriminator
            # that's why two optimizers with only each parameters of discriminator and generator
            fake_data = generator(noise(real_data.size(0)), fake_labels)
            gen_error = train_gen(gen_optim, fake_data, fake_labels)

In [85]:
train_GAN(10)

KeyboardInterrupt: 