In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# Setting the constants
BATCHSIZE, CHANNELS, WIDTH, HEIGHT = 100, 1, 28, 28
FAKE, REAL = 0, 1

In [None]:
### Downloading the MNIST dataset

!wget www.di.ens.fr/~lelarge/MNIST.tar.gz
!tar -zxvf MNIST.tar.gz
transform=transforms.Compose([
                              transforms.ToTensor(),
                              transforms.Normalize((0.5,), (0.5,)),
                             ])
mnist_set = MNIST(root='./', download=True, transform=transform)
dataset = DataLoader(mnist_set, batch_size=BATCHSIZE)

In [None]:
class Discriminator(nn.Module):
    """Architecture of the discriminator that distinguishes real and generated
     images from each other."""
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(CHANNELS, 32, kernel_size=7, stride=1)
        self.batchnorm1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=8, stride=1)
        self.batchnorm2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 1, kernel_size=8, stride=8)

    def forward(self, input):
        output = self.conv1(input)
        output = F.leaky_relu(output)
        output = self.batchnorm1(output)
        output = self.conv2(output)
        output = F.leaky_relu(output)
        output = self.batchnorm2(output)
        output = self.conv3(output)
        output = torch.sigmoid(output)
        return output

In [None]:
class Generator(nn.Module):
    """Architecture for the generator. It generates an image based on noise
     created from a zero-mean gaussian."""
    def __init__(self):
        super(Generator, self).__init__()
        self.conv1 = nn.ConvTranspose2d(BATCHSIZE, 512, 4, 1, 0)
        self.batchnorm1 = nn.BatchNorm2d(512)
        self.conv2 = nn.ConvTranspose2d(512, 256, 4, 2, 1)
        self.batchnorm2 = nn.BatchNorm2d(256)
        self.conv3 = nn.ConvTranspose2d(256, 128, 4, 2, 2)
        self.batchnorm3 = nn.BatchNorm2d(128)
        self.conv4 = nn.ConvTranspose2d(128, 64, 4, 1, 1)
        self.batchnorm4 = nn.BatchNorm2d(64)
        self.conv5 = nn.ConvTranspose2d(64, 1, 4, 2, 2)

    def forward(self, input):
        output = self.batchnorm1(F.relu(self.conv1(input)))
        output = self.batchnorm2(F.relu(self.conv2(output)))
        output = self.batchnorm3(F.relu(self.conv3(output)))
        output = self.batchnorm4(F.relu(self.conv4(output)))
        output = torch.tanh(self.conv5(output))
        return output

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
D = Discriminator().to(device)
G = Generator().to(device)
# loss function, binary cross-entropy since we have 0 and 1 as labels
bce = nn.BCELoss()
optimizerD = torch.optim.Adam(D.parameters(), lr=0.0002)
optimizerG = torch.optim.Adam(G.parameters(), lr=0.0002)
epochs = 7

for epoch in range(epochs):
    for i, data in enumerate(dataset):
        D.zero_grad()
        # train the discriminator on real numbers 
        # create labels for real images
        label = torch.full((BATCHSIZE,), REAL, dtype=torch.float, device=device)
        # make predictions with the discriminator
        predictions = D(data[0].cuda()).to(device).view(-1)
        # get and update the gradients with the real labels
        realError = bce(predictions, label)
        realError.backward()

        # train the discriminator on fake generated images
        # create the fake labels 
        label.fill_(FAKE)
        # create noise
        noise = torch.randn(BATCHSIZE, BATCHSIZE, 1, 1, device=device)
        # create a fake image from the noise
        fake_images = G(noise).to(device)
        # make predictions with the discriminator
        predictions = D(fake_images.detach()).to(device).view(-1)
        # get and update the gradients with the fake labels
        fakeError = bce(predictions, label)
        fakeError.backward()
        addedError = realError + fakeError
        optimizerD.step()

        # train the generator 
        G.zero_grad()
        # initialize real labels again
        label.fill_(REAL)
        # make predictions with the discriminator for the fake images
        predictions = D(fake_images).to(device).view(-1)
        # train the generator with these predictions and the real labels
        generatorError = bce(predictions, label)
        generatorError.backward()
        optimizerG.step()
        print('Samples: {}/60000\nEpoch: {}\nLoss D: {}\nLoss G: {}\n'.format((i+1)*BATCHSIZE, epoch, addedError, generatorError))

In [None]:
# plot images
noise = torch.randn(BATCHSIZE, BATCHSIZE, 1, 1, device=device)
images = G(noise).to(device)
fig, (ax1, ax2) = plt.subplots(1, 2)
ax1.imshow(images[64].detach().cpu()[0], cmap='gray')
ax2.imshow(images[32].detach().cpu()[0], cmap='gray')