<a href="https://colab.research.google.com/github/LucasLU-ZY/dlaicourse/blob/master/model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import tensorflow as tf
import math
from tensorflow.keras import layers, Sequential
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
BEST_MODEL_PATH_TEST = None

In [4]:
class Evaluate(tf.keras.callbacks.Callback):
    '''
    Save the model if it is better after each epoch
    '''

    def __init__(self, model):
        '''
        Initialization
            :param model:
                Model that we are training
        '''
        super().__init__()
        # Initialize the loss with a big value
        self.lowest = 1e10
        self.model = model

    def on_epoch_end(self, epoch, logs=None):
        '''
        Override function on_epoch_end to do the save job after an epoch,
        if loss is lower than save the model
            :param epoch:
                The epoch we are now
            :param logs:
                The information in this epoch
            :return:
                None
        '''
        if logs['loss'] <= self.lowest:
            self.lowest = logs['loss']
            self.model.save(BEST_MODEL_PATH_TEST)

In [5]:
class BasicBlock(layers.Layer):
    '''
    Create a basic residual network
    '''
    def __init__(self, filter_num, stride = (1, 1)):
        super(BasicBlock, self).__init__()

        self.conv1 = layers.Conv2D(filter_num, kernel_size=(3, 3), strides=stride, padding='same')
        self.bn1 = layers.BatchNormalization()
        self.relu = layers.Activation('relu')

        self.conv2 = layers.Conv2D(filter_num, kernel_size=(3, 3), strides=(1, 1), padding='same')
        self.bn2 = layers.BatchNormalization()
        if stride != (1, 1):
            # if stride is not 1, we use down sampling
            self.downsample = Sequential()
            self.downsample.add(layers.Conv2D(filter_num, kernel_size=(1, 1), strides=stride))
        else:
            self.downsample = lambda x: x

    def call(self, inputs, training=None):
        '''
        Overwrite function call, when the class is called, execute this function to create the Basic Block
            :param inputs:
                Source of inputs
            :param training:
                If Training
            :return:
                Output of the Basic Block
        '''
        # [b, h, w, c]
        # Connect all the layers of Basic Block
        out = self.conv1(inputs)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = layers.Activation("relu")(out)
        identity = self.downsample(inputs)

        output = layers.add([out, identity])

        return output

In [6]:
class ResNet(tf.keras.Model):
    '''
    Create a Residual NetWork
    '''
    def __init__(self, layer_dims, num_classes=4):
        '''
        Initialize the Residual NetWork
            :param layer_dims:
                list of shape (1, 4) represent 4 dimensons of residual block
            :param num_classes:
                Number of classes to be classified
        '''
        super(ResNet, self).__init__()

        self.stem = Sequential([layers.Conv2D(64, (3, 3), strides=(1, 1)),
                                layers.BatchNormalization(),
                                layers.Activation('relu'),
                                layers.MaxPool2D(pool_size=(2, 2), strides=(1, 1), padding='same')
        ])
        self.layer1 = self.build_resblock(32, layer_dims[0])
        self.layer2 = self.build_resblock(64, layer_dims[1])
        self.layer3 = self.build_resblock(128, layer_dims[2])
        self.layer4 = self.build_resblock(256, layer_dims[3])

        self.drop = layers.Dropout(0.25)
        self.avgpool = layers.GlobalAveragePooling2D()
        self.flatten = layers.Flatten()
        self.fc = layers.Dense(num_classes, activation="softmax")

    def call(self, inputs, training=None, mask=None):
        '''
        Overwrite function call, when the class is called, execute this function to create the Residual Block
            :param inputs:
                Source of inputs
            :param training:
                If Training
            :param mask:
                We don't use this here
            :return:
                Output of the Basic Block
        :return:
        '''
        x = self.stem(inputs)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.drop(x)
        # [b, c]
        x = self.avgpool(x)
        # [b, 100]
        x = self.flatten(x)
        x = self.fc(x)

        return x

    def build_resblock(self, filter_num, blocks, stride=1):
        '''
        A function to build Residual Block
            :param filter_num:
                Filter number of Basic Block
            :param blocks:
                Number of blocks we will create
            :param stride:
                Identify the stride of each Basic Block
            :return:
                Residual Block
        '''
        res_block = Sequential()
        # may down sample
        res_block.add(BasicBlock(filter_num, stride))

        for _ in range(1, blocks):
            res_block.add(BasicBlock(filter_num, 1))
        return res_block

In [7]:
def resnet18():
    # Create ResNet 18
    return ResNet([2, 2, 2, 2])

def resnet34():
    # Create ResNet 34
    return ResNet([3, 4, 6, 3])