<a href="https://colab.research.google.com/github/JHyunjun/torch_2D-CNN/blob/main/2D_CNN_AE(AD)_image.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Code Maker : Hyunjun, JANG (KOR)
# https://github.com/JHyunjun
# DC-GAN AE(Deep Convolutional - Generative Adversarial Network AutoEncoder) for Transistor Image Anomaly Detection
# Image Copy right : https://www.mvtec.com/company/research/datasets/mvtec-ad

In [None]:
#Image preprocessing

from PIL import Image
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torchvision
from torchvision import transforms
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import pandas as pd

torch.cuda.manual_seed_all(7)

trans = transforms.Compose([transforms.Resize((32,32)),transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])

trainset = torchvision.datasets.ImageFolder(root = "/content/drive/MyDrive/Colab Notebooks/Data/img/anomaly_transistor/for_coding", transform = trans) 
classes = trainset.classes
classes

In [None]:
print("trainset.shape : ",trainset) #[16,2]

In [None]:
trainloader = DataLoader(trainset, batch_size = len(trainset) , shuffle = False)

In [None]:
dataiter = iter(trainloader)
images, labels = dataiter.next()
#images = images.transpose(1,2)
#images = images.transpose(2,3) #[16, 100, 100, 3]
print(images.shape)
for i in range(len(trainset)) : 
  print( i, "image is ", labels[i]) # 0 is abnormal, 1 is normal

In [None]:
'''
# Image Check
def imshow(img) : 
  img = img / 2 + 0.5
  np_img = img.numpy()
  print("np_img : ",np_img.shape) #[3,206,818]
  plt.imshow(np.transpose(np_img, (1,2,0)))

  print(np_img.shape)
  print((np.transpose(np_img, (1,2,0))).shape)

print(images.shape)
imshow(torchvision.utils.make_grid(images, nrow = 8))
'''

In [None]:
#Normalizing
for i in range(len(images)) : 
  images[i] = images[i] / 2 + 0.5
  pass

print(images[2].shape) #[100,100,3]

In [None]:
# 2D-CNN Network

filters = 8
latent_space = 2
kernel_size = 5
padding_size = 2

class CNN(nn.Module) : 
  def __init__(self) : 
    super().__init__()
    
    #Input이 (?, 28, 28, 1)인 경우
    self.a = nn.Conv2d(in_channels = 3, out_channels = filters, kernel_size = kernel_size, padding = padding_size, padding_mode = 'zeros')
    self.b = nn.LeakyReLU(0.02)
    self.c = nn.Conv2d(in_channels = filters, out_channels = latent_space, kernel_size = kernel_size, padding = padding_size)
    self.d = nn.LeakyReLU(0.02)
    self.e = nn.Conv2d(in_channels = latent_space , out_channels = filters, kernel_size = kernel_size, padding = padding_size)
    self.f = nn.ReLU()
    self.g = nn.Linear(filters * 32 * 32, 32 * 32 * 3)

    self.loss_function = nn.MSELoss()
    self.optimiser = torch.optim.Adam(self.parameters(), lr = 1e-3)
    self.progress = []
    pass
    
  def forward(self, inputs) :
    #print("inputs : ",inputs.shape)
    x = self.a(inputs)
    #print("A : ",x.shape)
    x = self.b(x)
    #print("B : ",x.shape)
    x = self.c(x)
    #print("C : ",x.shape)
    x = self.d(x)
    #print("D : ",x.shape)
    x = self.e(x)
    #print("E : ",x.shape)
    x = self.f(x)
    #print("F : ",x.shape)
    x = torch.flatten(x)
    #print("after flatten : ",x.shape)
    x = self.g(x)
    #print("G : ",x.shape)
    x = x.reshape(3, 32, 32)
    #print("Output : ",x.shape)
    return x

  def train(self,inputs) : 
    self.optimiser.zero_grad()
    outputs = self.forward(inputs)
    loss = self.loss_function(inputs,outputs)
    loss.backward()
    self.optimiser.step()

  def plot_progress(self) : 
    df = pd.DataFrame(self.progress, columns = ['2D CNN AE Loss'])
    df.plot(ylim=(0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0, 5.0))
    pass

In [None]:
D = CNN()

In [None]:
epoch1 = 10 # number of normal data
epoch2 = 10

for i in range(epoch1) : 
  refs = images[i+6] # Abnormal image : 0~5
  #refs = refs.reshape(1,32,32,3)
  #print(refs.shape) #[100,100,3]
  for j in range(epoch2) : 
    D.train(refs)
  pass
  print("Percentage : ",(i/epoch1)*100,"%")
pass
  



In [None]:
plt.figure(figsize = (25,12))

for i in range(1, 11) : 
  # Original
  test = images[i+6-1]
  test1 = test
  test1 = test1.transpose(0,1)
  test1 = test1.transpose(1,2)
  ax = plt.subplot(3, 10, i)
  plt.imshow(test1)
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
 
  # Reconstructed
  out = D.forward(test)
  out1 = out
  out1 = out1.transpose(0,1)
  out1 = out1.transpose(1,2)
  out1 = out1.detach().cpu().numpy()
  ax = plt.subplot(3, 10, i + 10)
  plt.imshow(out1)
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)

  # Error
  Error = test - out
  Error = torch.clamp(Error, min = 0, max = 1)
  Error = Error.transpose(0,1)
  Error = Error.transpose(1,2)
  Error = Error.detach().cpu().numpy()
  ax = plt.subplot(3, 10, i + 20)
  plt.imshow(Error)
  ax.get_xaxis().set_visible(False) 
  ax.get_yaxis().set_visible(False)

plt.show()

In [None]:
plt.figure(figsize = (25,12))

for i in range(1, 7) : 
  # Original
  test = images[i]
  test1 = test
  test1 = test1.transpose(0,1)
  test1 = test1.transpose(1,2)
  ax = plt.subplot(3, 6, i)
  plt.imshow(test1)
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
 
  # Reconstructed
  out = D.forward(test)
  out1 = out
  out1 = out1.transpose(0,1)
  out1 = out1.transpose(1,2)
  out1 = out1.detach().cpu().numpy()
  ax = plt.subplot(3, 6, i + 6)
  plt.imshow(out1)
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)

  # Error
  Error = test - out
  Error = torch.clamp(Error, min = 0, max = 1)
  Error = Error.transpose(0,1)
  Error = Error.transpose(1,2)
  Error = Error.detach().cpu().numpy()
  ax = plt.subplot(3, 6, i + 12)
  plt.imshow(Error)
  ax.get_xaxis().set_visible(False) 
  ax.get_yaxis().set_visible(False)

plt.show()