In [1]:
import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import os
import torch.optim as optim

In [2]:
absolute_path = os.path.join(os.getcwd(), '/mnist')

In [3]:
transform = transforms.Compose([transforms.ToTensor()])

train_dataset = MNIST(os.getcwd(), transform=transform, download=True)
test_dataset  = MNIST(os.getcwd(), transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False)

one_image_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) # For visualising

# Generator

In [4]:
class Generator(nn.Module):
    def __init__(self, latent_dim, output_channels, leaky_relu_parameter = 0.1):
        super(Generator, self).__init__()

        self.latent_dim = latent_dim
        self.output_channels = output_channels
        self.leaky_relu_parameter = leaky_relu_parameter


        self.net = nn.Sequential(
            nn.Linear(latent_dim, 64*7*7),
            nn.BatchNorm1d(64*7*7),
            nn.LeakyReLU(self.leaky_relu_parameter),
            nn.Unflatten(1, (64, 7, 7)),
            nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(self.leaky_relu_parameter),
            nn.ConvTranspose2d(in_channels=32, out_channels=output_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

    def forward(self, X):
        return self.net(X)

In [5]:
generator = Generator(64, 1)

In [6]:
random_vars = torch.randn(64).unsqueeze(0)

In [7]:
generator.eval()
generator(random_vars).shape

torch.Size([1, 1, 28, 28])

# Discriminator

In [40]:
class Discriminator(nn.Module):
    def __init__(self, output_dim):
        super(Discriminator, self).__init__()

        self.output_dim = output_dim

        self.net = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1), # 28x28
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride = 2), #14x14
            nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride = 2), #7x7
            nn.Flatten(),
            nn.Linear(32*7*7, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Linear(64, self.output_dim)
        )

    def forward(self, X):
        return self.net(X)

In [None]:
discriminator = Discriminator(2)
discriminator.eval()

In [43]:
example_img = next(iter(one_image_loader))[0]

In [None]:
discriminator(example_img)

# Initialisation

In [46]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
            nn.init.kaiming_uniform_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

In [47]:
generator = Generator(64, 1)
discriminator = Discriminator(output_dim=1)

In [None]:
generator.apply(initialize_weights)
discriminator.apply(initialize_weights)

# Training

In [52]:
lr = 0.0002
batch_size = train_loader.batch_size
latent_dim = generator.latent_dim
num_epochs = 50

smooth_factor = 0.1
real_label_smoothed = 1 - smooth_factor
fake_label_smoothed = smooth_factor
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


optimizer_g = optim.Adam(generator.parameters(), lr=lr)
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr)


criterion = torch.nn.BCEWithLogitsLoss()


for epoch in range(num_epochs):
    for i, (real_images, _) in enumerate(train_loader):
        # Real images to device
        # real_images = real_images.to(device)
        real_labels = torch.full((batch_size, 1), real_label_smoothed)
        fake_labels = torch.full((batch_size, 1), fake_label_smoothed)

        discriminator.zero_grad()
        generator.zero_grad()

        outputs_real = discriminator(real_images)
        loss_real = criterion(outputs_real, real_labels)


        noise = torch.randn(batch_size, latent_dim)
        fake_images = generator(noise)


        outputs_fake = discriminator(fake_images.detach())
        loss_fake = criterion(outputs_fake, fake_labels)


        loss_d = loss_real + loss_fake
        loss_d.backward()
        optimizer_d.step()



        fake_images = generator(noise)
        outputs_fake = discriminator(fake_images)
        loss_g = criterion(outputs_fake, real_labels)


        loss_g.backward()
        optimizer_g.step()

        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], '
                    f'D Loss: {loss_d.item():.4f}, G Loss: {loss_g.item():.4f}')


Epoch [1/50], Step [100/1875], D Loss: 1.4377, G Loss: 0.7539
Epoch [1/50], Step [200/1875], D Loss: 1.3946, G Loss: 0.7458
Epoch [1/50], Step [300/1875], D Loss: 1.3462, G Loss: 0.7564
Epoch [1/50], Step [400/1875], D Loss: 1.3158, G Loss: 0.7722
Epoch [1/50], Step [500/1875], D Loss: 1.2724, G Loss: 0.8109
Epoch [1/50], Step [600/1875], D Loss: 1.2743, G Loss: 0.8314
Epoch [1/50], Step [700/1875], D Loss: 1.2585, G Loss: 0.8436
Epoch [1/50], Step [800/1875], D Loss: 1.2764, G Loss: 0.8803
Epoch [1/50], Step [900/1875], D Loss: 1.2027, G Loss: 0.8632
Epoch [1/50], Step [1000/1875], D Loss: 1.1746, G Loss: 0.8886
Epoch [1/50], Step [1100/1875], D Loss: 1.1172, G Loss: 0.9249
Epoch [1/50], Step [1200/1875], D Loss: 1.1797, G Loss: 0.9159
Epoch [1/50], Step [1300/1875], D Loss: 1.1305, G Loss: 0.9786
Epoch [1/50], Step [1400/1875], D Loss: 1.1118, G Loss: 0.9885
Epoch [1/50], Step [1500/1875], D Loss: 1.0294, G Loss: 1.0222
Epoch [1/50], Step [1600/1875], D Loss: 1.0553, G Loss: 1.0645
E

[E thread_pool.cpp:110] Exception in thread pool task: mutex lock failed: Invalid argument


KeyboardInterrupt: 