<a href="https://colab.research.google.com/github/JHyunjun/torch_GAN/blob/main/ConditionalGAN_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#Created by Hyunjun JANG
#training GAN to check the performance with simple pattern

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
import random

csv_file = '/content/sample_data/mnist_test.csv'

class MnistDataset(Dataset):
    
    def __init__(self, csv_file):
        self.data_df = pd.read_csv(csv_file, header=None)
        pass
    
    def __len__(self):
        return len(self.data_df)
    
    def __getitem__(self, index):
        
        label = self.data_df.iloc[index,0]
        target = torch.zeros((10))
        target[label] = 1.0
        
        
        image_values = torch.FloatTensor(self.data_df.iloc[index,1:].values) / 255.0
        
        
        return label, image_values, target
    
    def plot_image(self, index):
        img = self.data_df.iloc[index,1:].values.reshape(28,28)
        plt.title("label = " + str(self.data_df.iloc[index,0]))
        plt.imshow(img, interpolation='none', cmap='Reds')
        pass
    
    pass


mnist_dataset = MnistDataset(csv_file)
print(np.shape(mnist_dataset))

In [None]:
mnist_dataset.plot_image(17)

In [None]:
#Discriminator is learning about generate_real as true pattern and generate_random is false pattern

Data_size = 784+10
class Discriminator(nn.Module):
    
    def __init__(self):
        # initialise parent pytorch class
        super().__init__()
        
        # define neural network layers
        self.model = nn.Sequential(
            nn.Linear(Data_size, 200),
            nn.LeakyReLU(0.02),
            nn.LayerNorm(200),
            nn.Linear(200, 1),
            nn.Sigmoid()
        )
        
        # create loss function
        self.loss_function = nn.BCELoss()

        # create optimiser, simple stochastic gradient descent
        self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)

        # counter and accumulator for progress
        self.counter = 0;
        self.progress = []

        pass
    
    
    def forward(self, image_tensor, label_tensor):
        # simply run model
        inputs = torch.cat((image_tensor, label_tensor))
        return self.model(inputs)
    
    
    def train(self, inputs, label_tensor, targets):
        # calculate the output of the network
        outputs = self.forward(inputs, label_tensor)
        
        # calculate loss
        loss = self.loss_function(outputs, targets)

        # increase counter and accumulate error every 10
        self.counter += 1;
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())
            pass
        if (self.counter % 3000 == 0):
            print("counter = ", self.counter)
            pass

        # zero gradients, perform a backward pass, update weights
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()

        pass
    
    
    def plot_progress(self):
        df = pd.DataFrame(self.progress, columns=['Discriminator loss'])
        df.plot(ylim=(0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0, 5.0))
        pass
    
    pass

In [None]:
def generate_random(size) :
  random_data = torch.rand(size)
  return random_data

def generate_random_seed(size):
    random_data = torch.randn(size)
    return random_data  

def generate_random_image(size) : 
  random_data = torch.rand(size)
  return random_data

def generate_random_one_hot(size) : 
  label_tensor = torch.zeros((size))
  random_idx = random.randint(0, size-1)
  label_tensor[random_idx] = 1.0
  return label_tensor

In [None]:
#Training the Discriminator
D = Discriminator()
targets = torch.FloatTensor([1.0])
non_target = torch.FloatTensor([0.0])

for label, image_data_tensor, label_tensor in mnist_dataset : 
  D.train(image_data_tensor, label_tensor, targets)
  D.train(generate_random_image(784), generate_random_one_hot(10), non_target)

  pass

In [None]:
#Checking the performance of Discriminator
D.plot_progress()

In [None]:
# Constructing Generator

class Generator(nn.Module) : 
  def __init__(self) : 
    super().__init__()

    self.model = nn.Sequential(
        nn.Linear(100+10,200),
        nn.LeakyReLU(0.02),
        nn.LayerNorm(200),
        nn.Linear(200,Data_size-10),
        nn.Sigmoid()
    )

    self.optimiser = torch.optim.Adam(self.parameters(), lr = 0.0001)
    self.counter = 0
    self.progress = []

    pass

  def forward(self, seed_tensor, label) : 
    inputs = torch.cat((seed_tensor, label_tensor))
    return self.model(inputs)

  def train(self, D, inputs, label_tensor, targets) : 
    g_output = self.forward(inputs, label_tensor)
    d_output = D.forward(g_output, label_tensor)
    loss = D.loss_function(d_output, targets)

    self.counter+=1;
    if (self.counter % 10 == 0) :
      self.progress.append(loss.item())
      pass

    self.optimiser.zero_grad()
    loss.backward()
    self.optimiser.step()

  def plot_progress(self):
        df = pd.DataFrame(self.progress, columns=['Generator loss'])
        df.plot(ylim=(0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0, 5.0))
        
  pass
  #Checking the result of Generator

  def plot_images(self, label) : 
    label_tensor = torch.zeros((10))
    label_tensor[label] = 1.0

    f, axarr = plt.subplots(2,3,figsize = (16,8))
    for i in range(2) : 
      for j in range(3) : 
       axarr[i,j].imshow(G.forward(generate_random_seed(100), label_tensor).detach().cpu().numpy().reshape(28,28), interpolation='none', cmap = 'Reds')
      pass
    pass
  pass

In [None]:
#Check a Generator
G = Generator()
random_label_0 = generate_random_one_hot(10)
output = G.forward(generate_random_seed(100),random_label_0)
print(output.shape,output)
img = output.detach().numpy().reshape(28,28)
plt.imshow(img, interpolation = 'none', cmap = 'Reds')

In [None]:
#Training both Generator and Discriminator for pattern
D = Discriminator()
G = Generator()

epochs = 10

for epoch in range(epochs) : 
  print("epoch = ", epoch+1)
  for label, image_data_tensor, label_tensor in mnist_dataset : 
    D.train(image_data_tensor, label_tensor,  torch.FloatTensor([1.0]))
    random_label = generate_random_one_hot(10)
    D.train(G.forward(generate_random_seed(100), random_label).detach(), random_label, torch.FloatTensor([0.0]))
    random_label = generate_random_one_hot(10)
    G.train(D, generate_random_seed(100), random_label, torch.FloatTensor([1.0]))

    pass
  pass

In [None]:
#Plotting the Generator Loss
D.plot_progress()
G.plot_progress()

In [None]:
G.plot_images(9)

In [None]:
#Check1
seed1 = generate_random_seed(100)
random_label1 = generate_random_one_hot(10)
out1 = G.forward(seed1,random_label1)
img1 = out1.detach().numpy().reshape(28,28)
plt.imshow(img1, interpolation='none', cmap='Blues')

In [None]:
#Check2
seed2 = generate_random_seed(100)
random_label2 = generate_random_one_hot(10)
out2 = G.forward(seed2, random_label2)
img2 = out2.detach().numpy().reshape(28,28)
plt.imshow(img2, interpolation='none', cmap='Blues')

In [None]:
#Check3
count = 0

# plot a 3 column, 2 row array of generated images
f, axarr = plt.subplots(3,4, figsize=(16,8))
for i in range(3):
    for j in range(4):
        seed = seed1 + (seed2 - seed1)/11 * count
        random_label3 = generate_random_one_hot(10)
        output = G.forward(seed, random_label3)
        img = output.detach().numpy().reshape(28,28)
        axarr[i,j].imshow(img, interpolation='none', cmap='Blues')
        count = count + 1
        pass
    pass