<a href="https://colab.research.google.com/github/JHyunjun/torch_GAN/blob/main/torch_GAN_simplepattern.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#Created by Hyunjun JANG
#training GAN to check the performance with simple pattern

import torch
import torch.nn as nn

import pandas
import matplotlib.pyplot as plt
import random


noise = 0.1
#adding noise for redundancy
#Target Value
def generate_real() : 
  real_data = torch.FloatTensor([random.uniform(1-noise, 1+noise),
                                 random.uniform(1-noise, 1+noise),
                                 random.uniform(0-noise, 0+noise),
                                 random.uniform(1-noise, 1+noise),
                                 random.uniform(0-noise, 0+noise),
                                 random.uniform(0-noise, 0+noise)])
  return real_data




In [None]:
#Discriminator is learning about generate_real as true pattern and generate_random is false pattern

class Discriminator(nn.Module):
    
    def __init__(self):
        # initialise parent pytorch class
        super().__init__()
        
        # define neural network layers
        self.model = nn.Sequential(
            nn.Linear(6, 3),
            nn.Sigmoid(),
            nn.Linear(3, 1),
            nn.Sigmoid()
        )
        
        # create loss function
        self.loss_function = nn.MSELoss()

        # create optimiser, simple stochastic gradient descent
        self.optimiser = torch.optim.SGD(self.parameters(), lr=0.01)

        # counter and accumulator for progress
        self.counter = 0;
        self.progress = []

        pass
    
    
    def forward(self, inputs):
        # simply run model
        return self.model(inputs)
    
    
    def train(self, inputs, targets):
        # calculate the output of the network
        outputs = self.forward(inputs)
        
        # calculate loss
        loss = self.loss_function(outputs, targets)

        # increase counter and accumulate error every 10
        self.counter += 1;
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())
            pass
        if (self.counter % 10000 == 0):
            print("counter = ", self.counter)
            pass

        # zero gradients, perform a backward pass, update weights
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()

        pass
    
    
    def plot_progress(self):
        df = pandas.DataFrame(self.progress, columns=['loss'])
        df.plot(ylim=(0, 1.0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5))
        pass
    
    pass

In [None]:
def generate_random(size) :
  random_data = torch.rand(size)
  return random_data

In [None]:
#Training the Discriminator
D = Discriminator()
targets = torch.FloatTensor([1.0])
non_target = torch.FloatTensor([0.0])

for i in range(10000) : 
  D.train(generate_real(), targets)
  D.train(generate_random(6), non_target)

  pass

In [None]:
#Checking the performance of Discriminator
D.plot_progress()
print(D.forward(generate_real()).item())
print(D.forward(generate_random(6)).item())

In [None]:
# Constructing Generator

class Generator(nn.Module) : 
  def __init__(self) : 
    super().__init__()

    self.model = nn.Sequential(
        nn.Linear(1,3),
        nn.Sigmoid(),
        nn.Linear(3,6),
        nn.Sigmoid()
    )

    self.optimiser = torch.optim.SGD(self.parameters(), lr = 0.005)
    self.counter = 0
    self.progress = []

    pass

  def forward(self, inputs) : 
    return self.model(inputs)

  def train(self, D, inputs, targets) : 
    g_output = self.forward(inputs)
    d_output = D.forward(g_output)
    loss = D.forward(g_output)
    loss = D.loss_function(d_output, targets)

    self.counter+=1;
    if (self.counter % 10 == 0) :
      self.progress.append(loss.item())
      pass

    self.optimiser.zero_grad()
    loss.backward()
    self.optimiser.step()

  def plot_progress(self):
        df = pandas.DataFrame(self.progress, columns=['loss'])
        df.plot(ylim=(0, 1.0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5))
        
  pass

In [None]:
#Check a Generator
G = Generator()
G.forward(torch.FloatTensor([0.5]))

In [None]:
#Training both Generator and Discriminator for pattern
D = Discriminator()
G = Generator()
image_list = []
for i in range(10000) : 
  D.train(generate_real(), torch.FloatTensor([1.0]))
  D.train(G.forward(torch.FloatTensor([0.5])).detach(), torch.FloatTensor([0.0]))
  G.train(D, torch.FloatTensor([0.5]), torch.FloatTensor([1.0]))
  if (i % 1000 == 0):
    image_list.append( G.forward(torch.FloatTensor([0.5])).detach().numpy() )

  pass

In [None]:
#Plotting the Generator Loss

import numpy as np

G.plot_progress()
G.forward(torch.FloatTensor([0.5]))

plt.figure(figsize = (16,8))
plt.imshow(np.array(image_list).T, interpolation = 'none', cmap = 'Reds')