In [1]:
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout, BatchNormalization, Activation, add, Flatten, AveragePooling2D, concatenate, Dense
from tensorflow.keras.models import Model

from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

width = 32
height = 32

In [2]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [3]:
num_classes = 10
y_train = tf.keras.utils.to_categorical(y_train)
y_test = tf.keras.utils.to_categorical(y_test)

In [4]:
validation_images, validation_labels = x_train[:500], y_train[:500]
train_images, validation_labels = x_train[500:], y_train[500:]

In [5]:
train_images.shape

(49500, 32, 32, 3)

In [8]:
def inception(x, filters):
    pre_layer = x
    f1, f2, f3, f4 = filters

    conv1 = Conv2D(f1, kernel_size = (1, 1), padding = 'same', activation = 'relu')(pre_layer)

    conv2 = Conv2D(f4, kernel_size = (1, 1), padding = 'same', activation = 'relu')(pre_layer)
    conv2 = Conv2D(f2, kernel_size = (3, 3), padding = 'same', activation = 'relu')(conv2)

    conv3 = Conv2D(f4, kernel_size = (1, 1), padding = 'same', activation = 'relu')(pre_layer)
    conv3 = Conv2D(f3, kernel_size = (5, 5), padding = 'same', activation = 'relu')(conv3)

    max_pool = MaxPooling2D(pool_size = (3, 3), strides = 1, padding = 'same')(pre_layer)
    max_pool = Conv2D(f4, kernel_size = (1, 1), padding = 'same')(max_pool)

    concat = concatenate([conv1, conv2, conv3, max_pool], axis = -1)
    return concat

In [9]:
input_shape = x_train[0].shape
inputs = Input(shape = input_shape)

x = Conv2D(64, kernel_size = (7, 7), strides = 2, padding = 'same', activation = 'relu')(inputs)
x = BatchNormalization()(x)
x = Conv2D(192, kernel_size = (3, 3), padding = 'same', activation = 'relu')(x)
x = BatchNormalization()(x)

x = inception(x, [64, 128, 32, 32])
x = inception(x, [128, 192, 96, 64])
x = MaxPooling2D(pool_size = (3, 3), strides = 2, padding = 'same')(x)

x = inception(x, [192, 208, 48, 64])

aux1 = AveragePooling2D(pool_size = (5, 5), strides = 3, padding = 'valid')(x)
aux1 = Conv2D(128, kernel_size = (1, 1), padding = 'same', activation = 'relu')(x)
aux1 = Flatten()(aux1)
aux1 = Dense(512, activation = 'relu')(aux1)
aux1 = Dense(10, activation = 'softmax')(aux1)

x = inception(x, [160, 224, 64, 4])

x = inception(x, [120, 256, 64, 64])
x = inception(x, [112, 288, 64, 64])

aux2 = AveragePooling2D(pool_size = (5, 5), strides = 3, padding = 'valid')(x)
aux2 = Conv2D(128, kernel_size = (1, 1), padding = 'same', activation = 'relu')(x)
aux2 = Flatten()(aux2)
aux2 = Dense(512, activation = 'relu')(aux2)
aux2 = Dense(10, activation = 'softmax')(aux2)

x = inception(x, [256, 320, 128, 128])
x = inception(x, [256, 320, 128, 128])
x = inception(x, [384, 384, 128, 128])

x = AveragePooling2D(pool_size = (4, 4), padding = 'valid')(x)   # padding 안쓰겠다는 뜻
x = Dropout(0.4)(x)
x = Flatten()(x)

output = Dense(10, activation = 'softmax')(x)

model = Model(inputs = inputs, outputs = [aux1, aux2, output])
model.summary()



Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_3 (InputLayer)           [(None, 32, 32, 3)]  0           []                               
                                                                                                  
 conv2d_10 (Conv2D)             (None, 16, 16, 64)   9472        ['input_3[0][0]']                
                                                                                                  
 batch_normalization_4 (BatchNo  (None, 16, 16, 64)  256         ['conv2d_10[0][0]']              
 rmalization)                                                                                     
                                                                                                  
 conv2d_11 (Conv2D)             (None, 16, 16, 192)  110784      ['batch_normalization_4[0][0]

                                                                                                  
 conv2d_35 (Conv2D)             (None, 8, 8, 64)     6464        ['conv2d_34[0][0]']              
                                                                                                  
 conv2d_36 (Conv2D)             (None, 8, 8, 4)      2052        ['max_pooling2d_5[0][0]']        
                                                                                                  
 concatenate_4 (Concatenate)    (None, 8, 8, 452)    0           ['conv2d_31[0][0]',              
                                                                  'conv2d_33[0][0]',              
                                                                  'conv2d_35[0][0]',              
                                                                  'conv2d_36[0][0]']              
                                                                                                  
 conv2d_38

                                                                  'conv2d_60[0][0]',              
                                                                  'conv2d_61[0][0]']              
                                                                                                  
 conv2d_63 (Conv2D)             (None, 8, 8, 128)    106624      ['concatenate_8[0][0]']          
                                                                                                  
 conv2d_65 (Conv2D)             (None, 8, 8, 128)    106624      ['concatenate_8[0][0]']          
                                                                                                  
 max_pooling2d_10 (MaxPooling2D  (None, 8, 8, 832)   0           ['concatenate_8[0][0]']          
 )                                                                                                
                                                                                                  
 conv2d_62