In [None]:
# TODO give the discriminator more time to train as described in the descriminator advantage global variable
# TODO make the generator connect up to a FC layer after convolutional upscaling
# TODO https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/78

import torch
import torch.utils.data.dataloader
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

import numpy as np
import pandas as pd
import random
from tqdm import tqdm
from tqdm.notebook import trange

import matplotlib.pyplot as plt

CUDA_LAUNCH_BLOCKING=1

In [None]:
NUM_EPOCHS = 100
DISCRIMINATOR_ADVANTAGE = 1
# GENERATOR_ADVANTAGE = 1
DISCRIMINATOR_LEARNING_RATE = 0.0005
GENERATOR_LEARNING_RATE = 0.001
LATENT_SPACE_SIZE = 100
IMAGE_SIZE = 28
CHANNEL_COUNT = 1

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

REAL_DATA_LABEL = 0
GENERATED_DATA_LABEL = 1

In [None]:
three_channel_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

one_channel_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5), (0.5))])

BATCH_SIZE = 15

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=one_channel_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=one_channel_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=2)


In [None]:
class Generator(nn.Module):
    def generate_hidden_layer(self, input_channels, output_channels, output_up_scaling_factor, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(input_channels, output_channels, output_up_scaling_factor, 1, padding),
            nn.BatchNorm2d(output_channels),
            nn.LeakyReLU()
        )
        
    def __init__(self):
        super(Generator, self).__init__()
        
        self.layer1 = nn.Sequential(
            nn.Linear(LATENT_SPACE_SIZE, 400),
            nn.Tanh()
        )
        
        # self.layer2 = self.generate_hidden_layer(400, 200, 4, 0)    # 4x4 
        self.layer2 = self.generate_hidden_layer(LATENT_SPACE_SIZE, 100, 4, 0)    # 4x4 
        self.layer3 = self.generate_hidden_layer(100, 50, 5, 0)   # 8x8
        self.layer4 = self.generate_hidden_layer(50, 25, 9, 0)    # 16x16
        self.layer5 = nn.Sequential(
            nn.ConvTranspose2d(25, CHANNEL_COUNT, 13, 1, 0),  # 28x28
            nn.BatchNorm2d(CHANNEL_COUNT),
            # nn.LeakyReLU()
            nn.Tanh()
        )

        
    def forward(self, x):
        """
        Args:
            x (torch.tensor): Some 1d tensor that consists of pure gaussian noise
        """
        # print(x.size())
        # x = self.layer1(x)
        # print(x.size())
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        return x
        
        
        
class Discriminator(nn.Module):
    def generate_convolutional_layer(self, in_channels, out_channels, convolution_kernel_size, pooling_kernel_size, stride):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, convolution_kernel_size, stride, padding=1),
            nn.AvgPool2d(pooling_kernel_size),
            nn.LeakyReLU(),
            nn.BatchNorm2d(out_channels)
        )
    
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.convolutional_layer1 = self.generate_convolutional_layer(CHANNEL_COUNT, 10, 3, 2, 1)
        self.convolutional_layer2 = self.generate_convolutional_layer(10, 30, 3, 2, 1) 
        self.convolutional_layer3 = self.generate_convolutional_layer(20, 40, 3, 2, 1)
        self.fully_connected1 = nn.Sequential(
            nn.LazyLinear(128),
            nn.LeakyReLU()
        )
        self.fully_connected2 = nn.Sequential(
            nn.LazyLinear(86),
            nn.LeakyReLU(),
        )
        self.fully_connected3 = nn.Sequential(
            nn.LazyLinear(1),
            nn.Sigmoid()
        )
        
        
    def forward(self, x):
        """
        Args:
            x (torch.tensor): Tensor that represents a 3x32x32 image
        """
        # print(x.size())
        x = self.convolutional_layer1(x)
        x = self.convolutional_layer2(x)
        # x = self.convolutional_layer3(x)
        x = x.view(x.size(0), -1)     # flatten each of the tensors in the batch so that we can feed it into the fully connected layer
        x = self.fully_connected1(x)
        # x = self.fully_connected2(x)
        return self.fully_connected3(x)
        # return x.view(-1)   # one other flatten to make sure everything lines up correctly

    
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
        

# generator = Generator()
# generator.apply(weights_init)
# generator.to(DEVICE)
generator = Generator().to(DEVICE)
generator.load_state_dict(torch.load("models/generator_MNIST"))
generator_optimizer = torch.optim.Adam(generator.parameters(), GENERATOR_LEARNING_RATE/ DISCRIMINATOR_ADVANTAGE)
torch.nn.utils.clip_grad_norm_(generator.parameters(), 1)

# discriminator = Discriminator()
# discriminator.apply(weights_init)
# discriminator.to(DEVICE)
discriminator = Discriminator().to(DEVICE)
discriminator.load_state_dict(torch.load("models/discriminator_MNIST"))
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), DISCRIMINATOR_LEARNING_RATE)
torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1)

# loss_function = nn.MSELoss()
loss_function = nn.BCELoss()

In [None]:
def generate_gaussian_noise(batch_size, dimensions):
    return torch.randn(batch_size, dimensions, 1, 1, device=DEVICE)


generator_losses = []
discriminator_losses = []

cur_training_iteration = 0

# print(trainloader.__iter__()._next_data())
for epoch_index in range(NUM_EPOCHS):
    for index, (real_images, image_labels) in enumerate(tqdm(trainloader, desc=f"epoch {epoch_index}")):
        
        real_images = real_images.to(DEVICE)
        image_labels = image_labels.float().to(DEVICE)
    
        labels = torch.zeros((real_images.size(0), 1), device=DEVICE).cuda()
        
        discriminator_optimizer.zero_grad()
        #
        # Train the discriminator network
        #
        # real data
        labels.fill_(REAL_DATA_LABEL)
        discriminator_loss_real = loss_function(discriminator(real_images), labels)
        discriminator_loss_real.backward() 

        # fake data
        labels.fill_(GENERATED_DATA_LABEL)
        generated_images = generator(generate_gaussian_noise(BATCH_SIZE, LATENT_SPACE_SIZE))
        discriminator_loss_generated = loss_function(discriminator(generated_images), labels)
        discriminator_loss_generated.backward(retain_graph=True)
        
        discriminator_losses.append(discriminator_loss_real.item() + discriminator_loss_generated.item())
        # print(discriminator_losses[-1])
        
        discriminator_optimizer.step()
            
        #
        # Train the generator network
        #
        labels.fill_(REAL_DATA_LABEL)
        generator_loss = loss_function(discriminator(generated_images), labels)
        generator_loss.backward()
        generator_losses.append(generator_loss.item())
        
        # print(generator_loss.item())
        # print()
        
        if cur_training_iteration % DISCRIMINATOR_ADVANTAGE == 0:
            generator_optimizer.step()
            generator_optimizer.zero_grad()
        
        cur_training_iteration += 1

In [None]:
from numpy import moveaxis
plt.plot(discriminator_losses)
plt.show()

plt.plot(generator_losses)
plt.show()

test_images = generator(generate_gaussian_noise(BATCH_SIZE, LATENT_SPACE_SIZE))
# print(test_images)
test_images = (-255 * ((test_images + 1)/ 2)).cpu().detach().numpy().astype(int)

test_image = moveaxis(test_images[1], 0, 2)
print(test_image.astype(int))
plt.imshow(test_image.astype(int), cmap="gray")
# x = torchvision.utils.make_grid(torch.tensor(images))
# print(x.size())
# plt.imshow(x)
# plt.show()


In [None]:
torch.save(generator.state_dict(), "models/generator_MNIST2")
torch.save(discriminator.state_dict(), "models/discriminator_MNIST2")