In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torchvision import datasets
import torchvision.transforms as transforms
from torchvision.transforms import ToTensor

#### Preprocessing the MNIST dataset

To normalize input data properly, we need to find the `mean` & `std` of the whole dataset.

In [None]:
dataset = datasets.MNIST(root='MNIST', download=True,transform=transforms.Compose([
    transforms.ToTensor(),  # Convert the image to `numpy.ndarray`
]))

mnist_mean, mnist_std = dataset.train_data.float().mean()/255, dataset.train_data.float().std()/255
print(f"MNIST mean={mnist_mean} & std={mnist_std}")

#### Loading MNIST onto a DataLoader

In [None]:
image_size = 64
batch_size = 64

dataset = datasets.MNIST(root='MNIST', transform=transforms.Compose([
    transforms.Resize(size=image_size), # Interpolate original dataset to fit the provided size
    transforms.ToTensor(),  # Convert the image to `numpy.ndarray`
    transforms.Normalize(mean=mnist_mean, std=mnist_std)
]))

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=1)

In [None]:
plt.imshow(dataset.data[0], cmap='gray')

In [None]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(in_channels=100, out_channels=512, kernel_size=4, stride=2),
            nn.BatchNorm2d(num_features=512),
            nn.ReLU(True),
            nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2),
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(True),
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(True),
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(True),   
            nn.ConvTranspose2d(in_channels=64, out_channels=3, kernel_size=4, stride=2),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.main(x)

In [None]:
generator = Generator()
generator.apply(weights_init)
print(generator)

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=128),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=256),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=512),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.main(x)
        

In [None]:
discriminator = Discriminator()
discriminator.apply(weights_init)
print(discriminator)

In [None]:
real_label = 1.
fake_label = 0.

In [None]:
loss = nn.BCELoss()

optimizer_discriminator = torch.optim.Adam(params=discriminator.parameters())
optimizer_generator = torch.optim.Adam(params=generator.parameters())

#### Training

In [None]:
loss_discriminator = []
loss_generator = []
