In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torchvision.datasets import ImageFolder, MNIST
from torchvision import transforms
from torch import autograd
from torch.autograd import Variable
from torchvision.utils import make_grid

In [None]:
transform = transforms.Compose([    #pytorch image transform pipeline
        transforms.ToTensor(),                #convert to tensor
        transforms.Normalize([0.5], [0.5])    # value = (value-0.5)/0.5 normalization, convert range from [0,1] to [-1,1]
])

In [None]:
batch_size = 32
# loader for mnist
data_loader = torch.utils.data.DataLoader(MNIST('data', train=True, download=True, transform=transform),batch_size=batch_size, shuffle=True)

In [None]:
class Discriminator(nn.Module): #inherited from nn.Module
    def __init__(self):
        super().__init__()      #init of nn.Module

        self.label_emb = nn.Embedding(10, 10)  # embedding label, assigns a 10 length vector to each label

        self.model = nn.Sequential(
            nn.Linear(794, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x, labels):
        x = x.view(x.size(0), 784) #resize image to flatten it
        c = self.label_emb(labels) #embedding labels
        x = torch.cat([x, c], 1)   #concat x and c
        out = self.model(x)
        return out.squeeze()

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()

        self.label_emb = nn.Embedding(10, 10)

        self.model = nn.Sequential(
            nn.Linear(110, 256),                #110 = 100+10 input neurons and passes 256 output neurons ahead
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, z, labels):
        z = z.view(z.size(0), 100)
        c = self.label_emb(labels) #embedding labels
        x = torch.cat([z, c], 1)   #concat z and c
        out = self.model(x)
        return out.view(x.size(0), 28, 28) #reshape output to image size

In [None]:
generator = Generator().cuda()
discriminator = Discriminator().cuda()
criterion = nn.BCELoss() #binary cross entropy loss
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-4)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=1e-4)

In [None]:
def generator_train_step(batch_size, discriminator, generator, g_optimizer, criterion):
    g_optimizer.zero_grad() #zeroes out previous accumulated gradients
    z = Variable(torch.randn(batch_size, 100)).cuda()  #generate random noise of 100 length
    fake_labels = Variable(torch.LongTensor(np.random.randint(0, 10, batch_size))).cuda() #generate random labels
    fake_images = generator(z, fake_labels) #generate fake images
    validity = discriminator(fake_images, fake_labels) #pass to discriminator to classify real or fake
    g_loss = criterion(validity, Variable(torch.ones(batch_size)).cuda()) #calculate loss, with vector of ones because we generator wants disc. to classify all as real
    g_loss.backward() #backprop
    g_optimizer.step() #update weights and biases
    return g_loss.item() #return loss value

def discriminator_train_step(batch_size, discriminator, generator, d_optimizer, criterion, real_images, labels):
    d_optimizer.zero_grad() #zeroes out previous accumulated gradients

    # train with real images
    real_validity = discriminator(real_images, labels)
    real_loss = criterion(real_validity, Variable(torch.ones(batch_size)).cuda()) #calc loss with vector of ones, as all are real

    # train with fake images
    z = Variable(torch.randn(batch_size, 100)).cuda() #generate random noise of 100 length
    fake_labels = Variable(torch.LongTensor(np.random.randint(0, 10, batch_size))).cuda() #generate random labels
    fake_images = generator(z, fake_labels) #generate fake images
    fake_validity = discriminator(fake_images, fake_labels) #pass to discriminator to classify real or fake
    fake_loss = criterion(fake_validity, Variable(torch.zeros(batch_size)).cuda()) #calc loss with vector of zeros, as all are fake

    d_loss = real_loss + fake_loss #combine losses
    d_loss.backward() #backprop
    d_optimizer.step() #update weights and biases
    return d_loss.item()

In [None]:
num_epochs = 50
n_critic = 5 # number of times discrimimnator is trained per one generator training step
display_step = 50
for epoch in range(num_epochs):
    print('Starting epoch {}...'.format(epoch), end=' ')
    for i, (images, labels) in enumerate(data_loader):

        step = epoch * len(data_loader) + i + 1
        real_images = Variable(images).cuda()
        labels = Variable(labels).cuda()
        generator.train()       #tells the model that it is in training phase

        d_loss = 0
        for _ in range(n_critic):
            d_loss = discriminator_train_step(len(real_images), discriminator,generator, d_optimizer, criterion, real_images, labels)


        g_loss = generator_train_step(batch_size, discriminator, generator, g_optimizer, criterion)


        if step % display_step == 0:
            generator.eval()
            z = Variable(torch.randn(9, 100)).cuda()
            labels = Variable(torch.LongTensor(np.arange(9))).cuda()
            sample_images = generator(z, labels).unsqueeze(1)
            grid = make_grid(sample_images, nrow=3, normalize=True)
    print('Done!')

In [None]:
torch.save(generator.state_dict(), 'generator_state.pt')
z = Variable(torch.randn(100, 100)).cuda()
labels = torch.LongTensor([i for i in range(10) for _ in range(10)]).cuda()
images = generator(z, labels).unsqueeze(1)
grid = make_grid(images, nrow=10, normalize=True)
fig, ax = plt.subplots(figsize=(10,10))
ax.imshow(grid.permute(1, 2, 0).data, cmap='binary')
ax.axis('off')

In [None]:
def generate_digit(generator, digit):
    z = Variable(torch.randn(1, 100)).cuda()
    label = torch.LongTensor([digit]).cuda()
    img = generator(z, label).data.cpu()
    img = 0.5 * img + 0.5
    return transforms.ToPILImage()(img)
generate_digit(generator, 8)