In [13]:
from tensorflow import Tensor
from tensorflow.keras.layers import Input,Conv2D,ReLU,BatchNormalization,Add,AveragePooling2D,Flatten,Dense
from tensorflow.keras.models import Model

In [31]:
def relu_bn(inputs:Tensor) ->Tensor:
    relu=ReLU()(inputs)
    bn=BatchNormalization()(relu)
    return bn

def residual_block(x:Tensor,downsample:bool,filters:int,kernal_size:int=3) ->Tensor:
    y=Conv2D(kernal_size=kernal_size,
            stride=(1 if not downsample else 2),
            filters=filters,
            padding='same')(x)
    y=relu_bn(y)
    
    y=Conv2D(kernal_size=kernal_size,
            stride=(1 if not downsample else 2),
            filters=filters,
            padding='same')(y)
    
    if downsample:
        x = Conv2D(kernel_size=1,
                   strides=2,
                   filters=filters,
                   padding="same")(x)
    y=Add()([x,y])
    out=relu_bn(y)
    return out


In [None]:
def create_res_net():
    inputs=Input(shape=(32,32,3))
    
    num_filters=64
    
    y=BatchNormalization()(inputs)
    y=Conv2D(kernal_size=3,
            stride=1,
            filters=filters,
            padding="same")(y)
    
    y=relu_bn(y)
    
    num_blocks_list=[2,5,5,2]
    
    for i in range(len(num_blocks_list)):
        num_blocks=num_blocks_list[i]
        for j in range(num_blocks):
            y=residual_block(y,(j==0 and i!=0),filters=num_filters)
        num_filters*=2 
        
    y=AveragePooling(4)(y)
    y=Flatten()(y)
    
    outputs=Dense(10,activation='softmax')(y)
    
    model=Model(inputs,outputs)
    model.compile(optimizer='adam',
                 loss='sparse_categorical_crossentropy',
                 matrix=['accuracy'])
    return model

In [None]:
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
import datetime
import os

(x_train, y_train), (x_test, y_test) = cifar10.load_data()

model = create_res_net() # or create_plain_net()
model.summary()

timestr = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
name = 'cifar-10_res_net_30-'+timestr # or 'cifar-10_plain_net_30-'+timestr

checkpoint_path = "checkpoints/"+name+"/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
os.system('mkdir {}'.format(checkpoint_dir))

# save model after each epoch
cp_callback = ModelCheckpoint(
    filepath=checkpoint_path,
    verbose=1
)
tensorboard_callback = TensorBoard(
    log_dir='tensorboard_logs/'+name,
    histogram_freq=1
)

model.fit(
    x=x_train,
    y=y_train,
    epochs=20,
    verbose=1,
    validation_data=(x_test, y_test),
    batch_size=128,
    callbacks=[cp_callback, tensorboard_callback]
)