<a href="https://colab.research.google.com/github/ShimilSBabu/Tensorflow-Model-Sub-Classing-API-Training/blob/main/resnet_imitation_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [19]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.layers import Layer, Conv2D, BatchNormalization, Activation, Add, MaxPool2D, GlobalAveragePooling2D, Dense
from tensorflow.keras import Model

In [12]:
class conv_block(Model):
  def __init__(self, filters, kernel_size = 3):
    super(conv_block, self).__init__()
    
    self.conv = Conv2D(filters, kernel_size, padding = 'same')
    self.batch_norm = BatchNormalization()

  def call(self, input_tensor, training = False):
    
    x = self.conv(input_tensor)
    x = self.batch_norm(x, training = training)
    x = tf.nn.relu(x)

    return x

In [14]:
class ResBlock(Layer):
  def __init__(self, channels):
    super(ResBlock, self).__init__()
    self.conv_1 = conv_block(channels[0])
    self.conv_2 = conv_block(channels[1])
    self.conv_3 = conv_block(channels[2])
    self.pooling = MaxPool2D()
    self.identity_mapping = Conv2D(channels[1], 3, padding = 'same')

  def call(self, input_tensor, training = False):
    x = self.conv_1(input_tensor, training = training)
    x = self.conv_2(x, training = training)
    x = self.conv_3(
        x + self.identity_mapping(input_tensor), training = training        
    )

    return self.pooling(x)


In [27]:
class resnet_imitation(Model):
  def __init__(self, num_classes = 10):
    super(resnet_imitation, self).__init__()
    self.block_1 = ResBlock([32, 32, 64])
    self.block_2 = ResBlock([128, 128, 256])
    self.block_3 = ResBlock([128 ,256 , 512])
    self.pool = GlobalAveragePooling2D()
    self.classifier = Dense(num_classes)

  def call(self, input_tensor, training = False):
    x = self.block_1(input_tensor, training = training)
    x = self.block_2(x, training = training)
    x = self.block_3(x, training = training)
    x = self.pool(x)
    
    return self.classifier(x)

  def model(self):
    x = tf.keras.Input(shape = (28, 28, 1))
    return Model(inputs = [x], outputs = self.call(x))

In [28]:
model = resnet_imitation(num_classes = 10)

In [29]:
model.compile(
    optimizer = tf.keras.optimizers.Adam(),
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True),
    metrics = ['Accuracy']
)

In [21]:
def preprocess(features):
  return tf.cast(features['image'], tf.float32) / 255., features['label']

In [None]:
dataset = tfds.load('mnist', split = tfds.Split.TRAIN)
dataset = dataset.map(preprocess).batch(32)

In [None]:
model.fit(dataset, epochs = 20)

In [None]:
# model.fit(x_train, y_train, batch_siz = 64, epochs = 20, verbose = 2)

In [30]:
model.model().summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 28, 28, 1)]       0         
                                                                 
 res_block_9 (ResBlock)      (None, 14, 14, 64)        28896     
                                                                 
 res_block_10 (ResBlock)     (None, 7, 7, 256)         592512    
                                                                 
 res_block_11 (ResBlock)     (None, 3, 3, 512)         2364032   
                                                                 
 global_average_pooling2d_2   (None, 512)              0         
 (GlobalAveragePooling2D)                                        
                                                                 
 dense_2 (Dense)             (None, 10)                5130      
                                                             