In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, regularizers
from tensorflow.keras.datasets import mnist
import os

physical_devices = tf.config.list_physical_devices("GPU")
tf.config.experimental.set_memory_growth(physical_devices[0], True)

(x_train,y_train),(x_test,y_test) = mnist.load_data()
x_train = x_train.reshape(-1,28,28,1).astype('float32')/255.0
x_test = x_test.reshape(-1,28,28,1).astype('float32')/255.0

In [7]:
class CNNBlock(layers.Layer):
    def __init__(self,out_channels,kernel_size=3):
        super().__init__()
        self.conv = layers.Conv2D(out_channels,kernel_size,padding='same')
        self.bn = layers.BatchNormalization()
        
#The _call_ method of the parent class layers.Layer calls the call() function
#batch norm waork differently in training mode and inference mode
    def call(self, input_tensor, training=False):
        x = self.conv(input_tensor)
        x = self.bn(x,training=training)
        x = tf.nn.relu(x)
        return x
        
# model = keras.Sequential(
#     [
#         CNNBlock(32),
#         CNNBlock(64),
#         CNNBlock(128),
#         layers.Flatten(),
#         layers.Dense(10)
#     ]
# )

class ResBlock(layers.Layer):
    def __init__(self,channels):
        super().__init__()
        self.cnn1 = CNNBlock(channels[0])
        self.cnn2 = CNNBlock(channels[1])
        self.cnn3 = CNNBlock(channels[2])
        self.pooling = layers.MaxPooling2D()
        self.identity_mapping = layers.Conv2D(channels[1],kernel_size=1,padding='same')
        
    def call(self, input_tensor, training=False):
        x = self.cnn1(input_tensor,training=training)
        x = self.cnn2(x,training=training)
        x = self.cnn3(x+self.identity_mapping(input_tensor),training=training)
        return self.pooling(x)
        
class ResNet_Like(keras.Model):
    def __init__(self, num_classes=10):
        super().__init__()
        self.block1 = ResBlock([32,32,64])
        self.block2 = ResBlock([128,128,256])
        self.block3 = ResBlock([128,256,512])
        self.pool = layers.GlobalAveragePooling2D()
        self.classifier = layers.Dense(num_classes)
        
    def call(self, input_tensor, training=False):
        x = self.block1(input_tensor,training=training)
        x = self.block2(x,training=training)
        x = self.block3(x,training=training)
        x = self.pool(x)
        return self.classifier(x)
    
    def model(self):
        x = keras.Input(shape=(28,28,1))
        return keras.Model(inputs=[x],outputs=self.call(x))

model = ResNet_Like()
model.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
             optimizer=keras.optimizers.Adam(),
             metrics=['accuracy'])

model.fit(x_train,y_train,batch_size=64,epochs=1)
# model.summary()
model.model().summary()
model.evaluate(x_test,y_test,batch_size=64)

Model: "model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
res_block_9 (ResBlock)       (None, 14, 14, 64)        28640     
_________________________________________________________________
res_block_10 (ResBlock)      (None, 7, 7, 256)         526976    
_________________________________________________________________
res_block_11 (ResBlock)      (None, 3, 3, 512)         1839744   
_________________________________________________________________
global_average_pooling2d_3 ( (None, 512)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 10)                5130      
Total params: 2,400,490
Trainable params: 2,397,418
Non-trainable params: 3,072
_____________________________________________

[0.1462131291627884, 0.9527000188827515]