In [None]:
!pip install -q -U keras-tuner
!pip install git+https://github.com/qubvel/classification_models.git

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import sys
sys.path.append('/content/drive/MyDrive/COMPARISON_SYSTEM/')

In [None]:
# Import Module
from vgg16_module import *
from resxext_module import *
from resnet50_module import *
# from 현규_모듈 import *

In [None]:
import os
import numpy as np
import pandas as pd
from tensorflow import keras

In [None]:
from keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split

# **CLASS: Comparison System**

In [None]:
class ComparisonSystem:
  # 데이터 및 초기값
  def __init__(self, X, y, num_class=10, epochs=16, rows=32, cols=32):
    self.X = X
    self.y = y

    self.NUM_CLASS = num_class
    self.EPOCHS = epochs

    self.BATCH_SIZE = 64
    self.RANDOM_SEED = 1023

    self.ROWS = rows
    self.COLS = cols
    self.IMG_DIM = (self.ROWS, self.COLS, 3)


  # Data split: Train:Valid:Test = 6:2:2
  def dataSplit(self):
    self.X = self.X.astype(float)/255

    self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(self.X, self.y, test_size=int(len(self.X)*0.4), random_state=RANDOM_SEED)
    self.X_valid, self.X_test, self.y_valid, self.y_test = train_test_split(self.X_test, self.y_test, test_size=0.5)

  
  # Module Class Setting
  def setSystem(self):
    self.moduleList = []

    self.moduleList.append(ModuleVGG16(self.X, self.y, self.NUM_CLASS, self.IMG_DIM, self.BATCH_SIZE, self.EPOCHS, self.RANDOM_SEED))
    self.moduleList.append(myResNext(self.X, self.y, self.NUM_CLASS, self.BATCH_SIZE, self.EPOCHS, self.RANDOM_SEED))
    self.moduleList.append(modelResnet50(self.X, self.y, self.NUM_CLASS, self.IMG_DIM, self.BATCH_SIZE, self.EPOCHS, self.RANDOM_SEED))
    # self.moudleList.append(현규모델)


  # Running Module: VGG16 - ResNext - ResNet50 - DenseNet
  def runSystem(self):
    self.result = []

    for md in self.moduleList:
      md.preprocessing()
      md.inference()
      mdResult = md.preditResult()

      # Add Running result - ALL MODULE
      self.result.extend(mdResult)

    return self.result


  # Sorting Method
  def sortF1(self, e):
    return -e['f1']

  def sortACC(self, e):
    return -e['accuracy']

  def sortIT(self, e):
    return e['inference_time']


  # Sorting Result: F1, Accuracy, Inference Time
  def sortResult(self):
    self.resultF1 = sorted(self.result, key=self.sortF1)
    self.resultACC = sorted(self.result, key=self.sortACC)
    self.resultIT = sorted(self.result, key=self.sortIT)


  # Getting Best Result
  def getResult(self, standard='all', num=0):
    if standard == 'all': # default: print F1, ACC, IT
      self.bestResult = []
      self.bestResult.append(self.resultF1[0])
      self.bestResult.append(self.resultACC[0])
      self.bestResult.append(self.resultIT[0])

      bestdf = pd.DataFrame (self.bestResult)
      return bestdf

    elif standard == 'f1': # print F1 model
      return self.resultF1[num]

    elif standard == 'accuracy': # print ACC model
      return self.resultACC[num]

    elif standard == 'inference time': # print IF model
      return self.resultIT[num]


  # Get Model what you want
  def getModule(self, module, num=0):
    if module == 'vgg16':
      return self.modueList[0].getModel(num)
    elif module == 'resnext':
      return self.modueList[1].getModel(num)
    elif module == 'resnet':
      return self.modueList[2].getModel(num)
    # elif moule == 'densenet':
    #   return self.modueList[3].getModel(num)

# **CLASS: Image Reszing**
- https://koos808.tistory.com/42

In [None]:
from PIL import Image
import os

class ResizeImage:
  def __init__(self, raw_path, data_path, rows, cols):
    self.raw_path = raw_path # 원본 이미지 경로
    self.data_path = data_path # 저장할 이미지 경로

    # 원본 이미지 경로 내 폴더들 list
    self.token_list = os.listdir(raw_path)

    self.ROWS = rows
    self.COLS = cols

  def resizing(self):
    for token in self.token_list:
      image_path = self.raw_path + token + '/'
      save_path = self.data_path + token + '/'

    # 저장할 경로 없으면 생성
    if not os.path.exists(save_path):
        os.mkdir(save_path)

    #원본 이미지 경로의 모든 이미지 list 지정
    data_list = os.listdir(image_path)
    print(len(data_list))

    # 모든 이미지 resize 후 저장하기
    for name in data_list:
        im = Image.open(image_path + name)

        im = im.resize((self.ROWS, self.COLS))

        im = im.convert('RGB')
        im.save(save_path + name)
    print('end ::: ' + token)

# **CLASS: Visualize history**

In [None]:
import matplotlib.pyplot as plt

class show_graph:
  def __init__(self, history):
    self.accuracy = history['history']['accuracy']
    self.val_accuracy = history['history']['val_accuracy']
    self.loss = history['history']['loss']
    self.val_loss = history['history']['val_loss']

  def show_accuracy(self):
    plt.plot(self.accuracy)
    plt.plot(self.val_accuracy)
    plt.title('Model accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend(['Train', 'Valid'], loc='upper left')
    plt.show()

  def show_loss(self):
    plt.plot(self.loss)
    plt.plot(self.val_loss)
    plt.title('Model loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend(['Train', 'Valid'], loc='upper left')
    plt.show()

# **Use System with CIFAR-10**

In [None]:
from keras.datasets import cifar10

if __name__ == '__main__':
  (X_train, Y_train), (X_test, Y_test) = cifar10.load_data()
  X = np.concatenate((X_train,X_test))
  y = np.concatenate((Y_train,Y_test))

  cSystem = ComparisonSystem(X, y, num_class=10, epochs=16, rows=32, cols=32)

  cSystem.setSystem()
  cSystem.runSystem()
  
  cSystem.sortResult()

  all = cSystem.getResult()

  print('\n\n###------------- ALL --------------###\n')
  print(all)

  f1Result = cSystem.getResult('f1')
  accResult = cSystem.getResult('accuracy')
  itResult = cSystem.getResult('inference time')

  print('\n\n###------------- F1 --------------###\n')
  print(f1Result)
  print('\n\n###------------- ACCURACY --------------###\n')
  print(accResult)
  print('\n\n###------------- INFERENCE TIME --------------###\n')
  print(itResult)

  graph = show_graph(f1Result)
  graph.show_accuracy()
  graph.show_loss()

**참고 자료**
- https://koos808.tistory.com/42
