# Model Subclassing

Keras provides a mechanism for subclassing `tf.keras.Model`, allowing you to 
elegantly define your own models that can be used with any high level API
call that expects a `Model` object. The code below shows the construction of
a residual style vision network, modeled after Resnet50.

The process of model subclassing follows a pattern:
    1. Define a `class` that inherits from `tf.keras.Model`
    2. Define a `__init__` method that defines the layers to be used in the model
    3. Define a `__call__` method that chains the defined layers together 
    into a flow of information

Note that subclassed models can include other models as part of the computational
flow, allowing for great reuse of fundamental building blocks.

## Tail

As an example, we will begin by constructing a tail consisting of a $7 \times 7 / 2$
convolution and a $3 \times 3 / 2$ max pool. The `__init__` function defines the layers,
and the `__call__` function describes the flow of data through these layers.

In [1]:
import tensorflow as tf
from tensorflow.keras import layers

class Tail(tf.keras.Model):

    def __init__(self, Ni, *args, **kwargs):

        # Call parent constructor with *args, **kwargs
        super(Tail, self).__init__(*args, **kwargs)

        # Convolution
        self.conv = layers.Conv2D(
                filters=Ni,
                kernel_size=(7, 7),
                strides=2,
                use_bias=False,
                name='tail_conv')

        # Batch norm
        self.bn = layers.BatchNormalization(
                name='tail_bn')

        # ReLU
        self.relu = layers.ReLU(name='tail_relu')

        # Max pooling layer
        self.pool = layers.MaxPool2D(
                pool_size=(2, 2),
                strides=2,
                name='tail_pool')

    def call(self, inputs, **kwargs):
        # Residual forward pass
        _ = self.conv(inputs, **kwargs)

        # Must call with **kwargs to receive training state
        _ = self.bn(_, **kwargs)

        _ = self.relu(_, **kwargs)

        return self.pool(_, **kwargs)

## Basic Block

Next we define the fundamental CNN style 2D convolution block
of Resnet, ie batch-norm, relu, convolution.

Note that the number of filters and the kernel size are
parameterized, and that parameter packs `*args, **kwargs`
are forwarded to the convolution layer. This is important
as it enables the reuse of this model for the various
types of convolutions that we will need.

In [2]:
class ResnetBasic(tf.keras.Model):

    def __init__(self, filters, kernel_size, strides=(1,1), *args, **kwargs):
        super(ResnetBasic, self).__init__(*args, **kwargs)
        self.batch_norm = layers.BatchNormalization()
        self.relu = layers.ReLU()
        self.conv2d = layers.Conv2D(
                filters=filters,
                kernel_size=kernel_size,
                padding='same',
                activation=None,
                use_bias=False,
                strides=strides)

    def call(self, inputs, **kwargs):
        x = self.batch_norm(inputs, **kwargs)
        x = self.relu(x, **kwargs)
        return self.conv2d(x, **kwargs)

## Standard Bottleneck

We can use `ResnetBasic` to build a bottleneck layer. Again we leave the
number of input feature maps parameterized so that way may reuse the
`Bottleneck` model at each level of downsampling.

Note that here we use loops in the `__init__` function. This helps define
repeating structures, but recall that in the `__init__` function we are only
defining layers and not the flow of computation between them. The use of a loop
in `__call__` is where we define the sequential flow through each iteration in
the loop.

In [3]:
class Bottleneck(tf.keras.Model):

    def __init__(self, Ni, *args, **kwargs):
        super(Bottleneck, self).__init__(*args, **kwargs)

        # Three residual convolution blocks
        kernels = [(1, 1), (3, 3), (1, 1)]
        feature_maps = [Ni // 4, Ni // 4, Ni]
        self.residual_filters = [
            ResnetBasic(N, K)
            for N, K in zip(feature_maps, kernels)
        ]

        # Merge operation
        self.merge = layers.Add()

    def call(self, inputs, **kwargs):

        # Residual forward pass
        res = inputs
        for res_layer in self.residual_filters:
            res = res_layer(res, **kwargs)

        # Combine residual pass with identity
        return self.merge([inputs, res], **kwargs)

## Special Bottleneck

We can define the special bottleneck layer by subclassing
the `Bottleneck` class. We add a convolutional layer along
the main path and redefine the `__call__` method to include
this layer.

In [4]:
class SpecialBottleneck(Bottleneck):

    def __init__(self, Ni, *args, **kwargs):

        # Layers that also appear in standard bottleneck
        super(SpecialBottleneck, self).__init__(Ni, *args, **kwargs)

        # Add convolution layer along main path
        self.main = layers.Conv2D(
                Ni,
                (1, 1),
                padding='same',
                
                activation=None,
                use_bias=False)

    def call(self, inputs, **kwargs):

        # Residual forward pass
        res = inputs
        for res_layer in self.residual_filters:
            res = res_layer(res, **kwargs)

        # Convolution on main forward pass
        main = self.main(inputs, **kwargs)

        # Merge residual and main
        return self.merge([main, res])

## Downsampling

Next we need to define the downsampling layer. 

In [5]:
class Downsample(tf.keras.Model):

    def __init__(self, Ni, *args, **kwargs):
        super(Downsample, self).__init__(*args, **kwargs)

        # Three residual convolution blocks
        kernels = [(1, 1), (3, 3), (1, 1)]
        strides = [(2, 2), (1, 1), (1, 1)]
        feature_maps = [Ni // 2, Ni // 2, 2*Ni]

        self.residual_filters = [
            ResnetBasic(N, K, strides=S)
            for N, K, S in zip(feature_maps, kernels, strides)
        ]

        # Convolution on main path
        self.main = ResnetBasic(2*Ni, (1,1), strides=(2,2))

        # Merge operation for residual and main
        self.merge = layers.Add()

    def call(self, inputs, **kwargs):

        # Residual forward pass
        res = inputs
        for res_layer in self.residual_filters:
            res = res_layer(res,**kwargs)

        # Main forward pass
        main = self.main(inputs, **kwargs)

        # Merge residual and main
        return self.merge([main, res])

## Final Model

Finally, we can assemble these blocks into the final model. Note that
a levels argument is added to the constructor to parameterize the number
of bottleneck repeats at each level of downsampling. The `levels` parameter
receives a list of integers where `levels[i]` gives the number of bottleneck
repeats at level `i`.

In [6]:
class Resnet(tf.keras.Model):

    def __init__(self, classes, filters, levels, *args, **kwargs):
        super(Resnet, self).__init__(*args, **kwargs)


        # Lists to hold various layers
        self.blocks = list()

        # Tail
        self.tail = Tail(filters)

        # Special bottleneck layer with convolution on main path
        self.level_0_special = SpecialBottleneck(filters)

        # Loop through levels and their parameterized repeat counts
        for level, repeats in enumerate(levels):
            for block in range(repeats):
                # Append a bottleneck block for each repeat
                name = 'bottleneck_%i_%i' % (level, block)
                layer = Bottleneck(filters, name=name)
                self.blocks.append(layer)

            # Downsample and double feature maps at end of level
            name = 'downsample_%i' % (level)
            layer = Downsample(filters, name=name)
            self.blocks.append(layer)
            filters *= 2

        self.level2_batch_norm = layers.BatchNormalization(name='final_bn')
        self.level2_relu = layers.ReLU(name='final_relu')

        # Decoder - global average pool and fully connected
        self.global_avg = layers.GlobalAveragePooling2D(
                name='GAP'
        )

        # Dense with regularizer, just as a test
        self.dense = layers.Dense(
                classes,
                name='dense',
                # Just for fun, show a regularized layer
                kernel_regularizer=tf.keras.regularizers.l2(0.01),
                use_bias=True
        )


    def call(self, inputs, **kwargs):
        x = self.tail(inputs, **kwargs)
        x = self.level_0_special(x)

        # Loop over layers by level
        for layer in self.blocks:
            x = layer(x, **kwargs)

        # Finish up specials in level 2
        x = self.level2_batch_norm(x, **kwargs)
        x = self.level2_relu(x)

        # Decoder
        x = self.global_avg(x)
        return self.dense(x, **kwargs)

## Using the Model

Now we can construct the model. Here we define four level model
with bottleneck repeats given in `levels`.

In [7]:
levels = [4, 3, 6, 2]
num_classes = 100
width = 32
model = Resnet(num_classes, width, levels)

Now we have constructed a model, but the model still knows nothing
about the actual input sizes it will be working with. We can define
an input layer and call the model on this input.

In [8]:
# tf.keras.InputLayer in TF 2.0
inputs = tf.keras.layers.InputLayer(
        input_shape=(64, 64, 3),
        batch_size=32,
        dtype=tf.float32
)
outputs = model(inputs.input)

Instructions for updating:
Colocations handled automatically by placer.


Finally, we can get information about the model, ie

In [9]:
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
bottleneck_0_0 (Bottleneck)  (32, 14, 14, 32)          1280      
_________________________________________________________________
bottleneck_0_1 (Bottleneck)  (32, 14, 14, 32)          1280      
_________________________________________________________________
bottleneck_0_2 (Bottleneck)  (32, 14, 14, 32)          1280      
_________________________________________________________________
bottleneck_0_3 (Bottleneck)  (32, 14, 14, 32)          1280      
_________________________________________________________________
downsample_0 (Downsample)    (32, 7, 7, 64)            6272      
_________________________________________________________________
bottleneck_1_0 (Bottleneck)  (32, 7, 7, 64)            4736      
_________________________________________________________________
bottleneck_1_1 (Bottleneck)  (32, 7, 7, 64)            4736      
__________

Or we can iterate over layers and examine properties of
the individual `Layer` objects.

In [10]:
FMT = "%-22s : %15s -> %-15s"
for layer in model.layers:
    name = type(layer).__name__
    inp= layer.input_shape 
    out= layer.output_shape
    msg = FMT % (name, inp, out)
    print(msg)

Bottleneck             : (32, 14, 14, 32) -> (32, 14, 14, 32)
Bottleneck             : (32, 14, 14, 32) -> (32, 14, 14, 32)
Bottleneck             : (32, 14, 14, 32) -> (32, 14, 14, 32)
Bottleneck             : (32, 14, 14, 32) -> (32, 14, 14, 32)
Downsample             : (32, 14, 14, 32) -> (32, 7, 7, 64) 
Bottleneck             :  (32, 7, 7, 64) -> (32, 7, 7, 64) 
Bottleneck             :  (32, 7, 7, 64) -> (32, 7, 7, 64) 
Bottleneck             :  (32, 7, 7, 64) -> (32, 7, 7, 64) 
Downsample             :  (32, 7, 7, 64) -> (32, 4, 4, 128)
Bottleneck             : (32, 4, 4, 128) -> (32, 4, 4, 128)
Bottleneck             : (32, 4, 4, 128) -> (32, 4, 4, 128)
Bottleneck             : (32, 4, 4, 128) -> (32, 4, 4, 128)
Bottleneck             : (32, 4, 4, 128) -> (32, 4, 4, 128)
Bottleneck             : (32, 4, 4, 128) -> (32, 4, 4, 128)
Bottleneck             : (32, 4, 4, 128) -> (32, 4, 4, 128)
Downsample             : (32, 4, 4, 128) -> (32, 2, 2, 256)
Bottleneck             : (32, 2

An upcoming demo will show how to train the constructed model.