# Using the model class to simplify complex architectures
***
In this lab, you will continue exploring Model subclassing by building a more complex architecture.

[Residual Networks](https://arxiv.org/abs/1512.03385) make use of skip connections to make deep models easier to train.

- There are branches as well as many repeating blocks of layers in this type of network.
- You can define a model class to help organize this mode complex code, and to make t easier to re-use your code when building the model.
- As before, you will inherit from the [Model Class](https://keras.io/api/models/model/) so that you can make use of the other built-in methods that Keras provides.

### Imports
***

In [2]:
try:
  %tensorflow_version 2.x
except Exception:
  pass

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.layers import Layer
from tensorflow.keras.utils import plot_model

### Implement Model subclasses
***
As show in the lectures, you will first implement the identity Block wich contains the skip connections (i.e the add() operations below). This will also inherit the Model class and implement the ```__init__()``` and ```call()``` methods.

In [4]:
class IdentityBlock(tf.keras.Model):
  
  def __init__(self, filters, kernel_size):

    super(IdentityBlock, self).__init__(name='')

    self.conv1 = tf.keras.layers.Conv2D(filters, kernel_size, padding='same')
    self.bn1 = tf.keras.layers.BatchNormalization()

    self.conv2 = tf.keras.layers.Conv2D(filters, kernel_size, padding='same')
    self.bn2 = tf.keras.layers.BatchNormalization()

    self.act = tf.keras.layers.Activation('relu')
    self.add = tf.keras.layers.Add()
  
  def call(self, input_tensor):

    x = self.conv1(input_tensor)
    x = self.bn1(x)
    x = self.act(x)

    x = self.conv2(x)
    x = self.bn2(x)
    x = self.act(x)

    x = self.add([x, input_tensor])
    x = self.act(x)

    return x

From there, you can build the rest of the ResNet model;
- You will call your ```IdentityBlock``` class two times below and that takes care of inserting those blocks of layers into this network.

In [5]:
class ResNet(tf.keras.Model):
  
  def __init__(self, num_classes):
    super(ResNet, self).__init__()
    self.conv = tf.keras.layers.Conv2D(filters=64, kernel_size=7, padding='same')
    self.bn = tf.keras.layers.BatchNormalization()
    self.act = tf.keras.layers.Activation('relu')
    self.max_pool = tf.keras.layers.MaxPool2D((3, 3))

    self.id1a = IdentityBlock(filters=64, kernel_size=3)
    self.id1b = IdentityBlock(filters=64, kernel_size=3)

    self.global_pool = tf.keras.layers.GlobalAveragePooling2D()
    self.classifier = tf.keras.layers.Dense(units=num_classes, activation='softmax')
  
  def call(self, inputs):
    x = self.conv(inputs)
    x = self.bn(x)
    x = self.act(x)
    x = self.max_pool(x)

    x = self.id1a(x)
    x = self.id1b(x)

    x = self.global_pool(x)

    return self.classifier(x)

### Training the model
***
As mentioned before, inheriting the Model class allows you to make use of the other APIs that Keras provides, such as:

- training
- serialization
- evaluation

You can instantiate as ResNet object and train it as usual like below:
**Note**: If you have issues with training in the Coursera lab environment, you can also run this in Colab.

In [6]:
def preprocess(features):

  return tf.cast(features['image'], tf.float32) / 255., features['label']

In [None]:
resnet = ResNet(num_classes=10)
resnet.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

In [9]:
dataset = tfds.load('mnist', split=tfds.Split.TRAIN, data_dir='./data')
dataset = dataset.map(preprocess).batch(32)

resnet.fit(dataset, epochs=1)

[1mDownloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to ./data/mnist/3.0.1...[0m


local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.



HBox(children=(FloatProgress(value=0.0, description='Dl Completed...', max=4.0, style=ProgressStyle(descriptio…



[1mDataset mnist downloaded and prepared to ./data/mnist/3.0.1. Subsequent calls will reuse this data.[0m


<tensorflow.python.keras.callbacks.History at 0x7f57aa592390>