In [None]:
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import *

#https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer
# ResNetBlock class inherits Layer class from tensorflow!
class ResNetBlock(Layer):

  def __init__(self, out_channels, first_stride=1):
    super().__init__()

    first_padding = 'same'
    if first_stride != 1:
      first_padding = 'valid'

    self.conv_sequence = Sequential([
        Conv2D(out_channels, 3, first_stride, padding=first_padding),
        BatchNormalization(),
        ReLU(),

        Conv2D(out_channels, 3, 1, padding='same'),
        BatchNormalization(),
        ReLU()
    ])

  def call(self, inputs):
    x = self.conv_sequence(inputs)

    if x.shape == inputs.shape:
      x = x + inputs # Skip connection

    return x

layer = ResNetBlock(4) # out_channels = 4
print(layer)

<__main__.ResNetBlock object at 0x7faefde90290>


In [None]:
#If you remove ResNet and self from the super() call and use super().__init__(), it will still invoke the __init__() method of the immediate parent class. 
#However, it will use the current class (ResNet) and the instance (self) automatically, based on the context.
#In most cases, omitting the arguments in super() and using super().__init__() will work correctly, 
#as Python will automatically determine the appropriate class and instance to use. 
#This is known as "zero-argument super()." It simplifies the code and is often used when there is only a single level of inheritance.
#However, if there are multiple levels of inheritance or you need to be explicit about the parent class and instance, specifying them in the super() call becomes necessary.
class ResNet(Model):
  def __init__(self):
    super(ResNet, self).__init__()  # read the comment above. (why ResNet and self are in super()?)

    self.conv_1 = Sequential([Conv2D(64, 7, 2),
                              ReLU(),
                              MaxPooling2D(3, 2)
    ])

    self.resnet_chains = Sequential([ResNetBlock(64), ResNetBlock(64)] +
                                    [ResNetBlock(128, 2), ResNetBlock(128)] +
                                    [ResNetBlock(256, 2), ResNetBlock(256)] +
                                    [ResNetBlock(512, 2), ResNetBlock(512)])  
                                    # '+' operator is used to concatenate these lists together. By using the + operator, the individual lists are merged into a single list, 
                                    # if stride > 1, no skip connection, ex) ResNetBlock(256, 2)
    self.out = Sequential([GlobalAveragePooling2D(),
                           Dense(1, activation='sigmoid')])

  def call(self, x):
    x = self.conv_1(x)
    x = self.resnet_chains(x)
    x = self.out(x)
    return x

model = ResNet()
print(model)

<__main__.ResNet object at 0x7faefd588350>


In [None]:
model.build(input_shape=(1, 224, 224, 3))

In [None]:
model.summary()

Model: "res_net"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 sequential_1 (Sequential)   multiple                  9472      
                                                                 
 sequential_10 (Sequential)  (1, 5, 5, 512)            11004672  
                                                                 
 sequential_11 (Sequential)  (1, 1)                    513       
                                                                 
Total params: 11,014,657
Trainable params: 11,006,977
Non-trainable params: 7,680
_________________________________________________________________
