<a href="https://colab.research.google.com/github/JHyunjun/torch_GAN_AutoEncoder/blob/main/DC_GAN_AE(image).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Code Maker : Hyunjun, JANG (KOR)
# https://github.com/JHyunjun
# DC-GAN AE(Deep Convolutional - Generative Adversarial Network AutoEncoder) for Transistor Image Anomaly Detection
# Image Copy right : https://www.mvtec.com/company/research/datasets/mvtec-ad

In [None]:
#Image preprocessing

from PIL import Image
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torchvision
from torchvision import transforms
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import pandas as pd

torch.cuda.manual_seed_all(7)

target_img_size = 50

trans = transforms.Compose([transforms.Resize((target_img_size,target_img_size)),transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])

trainset = torchvision.datasets.ImageFolder(root = "/content/drive/MyDrive/Colab Notebooks/Data/img/anomaly_transistor/for_coding", transform = trans) 
classes = trainset.classes
classes

In [None]:
print("trainset.shape : ",trainset) #[16,2]

In [None]:
trainloader = DataLoader(trainset, batch_size = len(trainset) , shuffle = False)

In [None]:
dataiter = iter(trainloader)
images, labels = dataiter.next()
#images = images.transpose(1,2)
#images = images.transpose(2,3) #[16, 100, 100, 3]
print(images.shape)
for i in range(len(trainset)) : 
  print( i, "image is ", labels[i]) # 0 is abnormal, 1 is normal

In [None]:
'''
# Image Check
def imshow(img) : 
  img = img / 2 + 0.5
  np_img = img.numpy()
  print("np_img : ",np_img.shape) #[3,206,818]
  plt.imshow(np.transpose(np_img, (1,2,0)))

  print(np_img.shape)
  print((np.transpose(np_img, (1,2,0))).shape)

print(images.shape)
imshow(torchvision.utils.make_grid(images, nrow = 8))
'''

In [None]:
#Normalizing
for i in range(len(images)) : 
  images[i] = images[i] / 2 + 0.5
  pass

print(images[2].shape) #[100,100,3]

In [None]:
filters = 16
latent_space = 1
kernel_size = 5
padding_size = 2


class Discriminator(nn.Module):
    
    def __init__(self):
        # initialise parent pytorch class
        super().__init__()
        self.a = nn.Conv2d(in_channels = 3, out_channels = filters, kernel_size = kernel_size, padding = padding_size, padding_mode = 'zeros')
        self.b = nn.ReLU()
        self.c = nn.Conv2d(in_channels = filters, out_channels = filters, kernel_size = kernel_size, padding = padding_size)
        self.d = nn.ReLU()
        self.e = nn.Conv2d(in_channels = filters , out_channels = filters, kernel_size = kernel_size, padding = padding_size)
        self.f = nn.ReLU()
        self.g = nn.Linear(filters * target_img_size * target_img_size, 1)

        self.sigmoid = nn.Sigmoid()
        
        # create loss function
        #W-GAN
        self.loss_function = nn.BCELoss()
        #self.loss_function = nn.MSELoss()

        # create optimiser, simple stochastic gradient descent
        #self.optimiser = torch.optim.Adam(self.parameters(), lr = 1e-5)
        # Wasserstain
        self.optimiser = torch.optim.RMSprop(self.parameters(), lr = 5e-4)

        # counter and accumulator for progress
        self.counter = 0;
        self.progress = []
        pass
    
    def forward(self, inputs):
        #simply run model
        #print("D_input : ", inputs)
        x = self.a(inputs)
        #print("A", x.shape)
        x = self.b(x)
        #print("B", x.shape)
        #x = self.norm(x)
        #print("B_norm", x.shape)
        x = self.c(x)
        #print("C", x.shape)
        x = self.d(x)
        #print("D", x.shape)
        x = self.e(x)
        x = self.f(x)
        #print("F", x.shape)
        x = torch.flatten(x)
        #print("Flatten : ",x.shape)
        x = self.g(x)
        #x = self.sigmoid(x)
        #print("D_outputs : ", x)

        return x

    def train(self, inputs, targets, gens):
    #def train(self, inputs, targets):
        # calculate the output of the network
        #print("Start")
        outputs = self.forward(inputs)
        #print("After D_self.forward(inputs)")
        # Wasserstein-GAN
        targets_w = self.forward(gens)
        #print("After self.forward(gens)")

        # calculate loss
        #loss = self.loss_function(outputs, targets)
        #print("After Calculating D loss")

        # Wasserstein loss
        loss = -(torch.mean(outputs) - torch.mean(targets_w)) # - 1 * torch.mean(inputs - gens)

        # increase counter and accumulate error every 10
        self.counter += 1;
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())
            pass
        if (self.counter % 1000 == 0):
            print("counter = ", self.counter)
            pass

        # zero gradients, perform a backward pass, update weights
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()
        
        #W-GAN
        for p in self.parameters() : 
          p.data.clamp_(-0.01,0.01)
        pass
    
    def plot_progress(self):
        df = pd.DataFrame(self.progress, columns=['Discriminator loss'])
        df.plot(ylim=(0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0, 5.0))
        pass
    
    pass

In [None]:
# Constructing Generator for AutoEncoder Structure

filters = 16
latent_space = 1
kernel_size = 5
padding_size = 2

class Generator(nn.Module) : 
  def __init__(self) : 
    super().__init__()
    self.a = nn.Conv2d(in_channels = 3, out_channels = filters, kernel_size = kernel_size, padding = padding_size, padding_mode = 'zeros')
    self.b = nn.ReLU()
    self.c = nn.Conv2d(in_channels = filters, out_channels = filters, kernel_size = kernel_size, padding = padding_size)
    self.d = nn.ReLU()
    self.e = nn.Conv2d(in_channels = filters , out_channels = filters, kernel_size = kernel_size, padding = padding_size)
    self.f = nn.ReLU()
    self.g = nn.Linear(filters * target_img_size * target_img_size, target_img_size * target_img_size * 3)

    self.sigmoid = nn.Sigmoid()
    #View((1, train_gan.shape[1], 1))   
    #self.norm = nn.LayerNorm(self.hidden_dim )

    #self.optimiser = torch.optim.Adam(self.parameters(), lr = 1e-5)
    #Wasserstain
    self.optimiser = torch.optim.RMSprop(self.parameters(), lr = 5e-3)
    self.counter = 0
    self.progress = []

    pass

  def forward(self, inputs) : 
    #print("G_inputs. shape : ",inputs) #numpyarray ([seq_length, features])
    x = self.a(inputs)
    #x = self.norm(x)
    #print("a. shape after norm : ",x.shape)
    x = self.b(x)
    #print("b. shape : ",x.shape)
    x = self.c(x)
    #x = self.norm(x)
    #print("c. shape after norm : ",x.shape)
    x = self.d(x)
    #print("d. shape : ",x.shape)
    x = self.e(x) 
    #x = self.norm(x)
    #print("e. shape after norm : ",x.shape)
    x = self.f(x) 
    x = torch.flatten(x)
    #print("f. shape : ",x.shape)
    x = self.g(x)
    #x = self.sigmoid(x)
    #print("g. shape : ",x.shape)
    x = x.reshape(3, target_img_size, target_img_size) # Should be [seq_length, features]
    #print("G_outputs : ",x)
    return x

  def train(self, D, inputs, targets) : 
    g_output = self.forward(inputs)
    d_output = D.forward(g_output)
    #loss = D.forward(g_output)
    #loss = D.loss_function(d_output, targets)

    #Wasserstain
    loss = -torch.mean(d_output) #- 0.005 * torch.mean(inputs - g_output)

    self.counter+=1;
    if (self.counter % 10 == 0) :
      self.progress.append(loss.item())
      pass

    self.optimiser.zero_grad()
    loss.backward()
    self.optimiser.step()

  def plot_progress(self):
        df = pd.DataFrame(self.progress, columns=['Generator loss'])
        df.plot(ylim=(0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0, 5.0))
        
  pass

In [None]:
epoch0 = 1
epoch1 = 10 # number of normal data
epoch2 = 50

D = Discriminator()
G = Generator()

for x in range(epoch0) : 

  for i in range(epoch1) : 
    refs = images[i+6] # Abnormal image : 0~5
    #refs = refs.reshape(1,32,32,3)
    #print(refs.shape) #[100,100,3]
    for j in range(epoch2) : 
      D.train(refs, torch.FloatTensor([1.0]), G.forward(refs))
      #D.train(refs, torch.FloatTensor([1.0]))
      G.train(D, refs, torch.FloatTensor([1.0]))
    pass
  print("Percentage : ",((x+1)/epoch0)*100,"%")
  pass
  pass

In [None]:
#Plotting the Generator Loss
D.plot_progress()
G.plot_progress()

In [None]:
plt.figure(figsize = (25,12))

for i in range(1, 11) : 
  # Original
  test = images[i+6-1]
  test1 = test
  test1 = test1.transpose(0,1)
  test1 = test1.transpose(1,2)
  ax = plt.subplot(3, 10, i)
  plt.imshow(test1)
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
 
  # Reconstructed
  out = G.forward(test)
  out1 = out
  out1 = out1.transpose(0,1)
  out1 = out1.transpose(1,2)
  out1 = out1.detach().cpu().numpy()
  ax = plt.subplot(3, 10, i + 10)
  plt.imshow(out1)
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)

  # Error
  Error = test - out
  Error = torch.clamp(Error, min = 0, max = 1)
  Error = Error.transpose(0,1)
  Error = Error.transpose(1,2)
  Error = Error.detach().cpu().numpy()
  ax = plt.subplot(3, 10, i + 20)
  plt.imshow(Error)
  ax.get_xaxis().set_visible(False) 
  ax.get_yaxis().set_visible(False)

plt.show()

In [None]:
# Normal

plt.figure(figsize = (25,12))
Error_list_normal = np.zeros(6)

for i in range(1, 7) : 
  # Original
  test = images[i+6-1]
  test1 = test
  test1 = test1.transpose(0,1)
  test1 = test1.transpose(1,2)
  ax = plt.subplot(3, 6, i)
  plt.imshow(test1)
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
 
  # Reconstructed
  out = G.forward(test)
  out1 = out
  out1 = out1.transpose(0,1)
  out1 = out1.transpose(1,2)
  out1 = out1.detach().cpu().numpy()
  ax = plt.subplot(3, 6, i + 6)
  plt.imshow(out1)
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)

  # Error
  Error = test - out
  Error = torch.clamp(Error, min = 0, max = 1)
  Error_list_normal[i-1] = torch.mean(Error)
  Error = Error.transpose(0,1)
  Error = Error.transpose(1,2)
  Error = Error.detach().cpu().numpy()
  ax = plt.subplot(3, 6, i + 12)
  plt.imshow(Error)
  ax.get_xaxis().set_visible(False) 
  ax.get_yaxis().set_visible(False)

print(Error_list_normal)
plt.show()

In [None]:
# Abnormal
Error_list_abnormal = np.zeros(6)
plt.figure(figsize = (25,12))

for i in range(1, 7) : 
  # Original
  test = images[i]
  test1 = test
  test1 = test1.transpose(0,1)
  test1 = test1.transpose(1,2)
  ax = plt.subplot(3, 6, i)
  plt.imshow(test1)
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
 
  # Reconstructed
  out = G.forward(test)
  out1 = out
  out1 = out1.transpose(0,1)
  out1 = out1.transpose(1,2)
  out1 = out1.detach().cpu().numpy()
  ax = plt.subplot(3, 6, i + 6)
  plt.imshow(out1)
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)

  # Error
  Error = test - out
  Error = torch.clamp(Error, min = 0, max = 1)
  Error_list_abnormal[i-1] = torch.mean(Error)
  Error = Error.transpose(0,1)
  Error = Error.transpose(1,2)
  Error = Error.detach().cpu().numpy()
  ax = plt.subplot(3, 6, i + 12)
  plt.imshow(Error)
  ax.get_xaxis().set_visible(False) 
  ax.get_yaxis().set_visible(False)

print(Error_list_abnormal)
plt.show()

In [None]:
print("Normal : ", Error_list_normal)
print("Abnormal : ", Error_list_abnormal)

x = range(0,6)
plt.plot(x, Error_list_normal,'ro')
plt.plot(x, Error_list_abnormal, 'bo')