In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as opt
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

In [None]:
class Generator(nn.Module):
    def __init__(self, image_width: int, image_height: int, num_of_layers: int, input_size: int, drop_out_rate: float):
        super().__init__()
        self.num_pixel = image_width * image_height
        self.num_of_layers = num_of_layers
        self.input_size = input_size
        self.drop_out_rate = drop_out_rate
        self.image_width = image_width
        self.image_height = image_height

        #This is how many neurons I am going to increase/decrease between each layer
        increase_decrease_size = int((self.num_pixel - self.input_size) / self.num_of_layers)

        self.generator = nn.Sequential(nn.Linear(input_size, input_size+increase_decrease_size),
                                       nn.LeakyReLU(),
                                       nn.Dropout(self.drop_out_rate),
                                       nn.Linear(input_size+increase_decrease_size, input_size+increase_decrease_size*2),
                                       nn.LeakyReLU(),
                                       nn.Dropout(self.drop_out_rate),
                                       nn.Linear(input_size+increase_decrease_size*2, input_size+increase_decrease_size*3),
                                       nn.LeakyReLU(),
                                       nn.Dropout(self.drop_out_rate),
                                       nn.Linear(input_size+increase_decrease_size*3, input_size+increase_decrease_size*4),
                                       nn.LeakyReLU(),
                                       nn.Dropout(self.drop_out_rate),
                                       nn.Linear(input_size+increase_decrease_size*4, input_size+increase_decrease_size*5),
                                       nn.LeakyReLU(),
                                       nn.Dropout(self.drop_out_rate),
                                       nn.Linear(input_size+increase_decrease_size*5, self.num_pixel),
                                       nn.Tanh())

    def forward(self, tensor):
        tensor = self.generator(tensor)
        images = tensor.view(-1, 1, self.image_width, self.image_height)
        return images

In [None]:
class Discriminator(nn.Module):
    def __init__(self, image_width: int, image_height: int, num_of_layers: int, input_size: int, drop_out_rate: float):
        super().__init__()
        self.num_pixel = image_width * image_height
        self.num_of_layers = num_of_layers
        self.input_size = input_size
        self.drop_out_rate = drop_out_rate

        #This is how many neurons I am going to increase/decrease between each layer
        increase_decrease_size = int((self.num_pixel - self.input_size) / self.num_of_layers)

        self.discriminator = nn.Sequential(nn.Linear(self.num_pixel, input_size+increase_decrease_size*5),
                                           nn.LeakyReLU(),
                                           nn.Dropout(self.drop_out_rate),
                                           nn.Linear(input_size+increase_decrease_size*5, input_size+increase_decrease_size*4),
                                           nn.LeakyReLU(),
                                           nn.Dropout(self.drop_out_rate),
                                           nn.Linear(input_size+increase_decrease_size*4, input_size+increase_decrease_size*3),
                                           nn.LeakyReLU(),
                                           nn.Dropout(self.drop_out_rate),
                                           nn.Linear(input_size+increase_decrease_size*3, input_size+increase_decrease_size*2),
                                           nn.LeakyReLU(),
                                           nn.Dropout(self.drop_out_rate),
                                           nn.Linear(input_size+increase_decrease_size*2, input_size+increase_decrease_size),
                                           nn.LeakyReLU(),
                                           nn.Dropout(self.drop_out_rate),
                                           nn.Linear(input_size+increase_decrease_size, 11))

    def forward(self, images):
        tensor = images.view(-1, self.num_pixel)
        tensor = self.discriminator(tensor)
        return tensor

In [None]:
# Hyperparameters
batch_size = 256
epochs = 20
lr = 2e-4
weight_decay = 1e-3
drop_out_rate = 0.3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Prepare the dataset
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

# Download and load the training data
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

In [None]:
#Instantiate the model
generator = Generator(image_width=28, image_height=28, num_of_layers=6, input_size=20, drop_out_rate=drop_out_rate).to(device)
discriminator = Discriminator(image_width=28, image_height=28, num_of_layers=6, input_size=20, drop_out_rate=drop_out_rate).to(device)

In [None]:
class ExponentialLRWithMin:
    def __init__(self, optimizer, gamma, min_lr):
        self.optimizer = optimizer
        self.gamma = gamma
        self.min_lr = min_lr

    def step(self):
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = max(param_group['lr'] * self.gamma, self.min_lr)

In [None]:
optimizer_G = opt.Adam(generator.parameters(), lr = lr, weight_decay = weight_decay)
optimizer_D = opt.Adam(discriminator.parameters(), lr = lr, weight_decay = weight_decay)

scheduler_generator = ExponentialLRWithMin(optimizer_G, gamma=0.95, min_lr=2e-5)
scheduler_discriminator = ExponentialLRWithMin(optimizer_D, gamma=0.95, min_lr=2e-5)

auxiliary_criterion = nn.MSELoss()
criterion = nn.CrossEntropyLoss()

# Define the image augmentation
transform = transforms.Compose([transforms.ToPILImage(),
                                transforms.RandomAffine(0, translate = (0.2, 0.2)),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

In [None]:
def input_creator(labels):
    list_of_input = []
    for i in range(labels.numel()):
        noise = torch.randn(10)
        label = torch.zeros(10)
        label[labels[i]] = 1

        input = torch.concatenate((noise, label), 0)
        list_of_input.append(input)

    list_of_input = torch.stack(list_of_input)
    list_of_input = list_of_input.view(-1, 20)
    list_of_input = list_of_input.to(device)
    return list_of_input

In [None]:
def find_label(labels):
    list_of_input = []
    for i in range(labels.numel()):
        label = torch.zeros(11)
        label[labels[i]] = 1

        list_of_input.append(label)

    list_of_input = torch.stack(list_of_input)
    list_of_input = list_of_input.to(device)
    return list_of_input

In [None]:
def create_fake_label(output):
    label = torch.zeros_like(output).to(device)
    label[:,10] = 1
    return label

In [None]:
def augment_image(image):
    # To the best of my knowledge, transform doesn't support a batch of images
    list_of_tensors = []
    for i in range(image.shape[0]):
        tensor = transform(image[i])
        list_of_tensors.append(tensor)

    tensor = torch.stack(list_of_tensors)
    tensor = tensor.to(device)
    return tensor

In [None]:
def small_noise_for_latent_space(input_tensors):
    noise_tensor = torch.randn_like(input_tensors) * 0.07
    noise_tensor[:, 10:] = 0
    return noise_tensor

In [None]:
#Training

epoch_discriminator_loss = 0
epoch_generator_loss = 0

train_discriminator_num = 1
train_generator_num = 1

visualization_labels = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).to(device)

while(1):
    for images, labels in trainloader:
        # Create the image augmentation of real images
        augmented_real_images = augment_image(images)

        images = images.to(device)

        # Train discriminator
        # if((epoch_discriminator_loss - epoch_generator_loss)>3):
        #     train_discriminator_num = 2
        # else:
        #     train_discriminator_num = 1

        for _ in range(train_discriminator_num):
            #This is the real image
            output_real = discriminator(images)
            label = find_label(labels)
            loss_discriminator_real = criterion(output_real, label)

            # This is fake image
            input_tensors = input_creator(labels)
            images_generated = generator(input_tensors)
            output_fake = discriminator(images_generated)
            label = create_fake_label(output_fake)
            loss_discriminator_fake = criterion(output_fake, label)

            # Create the image augmentation of fake images
            augmented_fake_images = augment_image(images_generated)

            # This is the auxiliary loss
            output_real_augmented = discriminator(augmented_real_images)
            loss_real_augmented = auxiliary_criterion(output_real_augmented, output_real)

            output_fake_augmented = discriminator(augmented_fake_images)
            loss_fake_augmented = auxiliary_criterion(output_fake_augmented, output_fake)

            # input_tensors = input_tensors + small_noise_for_latent_space(input_tensors)
            # images_generated = generator(input_tensors)
            # output_fake_zcr = discriminator(images_generated)
            # loss_fake_zcr = auxiliary_criterion(output_fake_zcr, output_fake)

            optimizer_D.zero_grad()
            loss_discriminator = loss_discriminator_real + 0.2 * loss_discriminator_fake + 0.2 * loss_real_augmented + 0.2 * loss_fake_augmented# + 0.2 * loss_fake_zcr
            if(labels.numel() != batch_size):
                epoch_discriminator_loss = loss_discriminator.item()
                print("Discriminator loss", loss_discriminator.item())
            loss_discriminator.backward()
            optimizer_D.step()

        # Train generator
        # if((epoch_generator_loss - epoch_discriminator_loss)>3):
        #     train_generator_num = 2
        # else:
        #     train_generator_num = 1

        for _ in range(train_generator_num):
            input_tensors = input_creator(labels)
            images_generated = generator(input_tensors)
            output = discriminator(images_generated)
            label = find_label(labels)
            loss_generator = criterion(output, label)

            # input_tensors = input_tensors + small_noise_for_latent_space(input_tensors)
            # images_generated_zcr = generator(input_tensors)
            # loss_generator_auxiliary = -auxiliary_criterion(images_generated_zcr, images_generated)

            loss_mse = auxiliary_criterion(images_generated, images)

            optimizer_G.zero_grad()
            loss_generator_final = loss_generator + loss_mse #+ 0.5 * loss_generator_auxiliary
            if(labels.numel() != batch_size):
                epoch_generator_loss = loss_generator.item()
                print("Generator loss", loss_generator.item())
            loss_generator_final.backward()
            optimizer_G.step()

    # scheduler_generator.step()
    # scheduler_discriminator.step()

    input_tensors = input_creator(labels)
    images_generated = generator(input_tensors)
    images_generated = images_generated.to("cpu").detach()

    for i in range(10):
        plt.imshow(images_generated[i][0])
        plt.show()

print("Done")

In [None]:
# This saves the models
torch.save(generator, "generator")
torch.save(discriminator, "discriminator")

In [None]:
# Create the labels for visualization
labels = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
input_tensor = input_tensors = input_creator(labels)

In [None]:
# Use matplotlib to visualize
images_generated = generator(input_tensor)
for i in range(10):
    plt.imshow(images_generated[i][0].to("cpu").detach())
    plt.show()