In [1]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

In [29]:
class ConvBlock(tf.keras.Model):
    def __init__(self, channels, first_strides=1):
        super(ConvBlock, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(channels, kernel_size=3, strides=first_strides, padding='same')
        self.bn1 = tf.keras.layers.BatchNormalization()
        
        self.conv2 = tf.keras.layers.Conv2D(channels, kernel_size=3, strides=1, padding='same')
        self.bn2 = tf.keras.layers.BatchNormalization()
        

    def call(self, input_tensor, training = False):
        x = self.conv1(input_tensor)
        x = self.bn1(x)
        x = tf.nn.relu(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        return x

In [30]:
class ResBlock(tf.keras.Model):
    def __init__(self, channels, down_sample=False):
        super(ResBlock, self).__init__()
        self.down_sample = down_sample
        
        self.conv_block1 = ConvBlock(channels, first_strides =2 if down_sample else 1)
        
        if self.down_sample:
            self.down_sample_conv= tf.keras.layers.Conv2D(channels, kernel_size=1, strides=2, padding='same')
            self.down_bn = tf.keras.layers.BatchNormalization()
    
    def call(self, input_tensor, training = False):
        
        x = self.conv_block1(input_tensor)
        
        if self.down_sample:
            input_tensor = self.down_sample_conv(input_tensor)
            input_tensor = self.down_bn(input_tensor)
        
        x = x + input_tensor
        x = tf.nn.relu(x)
        return x

In [4]:
class ResNet34(tf.keras.Model):
    def __init__(self):
        super(ResNet34, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(64, kernel_size=7, strides=2, padding='same')
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.maxpool1 = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding='same')
        
        self.res_block1_1 = ResBlock(64)
        self.res_block1_2 = ResBlock(64)
        self.res_block1_3 = ResBlock(64)
        
        self.res_block2_1 = ResBlock(128, down_sample=True)
        self.res_block2_2 = ResBlock(128)
        self.res_block2_3 = ResBlock(128)
        self.res_block2_4 = ResBlock(128)
        
        self.res_block3_1 = ResBlock(256, down_sample=True)
        self.res_block3_2 = ResBlock(256)
        self.res_block3_3 = ResBlock(256)
        self.res_block3_4 = ResBlock(256)
        self.res_block3_5 = ResBlock(256)
        self.res_block3_6 = ResBlock(256)
        
        self.res_block4_1 = ResBlock(512, down_sample=True)
        self.res_block4_2 = ResBlock(512)
        self.res_block4_3 = ResBlock(512)
        
        self.avg = tf.keras.layers.GlobalAveragePooling2D()
        self.flat = tf.keras.layers.Flatten()
        self.fc = tf.keras.layers.Dense(10)
        
    def call(self, input_tensor, training = False):
        x = self.conv1(input_tensor)
        x = self.bn1(x)
        x = tf.nn.relu(x)
        x = self.maxpool1(x)
        
        x = self.res_block1_1(x)
        x = self.res_block1_2(x)
        x = self.res_block1_3(x)
        
        x = self.res_block2_1(x)
        x = self.res_block2_2(x)
        x = self.res_block2_3(x)
        x = self.res_block2_4(x)
        
        x = self.res_block3_1(x)
        x = self.res_block3_2(x)
        x = self.res_block3_3(x)
        x = self.res_block3_4(x)
        x = self.res_block3_5(x)
        x = self.res_block3_6(x)
        
        x = self.res_block4_1(x)
        x = self.res_block4_2(x)
        x = self.res_block4_3(x)
        
        x = self.avg(x)
        x = self.flat(x)
        x = self.fc(x)
        return x

In [5]:
# conv 36928
# bn 256
# conv 36928
# bn 256
# down conv 4160
# down bn 256

In [6]:
model = ResNet34()

In [7]:
model.build(input_shape=(None,32,32,3))

In [8]:
model.summary()

Model: "res_net34"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              multiple                  9472      
_________________________________________________________________
batch_normalization (BatchNo multiple                  256       
_________________________________________________________________
max_pooling2d (MaxPooling2D) multiple                  0         
_________________________________________________________________
res_block (ResBlock)         multiple                  74368     
_________________________________________________________________
res_block_1 (ResBlock)       multiple                  74368     
_________________________________________________________________
res_block_2 (ResBlock)       multiple                  74368     
_________________________________________________________________
res_block_3 (ResBlock)       multiple                  23

In [26]:
res_block = ResBlock(128, down_sample=True)

In [20]:
inputs = tf.keras.layers.Input(shape=(8,8,64))
x = res_block(inputs)
temp_model = tf.keras.Model(inputs=inputs, outputs=x)

In [21]:
temp_model.summary()

Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 8, 8, 64)]        0         
_________________________________________________________________
res_block_16 (ResBlock)      (None, 4, 4, 128)         231296    
Total params: 231,296
Trainable params: 230,528
Non-trainable params: 768
_________________________________________________________________


In [27]:
res_block.build(input_shape=(None,8,8,64))

In [28]:
res_block.summary()

Model: "res_block_17"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv_block_17 (ConvBlock)    multiple                  156928    
_________________________________________________________________
conv2d_41 (Conv2D)           multiple                  8320      
_________________________________________________________________
batch_normalization_41 (Batc multiple                  512       
Total params: 165,760
Trainable params: 164,992
Non-trainable params: 768
_________________________________________________________________
