<a href="https://colab.research.google.com/github/JHyunjun/torch_GAN/blob/main/torch_GAN_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#Created by Hyunjun JANG
#training GAN to check the performance with simple pattern

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
import random

csv_file = '/content/sample_data/mnist_test.csv'

class MnistDataset(Dataset):
    
    def __init__(self, csv_file):
        self.data_df = pd.read_csv(csv_file, header=None)
        pass
    
    def __len__(self):
        return len(self.data_df)
    
    def __getitem__(self, index):
        # 이미지 목표(레이블)
        label = self.data_df.iloc[index,0]
        target = torch.zeros((10))
        target[label] = 1.0
        
        # 0-255의 이미지를 0-1로 정규화
        image_values = torch.FloatTensor(self.data_df.iloc[index,1:].values) / 255.0
        
        # 레이블, 이미지 데이터 텐서, 목표 텐서 반환
        return label, image_values, target
    
    def plot_image(self, index):
        img = self.data_df.iloc[index,1:].values.reshape(28,28)
        plt.title("label = " + str(self.data_df.iloc[index,0]))
        plt.imshow(img, interpolation='none', cmap='Reds')
        pass
    
    pass


mnist_dataset = MnistDataset(csv_file)
print(np.shape(mnist_dataset))

In [None]:
#Discriminator is learning about generate_real as true pattern and generate_random is false pattern

Data_size = 784
class Discriminator(nn.Module):
    
    def __init__(self):
        # initialise parent pytorch class
        super().__init__()
        
        # define neural network layers
        self.model = nn.Sequential(
            nn.Linear(Data_size, 200),
            nn.Sigmoid(),
            nn.Linear(200, 1),
            nn.Sigmoid()
        )
        
        # create loss function
        self.loss_function = nn.MSELoss()

        # create optimiser, simple stochastic gradient descent
        self.optimiser = torch.optim.SGD(self.parameters(), lr=0.01)

        # counter and accumulator for progress
        self.counter = 0;
        self.progress = []

        pass
    
    
    def forward(self, inputs):
        # simply run model
        return self.model(inputs)
    
    
    def train(self, inputs, targets):
        # calculate the output of the network
        outputs = self.forward(inputs)
        
        # calculate loss
        loss = self.loss_function(outputs, targets)

        # increase counter and accumulate error every 10
        self.counter += 1;
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())
            pass
        if (self.counter % 10000 == 0):
            print("counter = ", self.counter)
            pass

        # zero gradients, perform a backward pass, update weights
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()

        pass
    
    
    def plot_progress(self):
        df = pd.DataFrame(self.progress, columns=['loss'])
        df.plot(ylim=(0, 1.0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5))
        pass
    
    pass

In [None]:
def generate_random(size) :
  random_data = torch.rand(size)
  return random_data

def generate_random_seed(size):
    random_data = torch.randn(size)
    return random_data  

In [None]:
#Training the Discriminator
D = Discriminator()
targets = torch.FloatTensor([1.0])
non_target = torch.FloatTensor([0.0])

for label, image_data_tensor, target_tensor in mnist_dataset : 
  D.train(image_data_tensor, targets)
  D.train(generate_random(784), non_target)

  pass

In [None]:
#Checking the performance of Discriminator
D.plot_progress()

In [None]:
# Constructing Generator

class Generator(nn.Module) : 
  def __init__(self) : 
    super().__init__()

    self.model = nn.Sequential(
        nn.Linear(1,200),
        nn.Sigmoid(),
        nn.Linear(200,Data_size),
        nn.Sigmoid()
    )

    self.optimiser = torch.optim.SGD(self.parameters(), lr = 0.005)
    self.counter = 0
    self.progress = []

    pass

  def forward(self, inputs) : 
    return self.model(inputs)

  def train(self, D, inputs, targets) : 
    g_output = self.forward(inputs)
    d_output = D.forward(g_output)
    loss = D.forward(g_output)
    loss = D.loss_function(d_output, targets)

    self.counter+=1;
    if (self.counter % 10 == 0) :
      self.progress.append(loss.item())
      pass

    self.optimiser.zero_grad()
    loss.backward()
    self.optimiser.step()

  def plot_progress(self):
        df = pd.DataFrame(self.progress, columns=['loss'])
        df.plot(ylim=(0, 1.0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5))
        
  pass

In [None]:
#Check a Generator
G = Generator()
output = G.forward(generate_random(1))
img = output.detach().numpy().reshape(28,28)
plt.imshow(img, interpolation = 'none', cmap = 'Reds')

In [None]:
#Training both Generator and Discriminator for pattern
D = Discriminator()
G = Generator()

for label, image_data_tensor, target_tensor in mnist_dataset : 
  D.train(image_data_tensor, torch.FloatTensor([1.0]))
  D.train(G.forward(generate_random(1)).detach(),torch.FloatTensor([0.0]))

  G.train(D, generate_random(1), torch.FloatTensor([1.0]))
  
  pass

In [None]:
#Plotting the Generator Loss
D.plot_progress()
G.plot_progress()

In [None]:
#Checking the result of Generator

f, axarr = plt.subplots(2,3,figsize = (16,8))
for i in range(2) : 
  for j in range(3) : 
    output = G.forward(generate_random(1))
    img = output.detach().numpy().reshape(28,28)
    axarr[i,j].imshow(img, interpolation = 'none', cmap = 'Reds')
    pass

pass