<a href="https://colab.research.google.com/github/aaalexlit/tf-advanced-techniques-spec/blob/main/course_1_custom_models/Week4_Implementing_ResNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Imports

In [16]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, \
Add, MaxPool2D, GlobalAveragePooling2D, Dense
from tensorflow.keras import Model
from tensorflow.keras.losses import SparseCategoricalCrossentropy

## First implement Identity Block
[Identity block image](https://www.notion.so/Week-4-3871c770f34e4f8ea38461fbbe046bbb#bce1979eb20f4556823653d3a8315990)

In [7]:
class IdentityBlock(Model):
  def __init__(self, n_filters, kernel_size):
    super(IdentityBlock, self).__init__(name='')

    self.conv1 = Conv2D(n_filters, kernel_size, padding='same')
    self.bn1 = BatchNormalization()
		
    self.conv2 = Conv2D(n_filters, kernel_size, padding='same')
    self.bn2 = BatchNormalization()

    self.act = Activation('relu')
    self.add = Add()

	
  def call(self, input_tensor):
    x = self.conv1(input_tensor)
    x = self.bn1(x)
    x = self.act(x)
		
    x = self.conv2(x)
    x = self.bn2(x)

    x = self.add([x, input_tensor])
    x = self.act(x)

    return x

## Then the whole Mini ResNet model

In [23]:
class ResNet(Model):
  def __init__(self, n_classes, from_logits=False):
   super(ResNet, self).__init__()
   self.conv = Conv2D(64, 7, padding='same')
   self.bn = BatchNormalization()
   self.act = Activation('relu')
   
   self.max_pool = MaxPool2D((3,3))
   self.id1a = IdentityBlock(64, 3)
   self.id1b = IdentityBlock(64, 3)
   
   self.global_pool = GlobalAveragePooling2D()
   
   if from_logits:
     self.classifier = Dense(n_classes)
   else:
     self.classifier = Dense(n_classes, activation='softmax')
     
  def call(self, inputs):
    x = self.conv(inputs)
    x = self.bn(x)
    x = self.act(x)
    x = self.max_pool(x)
    
    x = self.id1a(x)
    x = self.id1b(x)
    
    x = self.global_pool(x)
    output = self.classifier(x)
    return output

## Train the model on MNIST

In [24]:
def preprocess(features):
  return tf.cast(features['image'], tf.float32) / 255., features['label']

resnet = ResNet(10)
resnet.compile(optimizer='adam',
               loss=SparseCategoricalCrossentropy(),
               metrics=['accuracy'])

dataset = tfds.load('mnist', split=tfds.Split.TRAIN, data_dir='./data')
dataset = dataset.map(preprocess).batch(32)

val_dataset = tfds.load('mnist', split=tfds.Split.TEST, data_dir='./data')
val_dataset = val_dataset.map(preprocess).batch(32)

resnet.fit(dataset, 
           validation_data=val_dataset, 
           epochs=1)



<keras.callbacks.History at 0x7f324bd8c370>

In [25]:
resnet = ResNet(10, from_logits=True)
resnet.compile(optimizer='adam',
               loss=SparseCategoricalCrossentropy(from_logits=True),
               metrics=['accuracy'])

dataset = tfds.load('mnist', split=tfds.Split.TRAIN, data_dir='./data')
dataset = dataset.map(preprocess).batch(32)

val_dataset = tfds.load('mnist', split=tfds.Split.TEST, data_dir='./data')
val_dataset = val_dataset.map(preprocess).batch(32)

resnet.fit(dataset, 
           validation_data=val_dataset, 
           epochs=1)



<keras.callbacks.History at 0x7f324bb62310>