<a href="https://colab.research.google.com/github/Abhilash11Addanki/DeepLearning/blob/main/Custom%20Models/ResNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.python.keras.utils.vis_utils import plot_model

In [2]:
class IdentityBlock(tf.keras.Model):
  def __init__(self,filters,kernel_size):
    super(IdentityBlock,self).__init__()
    self.conv1 = tf.keras.layers.Conv2D(filters,kernel_size,padding='same')
    self.bn1 = tf.keras.layers.BatchNormalization()
    self.conv2 = tf.keras.layers.Conv2D(filters,kernel_size,padding='same')
    self.bn2 = tf.keras.layers.BatchNormalization()
    self.act = tf.keras.layers.Activation('relu')
    self.add = tf.keras.layers.Add()

  def call(self,input_tensor):
    x = self.conv1(input_tensor)
    x = self.bn1(x)
    x = self.act(x)
    x = self.conv2(x)
    x = self.bn2(x)
    x = self.act(x)
    x = self.add([x,input_tensor])
    x = self.act(x)
    return x

In [3]:
class ResNet(tf.keras.Model):
  def __init__(self,num_classes):
    super(ResNet,self).__init__()
    self.conv = tf.keras.layers.Conv2D(64,7,padding='same')
    self.bn = tf.keras.layers.BatchNormalization()
    self.act = tf.keras.layers.Activation('relu')
    self.max_pool = tf.keras.layers.MaxPool2D((3,3))
    self.id1 = IdentityBlock(64,3)
    self.id2 = IdentityBlock(64,3)
    self.global_pool = tf.keras.layers.GlobalAveragePooling2D()
    self.classifier = tf.keras.layers.Dense(num_classes,activation='softmax')

  def call(self,inputs):
    x = self.conv(inputs)
    x = self.bn(x)
    x = self.act(x)
    x = self.max_pool(x)
    x = self.id1(x)
    x = self.id2(x)
    x = self.global_pool(x)
    return self.classifier(x)

In [4]:
def preprocess(features):
  return tf.cast(features['image'],tf.float32)/255,features['label']

resnet = ResNet(10)
resnet.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
dataset = tfds.load('mnist',split=tfds.Split.TRAIN)
dataset = dataset.map(preprocess).batch(32)
history = resnet.fit(dataset,epochs = 1)



In [5]:
resnet.summary()

Model: "res_net"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              multiple                  3200      
_________________________________________________________________
batch_normalization (BatchNo multiple                  256       
_________________________________________________________________
activation (Activation)      multiple                  0         
_________________________________________________________________
max_pooling2d (MaxPooling2D) multiple                  0         
_________________________________________________________________
identity_block (IdentityBloc multiple                  74368     
_________________________________________________________________
identity_block_1 (IdentityBl multiple                  74368     
_________________________________________________________________
global_average_pooling2d (Gl multiple                  0   