Deep Residual Learning for Image Recognition

In [9]:
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.layers as layers

In [10]:
def Identity_shortcut(x,filters,strides = 1):
    shortcut = x
    x = layers.Conv2D(filters = filters, kernel_size = 3, strides = strides, padding = 'same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Conv2D(filters = filters, kernel_size = 3, padding = 'same')(x)
    x = layers.BatchNormalization()(x)
    if strides != 1: # zero-padding
        shortcut = layers.MaxPool2D(pool_size = 2, strides=2)(shortcut)
        shortcut = tf.pad(shortcut, tf.constant([[0, 0,], [0, 0,], [0, 0,], [0, filters//2]]), "CONSTANT")
    x = layers.Add()([shortcut,x])
    x = layers.ReLU()(x)
    return x

def Projection_shortcut(x,filters,strides = 1):
    shortcut = x
    x = layers.Conv2D(filters = filters, kernel_size = 3, strides = strides, padding = 'same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Conv2D(filters = filters, kernel_size = 3, padding = 'same')(x)
    x = layers.BatchNormalization()(x)
    if strides != 1: # 1x1 convolution
        shortcut = layers.Conv2D(filters = filters, kernel_size = 1, strides = strides, 
                                padding = 'same')(shortcut)
        shortcut = layers.BatchNormalization()(shortcut)
    
    x = layers.Add()([shortcut,x])
    x = layers.ReLU()(x)
    return x

def resnet_18(shape = (224,224,3)):
    inputs = layers.Input(shape = shape)
    x = layers.Conv2D(filters = 64, kernel_size = (7,7),strides = 2,padding = 'same')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.MaxPool2D(pool_size = (3,3), strides = 2, padding = 'same')(x)
    
    x = Identity_shortcut(x,64)
    x = Identity_shortcut(x,64)
    
    x = Projection_shortcut(x,128,2)
    x = Identity_shortcut(x,128)
    
    x = Projection_shortcut(x,256,2)
    x = Identity_shortcut(x,256)
    
    x = Projection_shortcut(x,512,2)
    x = Identity_shortcut(x,512)
    
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Flatten()(x)
    outputs = layers.Dense(1000, activation='softmax')(x)
    
    model = keras.Model(inputs=inputs,outputs=outputs)
    
    model.compile(optimizer = 'adam',loss = 'sparse_categorical_crossentropy',
        metrics=['accuracy'])
    
    return model

In [11]:
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
import datetime
import os

model = resnet_18(shape=(224,224,3)) # or create_plain_net()
model.summary()


(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

model.fit(
    x=x_train,
    y=y_train,
    epochs=20,
    verbose=1,
    validation_data=(x_test, y_test),
    batch_size=128,
    callbacks=[cp_callback, tensorboard_callback]
)

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_5 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
conv2d_46 (Conv2D)              (None, 112, 112, 64) 9472        input_5[0][0]                    
__________________________________________________________________________________________________
batch_normalization_46 (BatchNo (None, 112, 112, 64) 256         conv2d_46[0][0]                  
__________________________________________________________________________________________________
re_lu_38 (ReLU)                 (None, 112, 112, 64) 0           batch_normalization_46[0][0]     
____________________________________________________________________________________________

'\n(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n\nmodel.fit(\n    x=x_train,\n    y=y_train,\n    epochs=20,\n    verbose=1,\n    validation_data=(x_test, y_test),\n    batch_size=128,\n    callbacks=[cp_callback, tensorboard_callback]\n)'