In [None]:
class CIFAR_CNN():
  def __init__(self, x, y):
    self.x_train, self.y_train, self.x_valid, self.y_valid = self.Split_Dataset(x, y)
    self.model = self.CIFAR_Model()
  
  # splid data into training and validation set
  def Split_Dataset(self, x, y):
    # random shuffle
    index = np.arange(x.shape[0]) # create an integer list from 0 ~ t_len
    np.random.shuffle(index)
    x = x[index]
    y = y[index]

    # seperate temp set to training and validation set
    total_length = x.shape[0]
    training_length = int(x.shape[0]*0.9)
    x_train = x[:training_length]
    y_train = y[:training_length]
    x_valid = x[training_length:total_length]
    y_valid = y[training_length:total_length]

    return x_train, y_train, x_valid, y_valid

  # construct CNN model for CIFAR-10
  def CIFAR_Model(self, output_size=10):
    # model construction
    model = Sequential()

    # first layer (conv), input_shape=(32, 32, 3)
    model.add(Conv2D(64, (3,3), padding='same', activation='relu'))
    model.add(Conv2D(64, (3,3), padding='same', activation='relu'))
    model.add(MaxPooling2D(pool_size = (2,2)))

    # second layer (conv)
    model.add(Conv2D(128, (3,3), padding='same', activation='relu'))
    model.add(Conv2D(128, (3,3), padding='same', activation='relu'))
    model.add(MaxPooling2D(pool_size = (2,2)))
    
    # third layer (conv)
    model.add(Conv2D(128, (3,3), padding='same', activation='relu'))
    model.add(Conv2D(128, (3,3), padding='same', activation='relu'))
    model.add(MaxPooling2D(pool_size = (2,2)))
    model.add(Flatten())
    
    # fourth layer (fully)
    model.add(Dense(256, activation='relu'))
    model.add(Dropout(0.2))

    # output layer (fully)
    model.add(Dense(output_size, activation='softmax'))
    
    # compile model
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    
    return model
  
  # training process
  def Training(self, epoch=50):
    self.history = self.model.fit(self.x_train, self.y_train, epochs=epoch, batch_size=32, validation_data=(self.x_valid, self.y_valid))

  # plot learning information
  def Plot(self):
    # plot learning curve
    plt.plot(self.history.history['loss'], color='blue', label='training loss')
    plt.title('Learning Curve')
    plt.xlabel('epochs')
    plt.ylabel('cross-entropy loss')
    plt.legend()
    plt.show()

    plt.plot(self.history.history['accuracy'], color='blue', label='training accuracy')
    plt.plot(self.history.history['val_accuracy'], color='red', label='validation accuracy')
    plt.title('Accuracy')
    plt.xlabel('epochs')
    plt.ylabel('Accuracy rate')
    plt.legend()
    plt.show()