<a href="https://colab.research.google.com/github/ReutFarkash/useful/blob/main/tensorflow_tutorial_8.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

[TensorFlow Tutorial 8 - Model Subclassing with Keras](https://www.youtube.com/watch?v=WcZ_1IAH_nM&ab_channel=AladdinPersson)<br>
Aladdin Persson

In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist

In [2]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(-1, 28, 28, 1).astype("float32") / 255.0 # duing the reshape to add the nimber of channels (1)
x_test = x_test.reshape(-1, 28, 28, 1).astype("float32") / 255.0

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [3]:
# CNN -> BatchNorm -> ReLU (common structure)
# x10
class CNNBlock(layers.Layer):
  def __init__(self, out_channels, kernel_size=3):
    super(CNNBlock, self).__init__()
    self.conv = layers.Conv2D(out_channels, kernel_size, padding='same')
    self.bn = layers.BatchNormalization()
  
  def call(self, input_tensor, training=False):
    x = self.conv(input_tensor)
    print(x.shape)
    x = self.bn(x, training=training)
    x = tf.nn.relu(x)
    return x

In [4]:
model = keras.Sequential(
    [
     CNNBlock(32),
     CNNBlock(64),
     CNNBlock(128),
     layers.Flatten(),
     layers.Dense(10),
    ]
)

In [5]:
model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)

model.fit(x_train, y_train, batch_size=64, epochs=3, verbose=2)
model.evaluate(x_test, y_test, batch_size=64, verbose=2)

Epoch 1/3
(None, 28, 28, 32)
(None, 28, 28, 64)
(None, 28, 28, 128)
(None, 28, 28, 32)
(None, 28, 28, 64)
(None, 28, 28, 128)
(None, 28, 28, 32)
(None, 28, 28, 64)
(None, 28, 28, 128)
938/938 - 11s - loss: 0.5878 - accuracy: 0.9463
Epoch 2/3
938/938 - 11s - loss: 0.0948 - accuracy: 0.9815
Epoch 3/3
938/938 - 11s - loss: 0.0354 - accuracy: 0.9896
(None, 28, 28, 32)
(None, 28, 28, 64)
(None, 28, 28, 128)
157/157 - 1s - loss: 0.0477 - accuracy: 0.9839


[0.04766438156366348, 0.9839000105857849]

In [6]:
class ResBlock(layers.Layer):
  def __init__(self, channels):
    super(ResBlock, self).__init__()
    self.cnn1 = CNNBlock(channels[0])
    self.cnn2 = CNNBlock(channels[1])
    self.cnn3 = CNNBlock(channels[2])
    self.pooling = layers.MaxPooling2D()
    self.identity_mapping = layers.Conv2D(channels[1], 1, padding='same')
  
  def call(self, input_tensor, training=False):
    x = self. cnn1(input_tensor, training=training)
    x = self.cnn2(x, training=training)
    x = self.cnn3(x + self.identity_mapping(input_tensor), training=training)
    return self.pooling(x)

In [10]:
class ResNet_Like(keras.Model):
  def __init__(self, num_classes=10):
    super(ResNet_Like, self).__init__()
    self.block1 = ResBlock([32, 32, 64])
    self.block2 = ResBlock([128, 128, 256])
    self.block3 = ResBlock([128, 256, 512])
    self.pool = layers.GlobalAveragePooling2D()
    self.classifier = layers.Dense(num_classes)
  
  def call(self, input_tensor, training=False):
    x = self.block1(input_tensor, training=training)
    x = self.block2(x, training=training)
    x = self.block3(x, training=training)
    x = self.pool(x)
    return self.classifier(x)
  
  def model(self): # so we can print the summay like we are used to 
    x = keras.Input(shape=(28, 28, 1))
    return keras.Model(inputs=[x], outputs=self.call(x))

In [12]:
model = ResNet_Like(num_classes=10)

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)

model.fit(x_train, y_train, batch_size=64, epochs=3, verbose=2)
print(model.model().summary())
model.evaluate(x_test, y_test, batch_size=64, verbose=2)

Epoch 1/3
(None, 28, 28, 32)
(None, 28, 28, 32)
(None, 28, 28, 64)
(None, 14, 14, 128)
(None, 14, 14, 128)
(None, 14, 14, 256)
(None, 7, 7, 128)
(None, 7, 7, 256)
(None, 7, 7, 512)
(None, 28, 28, 32)
(None, 28, 28, 32)
(None, 28, 28, 64)
(None, 14, 14, 128)
(None, 14, 14, 128)
(None, 14, 14, 256)
(None, 7, 7, 128)
(None, 7, 7, 256)
(None, 7, 7, 512)
938/938 - 22s - loss: 0.0853 - accuracy: 0.9749
Epoch 2/3
938/938 - 22s - loss: 0.0358 - accuracy: 0.9886
Epoch 3/3
938/938 - 22s - loss: 0.0277 - accuracy: 0.9912
(None, 28, 28, 32)
(None, 28, 28, 32)
(None, 28, 28, 64)
(None, 14, 14, 128)
(None, 14, 14, 128)
(None, 14, 14, 256)
(None, 7, 7, 128)
(None, 7, 7, 256)
(None, 7, 7, 512)
Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
res_block_6 (ResBlock)     

[0.0571453720331192, 0.9842000007629395]

In [13]:
print(model.summary())

Model: "res_net__like_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
res_block_6 (ResBlock)       (None, 14, 14, 64)        28640     
_________________________________________________________________
res_block_7 (ResBlock)       (None, 7, 7, 256)         526976    
_________________________________________________________________
res_block_8 (ResBlock)       (None, 3, 3, 512)         1839744   
_________________________________________________________________
global_average_pooling2d_2 ( (None, 512)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 10)                5130      
Total params: 2,400,490
Trainable params: 2,397,418
Non-trainable params: 3,072
_________________________________________________________________
None
