In [1]:
import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import os
import torch.optim as optim

In [2]:
absolute_path = os.path.join(os.getcwd(), '/mnist')

In [3]:
transform = transforms.Compose([transforms.ToTensor()])

train_dataset = MNIST(os.getcwd(), transform=transform, download=True)
test_dataset  = MNIST(os.getcwd(), transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False)

one_image_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) # For visualising

# Generator

In [4]:
class Generator(nn.Module):
    def __init__(self, latent_dim, output_channels, leaky_relu_parameter = 0.1):
        super(Generator, self).__init__()

        self.latent_dim = latent_dim
        self.output_channels = output_channels
        self.leaky_relu_parameter = leaky_relu_parameter


        self.net = nn.Sequential(
            nn.Linear(latent_dim, 64*7*7),
            nn.BatchNorm1d(64*7*7),
            nn.LeakyReLU(self.leaky_relu_parameter),
            nn.Unflatten(1, (64, 7, 7)),
            nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(self.leaky_relu_parameter),
            nn.ConvTranspose2d(in_channels=32, out_channels=output_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

    def forward(self, X):
        return self.net(X)

In [5]:
generator = Generator(64, 1)

In [6]:
random_vars = torch.randn(64).unsqueeze(0)

In [7]:
generator.eval()
generator(random_vars).shape

torch.Size([1, 1, 28, 28])

# Discriminator

In [28]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.net = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1), # 28x28
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride = 2), #14x14
            nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride = 2), #7x7
            nn.Flatten(),
            nn.Linear(32*7*7, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Linear(64, 2)
        )

    def forward(self, X):
        return self.net(X)

In [30]:
discriminator = Discriminator()
discriminator.eval()

Discriminator(
  (net): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Flatten(start_dim=1, end_dim=-1)
    (9): Linear(in_features=1568, out_features=64, bias=True)
    (10): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU()
    (12): Linear(in_features=64, out_features=2, bias=True)
  )
)

In [31]:
example_img = next(iter(one_image_loader))[0]

In [32]:
discriminator(example_img)

tensor([[ 0.0856, -0.0078]], grad_fn=<AddmmBackward0>)

# Initialisation

In [33]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

In [35]:
generator = Generator(64, 1)
discriminator = Discriminator()