<a href="https://colab.research.google.com/github/JHyunjun/torch_GAN/blob/main/ConditionalGAN_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#Created by Hyunjun JANG
#training GAN to check the performance with simple pattern

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
import random

csv_file = '/content/sample_data/mnist_test.csv'

class MnistDataset(Dataset):
    
    def __init__(self, csv_file):
        self.data_df = pd.read_csv(csv_file, header=None)
        pass
    
    def __len__(self):
        return len(self.data_df)
    
    def __getitem__(self, index):
        
        label = self.data_df.iloc[index,0]
        target = torch.zeros((10))
        target[label] = 1.0
        
        
        image_values = torch.FloatTensor(self.data_df.iloc[index,1:].values) / 255.0
        
        
        return label, image_values, target
    
    def plot_image(self, index):
        img = self.data_df.iloc[index,1:].values.reshape(28,28)
        plt.title("label = " + str(self.data_df.iloc[index,0]))
        plt.imshow(img, interpolation='none', cmap='Reds')
        pass
    
    pass


mnist_dataset = MnistDataset(csv_file)
print(np.shape(mnist_dataset))

In [None]:
mnist_dataset.plot_image(17)

In [None]:
#Discriminator is learning about generate_real as true pattern and generate_random is false pattern

Data_size = 784+10
class Discriminator(nn.Module):
    
    def __init__(self):
        # initialise parent pytorch class
        super().__init__()
        
        # define neural network layers
        self.model = nn.Sequential(
            nn.Linear(784+10, 200),
            nn.LeakyReLU(0.02),
            nn.LayerNorm(200),
            nn.Linear(200, 1),
            nn.Sigmoid()
        )
        
        # create loss function
        self.loss_function = nn.BCELoss()

        # create optimiser, simple stochastic gradient descent
        self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)

        # counter and accumulator for progress
        self.counter = 0;
        self.progress = []

        pass
    
    
    def forward(self, image_tensor, label_tensor):
        # simply run model
        inputs = torch.cat((image_tensor, label_tensor))
        return self.model(inputs)
    
    
    def train(self, inputs, label_tensor, targets):
        # calculate the output of the network
        outputs = self.forward(inputs, label_tensor)
        
        # calculate loss
        loss = self.loss_function(outputs, targets)

        # increase counter and accumulate error every 10
        self.counter += 1;
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())
            pass
        if (self.counter % 5000 == 0):
            print("counter = ", self.counter)
            pass

        # zero gradients, perform a backward pass, update weights
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()

        pass
    
    
    def plot_progress(self):
        df = pd.DataFrame(self.progress, columns=['Discriminator loss'])
        df.plot(ylim=(0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0, 5.0))
        pass
    
    pass

In [None]:
def generate_random(size) :
  random_data = torch.rand(size)
  return random_data

def generate_random_seed(size):
    random_data = torch.randn(size)
    return random_data  

def generate_random_image(size) : 
  random_data = torch.rand(size)
  return random_data

def generate_random_one_hot(size) : 
  label_tensor = torch.zeros((size))
  random_idx = random.randint(0, size-1)
  label_tensor[random_idx] = 1.0
  return label_tensor

In [None]:
#Training the Discriminator
D = Discriminator()
targets = torch.FloatTensor([1.0])
non_target = torch.FloatTensor([0.0])

for label, image_data_tensor, label_tensor in mnist_dataset : 
  D.train(image_data_tensor, label_tensor, targets)
  D.train(generate_random_image(784), generate_random_one_hot(10), non_target)

  pass

In [None]:
#Checking the performance of Discriminator
D.plot_progress()

for i in range(4):
  label, image_data_tensor, label_tensor = mnist_dataset[random.randint(0,50)]
  print( D.forward( image_data_tensor, label_tensor ).item() )
  pass

for i in range(4):
  print( D.forward( generate_random_image(784), generate_random_one_hot(10) ).item() )
  pass

In [None]:
# Generator 클래스

class Generator(nn.Module):
    
    def __init__(self):
        # 파이토치 부모 클래스 초기화
        super().__init__()
        
        # 신경망 레이어 정의
        self.model = nn.Sequential(
            nn.Linear(100+10, 200),
            nn.LeakyReLU(0.02),

            nn.LayerNorm(200),

            nn.Linear(200, 784),
            nn.Sigmoid()
        )
        
        # 옵티마이저 생성
        self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)

        # 진행 측정을 위한 변수 초기화
        self.counter = 0;
        self.progress = []
        
        pass
    
    
    def forward(self, seed_tensor, label_tensor):        
        # 시드와 레이블 연결
        inputs = torch.cat((seed_tensor, label_tensor))
        return self.model(inputs)


    def train(self, D, inputs, label_tensor, targets):
        # 신경망 출력 계산
        g_output = self.forward(inputs, label_tensor)
        
        # 판별기에 값 전달
        d_output = D.forward(g_output, label_tensor)
        
        # 오차 계산
        loss = D.loss_function(d_output, targets)

        # 매 10회마다 에러를 누적하고 카운터를 증가
        self.counter += 1;
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())
            pass

        # 기울기를 초기화 하고 역전파 후 가중치 갱신
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()

        pass
    
    def plot_images(self, label):
        label_tensor = torch.zeros((10))
        label_tensor[label] = 1.0
        # plot a 3 column, 2 row array of sample images
        f, axarr = plt.subplots(2,3, figsize=(16,8))
        for i in range(2):
            for j in range(3):
                axarr[i,j].imshow(G.forward(generate_random_seed(100), label_tensor).detach().cpu().numpy().reshape(28,28), interpolation='none', cmap='Blues')
                pass
            pass
        pass
    
    def plot_progress(self):
        df = pd.DataFrame(self.progress, columns=['loss'])
        df.plot(ylim=(0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0, 5.0))
        pass
    
    pass

In [None]:
#Check a Generator
G = Generator()
random_label_0 = generate_random_one_hot(10)
output = G.forward(generate_random_seed(100),generate_random_one_hot(10))
print(output.shape,output)
img = output.detach().numpy().reshape(28,28)
plt.imshow(img, interpolation = 'none', cmap = 'Reds')

In [None]:
#Training both Generator and Discriminator for pattern
D = Discriminator()
G = Generator()


In [None]:
epochs = 12

for epoch in range(epochs) : 
  print("epoch = ", epoch+1)
  for label, image_data_tensor, label_tensor in mnist_dataset : 
    D.train(image_data_tensor, label_tensor,  torch.FloatTensor([1.0]))
    #generate_random = generate_random_seed(100)
    random_label = generate_random_one_hot(10)
    D.train(G.forward(generate_random_seed(100), random_label).detach(), random_label, torch.FloatTensor([0.0]))
    #D.train(G.forward(generate_random, random_label).detach(), random_label, torch.FloatTensor([0.0]))
    random_label = generate_random_one_hot(10)
    G.train(D, generate_random_seed(100), random_label, torch.FloatTensor([1.0]))
    #G.train(D, generate_random, random_label, torch.FloatTensor([1.0]))

    pass
  pass

In [None]:
#Plotting the Generator Loss
D.plot_progress()
G.plot_progress()

In [None]:
G.plot_images(9)

In [None]:
G.plot_images(4)

In [None]:
G.plot_images(1)