<a href="https://colab.research.google.com/github/ShimilSBabu/Tensorflow-Model-Sub-Classing-API-Training/blob/main/resnet_imitation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Importing Libs

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.layers import Layer, Conv2D, BatchNormalization, Activation, Add, MaxPool2D, GlobalAveragePooling2D, Dense
from tensorflow.keras import Model

In [2]:
class identity_block(Model):
  def __init__(self, filters, kernel_size):
    super(identity_block, self).__init__()
    
    self.conv_1 = Conv2D(filters, kernel_size, padding = 'same')
    self.batch_norm_1 = BatchNormalization()

    self.conv_2 = Conv2D(filters, kernel_size, padding = 'same')
    self.batch_norm_2 = BatchNormalization()

    self.activation = Activation('relu')
    self.add = Add()

  def call(self, input_tensor):
    
    x = self.conv_1(input_tensor)
    x = self.batch_norm_1(x)
    x = self.activation(x)

    x = self.conv_2(input_tensor)
    x = self.batch_norm_2(x)

    x = self.add([x, input_tensor])

    x = self.activation(x)

    return x

In [3]:
class resnet_imitation(Model):
  def __init__(self, num_classes):
    super(resnet_imitation, self).__init__()
    self.conv = Conv2D(64, 7, padding = 'same')
    self.batch_norm = BatchNormalization()
    self.activation = Activation('relu')
    self.maxpool = MaxPool2D((3, 3))

    self.id_block_1 = identity_block(64, 3)
    self.id_block_2 = identity_block(64, 3)

    self.global_pool = GlobalAveragePooling2D()
    self.classifier = Dense(num_classes, activation = 'softmax')

  def call(self, inputs):
    x = self.conv(inputs)
    x = self.batch_norm(x)
    x = self.activation(x)
    x = self.maxpool(x)
    x = self.id_block_1(x)
    x = self.id_block_2(x)
    x = self.global_pool(x)
    
    return self.classifier(x)

In [4]:
def preprocess(features):
  return tf.cast(features['image'], tf.float32) / 255., features['label']

In [5]:
resnet = resnet_imitation(10)
resnet.compile(optimizer = 'adam', loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'])
dataset = tfds.load('mnist', split = tfds.Split.TRAIN)
dataset = dataset.map(preprocess).batch(32)
resnet.fit(dataset, epochs = 20)

Downloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to ~/tensorflow_datasets/mnist/3.0.1...


Dl Completed...:   0%|          | 0/4 [00:00<?, ? file/s]

Dataset mnist downloaded and prepared to ~/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.
Epoch 1/20




Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.callbacks.History at 0x7faea4080f50>