In [1]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import tensorflow_datasets as tfds

In [2]:
import urllib3
urllib3.disable_warnings()

#tfds.disable_progress_bar()   # 이 주석을 풀면 데이터셋 다운로드과정의 프로그레스바가 나타나지 않습니다.

(ds_train, ds_test), ds_info = tfds.load(
    'cifar10',
    split=['train', 'test'],
    shuffle_files=True,
    with_info=True,
)

In [3]:
def normalize_and_resize_img(image, label):
    """Normalizes images: `uint8` -> `float32`."""
    # image = tf.image.resize(image, [32, 32])
    return tf.cast(image, tf.float32) / 255., label

In [4]:
def apply_normalize_on_dataset(ds, is_test=False, batch_size=16):
    ds = ds.map(
        normalize_and_resize_img, 
        num_parallel_calls=1
    )
    ds = ds.batch(batch_size)
    if not is_test:
        ds = ds.repeat()
        ds = ds.shuffle(200)
    ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
    return ds

In [5]:
class SELayer(tf.keras.layers.Layer):
    def __init__(self,out_channel =64, r_ratio=4):
        super(SELayer,self).__init__()
        self.gap = tf.keras.layers.GlobalAveragePooling2D()
        self.fc1 = tf.keras.layers.Dense(out_channel//r_ratio)
        self.fc2 = tf.keras.layers.Dense(out_channel)
        self.output_channel = out_channel
        self.reshape = tf.keras.layers.Reshape((1,1,out_channel))
    
    def call(self,input_tensor,training = False):
        x = self.gap(input_tensor)
        x = self.fc1(x)
        x = tf.nn.relu(x)
        x = self.fc2(x)
        x = tf.nn.sigmoid(x)
        x = self.reshape(x)
        return x*input_tensor

In [6]:
class SEBlock(tf.keras.layers.Layer):
    def __init__(self,channel,strides=1,down_sample = False, r_ratio=4):
        super(SEBlock,self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(channel,kernel_size=3, strides = strides, padding='same')
        self.bn1 = tf.keras.layers.BatchNormalization()
        
        self.conv2 = tf.keras.layers.Conv2D(channel,kernel_size=3, strides =1 ,padding='same')
        self.bn2 = tf.keras.layers.BatchNormalization()
        
        self.se1 = SELayer(channel,r_ratio)
        
        self.downsample = downsampling
        
        if self.downsample:
            self.downsample_conv = tf.keras.layers.Conv2D(channel, kernel_size=1, strides=strides, padding='same')
            self.downsample_bn = tf.keras.layers.BatchNormalization()
                
    def call(self,input_tensor,training = False):
        x = self.conv1(input_tensor)
        x = self.bn1(x,training=training)
        x = tf.nn.relu(x)
        
        x = self.conv2(x)
        x = self.bn2(x, training=training)
        x = self.se1(x)
        
        if self.downsample:
            input_tensor = self.downsample_conv(input_tensor)
            input_tensor = self.downsample_bn(input_tensor, training= training)
                        
        return x + input_tensor

In [7]:
class SE_ResNet18(tf.keras.Model):
    def __init__(self,layer_count=2):
        super(SE_ResNet18,self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(64, kernel_size=7, strides = 2, padding='same')
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.maxpool1 = tf.keras.layers.MaxPool2D(pool_size=2, strides = 2, padding='same')
        
        self.seblock = tf.keras.models.Sequential()
        for i in range(1,5):
            for j in range(layer_count):
                if i==1:
                    self.seblock.add(SEBlock(64*i))
                else:
                    if j==0:
                        self.seblock.add(SEBlock(channel=64*i, strides=2,downsampling=True))
                    else:
                        self.seblock.add(SEBlock(64*i))

        self.gap = tf.keras.layers.GlobalAveragePooling2D()
        self.fc1 = tf.keras.layers.Dense(4096, activation='relu')
        self.fc1 = tf.keras.layers.Dense(4096, activation='relu')
        self.fc2 = tf.keras.layers.Dense(10, activation='softmax')
        
    def call(self,input_tensor, training = False):
        x = self.conv1(input_tensor)
        x = self.bn1(x , training=training)
        x = tf.nn.relu(x)
        x = self.maxpool1(x)
        
        x = self.seblock(x)
        
        x = self.gap(x)
        x = self.fc1(x)
        x = self.fc2(x)
        
        return x     

In [8]:
BATCH_SIZE = 256
EPOCH = 20

(ds_train, ds_test), ds_info = tfds.load(
    'cifar10',
    split=['train', 'test'],
    as_supervised=True,
    shuffle_files=True,
    with_info=True,
)
ds_train = apply_normalize_on_dataset(ds_train, batch_size=BATCH_SIZE)
ds_test = apply_normalize_on_dataset(ds_test, batch_size=BATCH_SIZE)

In [9]:
se_resnet18 = SE_ResNet18()

se_resnet18.compile(
    loss='sparse_categorical_crossentropy',
    optimizer=tf.keras.optimizers.SGD(lr=0.01, clipnorm=1.),
    metrics=['accuracy'],
)

history_se_resnet18 = se_resnet18.fit(
    ds_train,
    steps_per_epoch=int(ds_info.splits['train'].num_examples/BATCH_SIZE),
    validation_steps=int(ds_info.splits['test'].num_examples/BATCH_SIZE),
    epochs=EPOCH,
    validation_data=ds_test,
    verbose=1
)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
