In [None]:
import numpy as np
import matplotlib.pyplot as plt
import random
import torch
from tqdm import tqdm
import torch.nn as nn
from torchvision import datasets
import torchvision.transforms as transforms
from torchvision.transforms import ToTensor
import torchvision.utils as vutils
import matplotlib.animation as animation
from IPython.display import HTML

#### Preprocessing the MNIST dataset

To normalize input data properly, we need to find the `mean` & `std` of the whole dataset.

In [None]:
dataset = datasets.MNIST(root='MNIST', download=True,transform=transforms.Compose([
    transforms.ToTensor(),  # Convert the image to `numpy.ndarray`
]))

mnist_mean, mnist_std = dataset.train_data.float().mean()/255, dataset.train_data.float().std()/255
print(f"MNIST mean={mnist_mean} & std={mnist_std}")

#### Loading MNIST onto a DataLoader

In [None]:
image_size = 64
batch_size = 64

dataset = datasets.MNIST(root='MNIST', transform=transforms.Compose([
    transforms.Resize(size=image_size), # Interpolate original dataset to fit the provided size
    transforms.ToTensor(),  # Convert the image to `numpy.ndarray`
    transforms.Normalize(mean=mnist_mean, std=mnist_std)
]))

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=1)

In [None]:
plt.imshow(dataset.data[0], cmap='gray')

In [None]:
# custom weights initialization based on DCGAN paper
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
noise_size = 100
seed = 1
random.seed(seed)
torch.manual_seed(seed)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(in_channels=noise_size, out_channels=512, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(num_features=512),
            nn.ReLU(True),

            nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(True),   
            
            nn.ConvTranspose2d(in_channels=64, out_channels=1, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.main(x)

In [None]:
generator = Generator()
generator.apply(weights_init)
print(generator)

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(negative_slope=0.02, inplace=True),

            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(num_features=128),
            nn.LeakyReLU(negative_slope=0.02, inplace=True),
            
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(num_features=256),
            nn.LeakyReLU(negative_slope=0.02, inplace=True),
            
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(num_features=512),
            nn.LeakyReLU(negative_slope=0.02, inplace=True),
            
            nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=0, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.main(x)
        

In [None]:
discriminator = Discriminator()
discriminator.apply(weights_init)
print(discriminator)

In [None]:
real_label = 1.
fake_label = 0.

fixed_noise = torch.randn(size=(batch_size, noise_size, 1, 1))

In [None]:
loss = nn.BCELoss()

# Tricks to train the GANs
optimizer_discriminator = torch.optim.SGD(params=discriminator.parameters(), lr=0.0002)
optimizer_generator = torch.optim.Adam(params=generator.parameters(), lr=0.0002)

#### Training

In [None]:
loss_discriminator = []
loss_generator = []
images_from_fixed_noise = []
EPOCHS = 4
counter = 0

for epoch in range(EPOCHS):
    for i, batch_data in tqdm(enumerate(dataloader, 0)):
        # batch_data contains a pair of lists: < images `numpy.ndarray`, true labels >

        # #############################################
        # Training the Discriminator with a real batch
        discriminator.zero_grad()  # "PyTorch accumulates the gradients on subsequent backward passes"
        batch = batch_data[0].size(0)
        labels = torch.full(size=(batch,), fill_value=np.random.normal(1., 0.1, 1)[0], dtype=torch.float)  # list of real labels
        discriminator_from_real_output = discriminator(batch_data[0]).view(-1)  # forward propagation
        loss_discriminator_from_real_samples = loss(discriminator_from_real_output, labels)  # calculate the loss
        loss_discriminator_from_real_samples.backward()  # back-propagate the gradient of the loss
        d_x = discriminator_from_real_output.mean().item()  # get the predicted ouput, should converge to 0.5
        
        # #############################################
        # Training the Discriminator with a fake batch
        fake_noise = torch.randn(size=(batch, noise_size, 1, 1))  # generate latent fake vectors
        fake_batch = generator(fake_noise)  # generate a fake batch of images by noise
        labels.fill_(value=np.random.normal(0.1, 0.05, 1)[0])  # list of fake labels
        discriminator_from_fake_output = discriminator(fake_batch.detach()).view(-1)
        loss_discriminator_from_fake_samples = loss(discriminator_from_fake_output, labels)
        loss_discriminator_from_fake_samples.backward()  # back-propagate the gradient of the loss
        d_g_z_1 = loss_discriminator_from_fake_samples.mean().item()  # get the predicted ouput, should converge to 0.5

        loss_discriminator_from_real_and_fake_samples = loss_discriminator_from_real_samples + loss_discriminator_from_fake_samples
        optimizer_discriminator.step()

        # #############################################
        # Training the Generator with random noise
        generator.zero_grad()
        labels.fill_(value=real_label)
        generator_from_noise_output = discriminator(fake_batch).view(-1)
        loss_generator_from_noise = loss(generator_from_noise_output, labels)
        loss_generator_from_noise.backward()
        d_g_z_2 = loss_generator_from_noise.mean().item()
        optimizer_generator.step()

        # Accessibility
        loss_d = np.round(loss_discriminator_from_real_and_fake_samples.item(), 2)
        loss_g = np.round(loss_generator_from_noise.item(), 2)
        if i % 100 == 0:
            print(f"Epoch #{epoch} \t Loss_D={loss_d} \t Loss_G={loss_g} \t D(x)={d_x} \t D(G(z_1))={d_g_z_1} \t D(G(z_2))={d_g_z_2}")

        loss_discriminator.append(loss_discriminator_from_real_and_fake_samples.item())
        loss_generator.append(loss_generator_from_noise.item())

        if (counter % 100 == 0) or ((epoch == EPOCHS-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = generator(fixed_noise).detach().cpu()
            images_from_fixed_noise.append(vutils.make_grid(fake, padding=2, normalize=True))

        counter += 1
    generator.train()

In [None]:
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in images_from_fixed_noise]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

In [None]:
ani.save('dcgan_mnist.mp4', metadata={'author':'MUDSS'})