In [1]:
import tensorflow as tf

In [2]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Activation, Dense, Flatten, Conv2D, MaxPooling2D, 
    GlobalAveragePooling2D, AveragePooling2D, BatchNormalization, add)
import tensorflow.keras.regularizers as regulizers

In [11]:
class ConvWithBatchNorm(tf.keras.layers.Conv2D):
    
    def __init__(self, activation='relu', name='convbn', **kwargs):
        
        super().__init__(activation=None, name=name + '_c', **kwargs)
        self.batch_norm = BatchNormalization(axis = -1 , name = name + '_bn')
        self.activation =  Activation(activation , name = name + '_act') if activation is not None else None
        
        
        
    #def build( self, input_shape ):
        #pass
            
            
    def call(self, inputs, training=None):
        x = super().call(inputs)
        x = self.batch_norm(x, training=training)
        if self.activation is not None:
            x = self.activation(x)
        return x
    
    
    
    
  
        

In [12]:
import functools

class ResidualMerge(tf.keras.layers.Layer):
    
    def __init__(self, name = 'block' ,**kwargs ):
        
        super().__init__(name = name)
        self.shortcut = None
        self.kwargs = kwargs
        
    
    def build( self, input_shape ):
        
        x_shape = input_shape[0]
        x_residual_shape = input_shape[1]
        
        if x_shape == x_residual_shape:
            self.shortcut = functools.partial(tf.identity , name=self.name + '_shortcut')
            
        else:
            
            strides = (
            
                    int(round(x_shape[1]/x_residual_shape[1])),
                    int(round(x_shape[2]/x_residual_shape[2]))
            )
            x_residual_channels = x_residual_shape[3]
            
            self.shortcut = ConvWithBatchNorm(filters = x_residual_channels ,
                             kernel_size = (1,1) , strides = strides , activation=None , name = self.name + '_shortcut_c' ,**self.kwargs)
            
    def call(self , inputs  ):
        
        x , x_residual = inputs 
        x_shortcut = self.shortcut(x)
        x_merge = add([ x_shortcut , x_residual ])
        return x_merge
        
        
            

In [13]:
class BasicResidualBlock(tf.keras.Model):
    
    def __init__(self, filters=16, kernel_size=1, strides=1, activation='relu',
                 kernel_initializer='he_normal', kernel_regularizer=regulizers.l2(1e-4),
                 name='res_basic', **kwargs):
        
        
        super().__init__(name = name )
        
        self.conv1 = ConvWithBatchNorm(filters = filters , kernel_size=kernel_size ,
                      strides = strides,padding = 'same' ,activation=activation ,
                      kernel_initializer=kernel_initializer , kernel_regularizer=kernel_regularizer,
                      name = name + '_cb_1' , **kwargs)
        
        self.conv2 = ConvWithBatchNorm(filters = filters , kernel_size=kernel_size , strides=strides,
                                      padding = 'same' , activation=activation , kernel_initializer=kernel_initializer,
                                      kernel_regularizer=kernel_regularizer)
        '''
        self.conv3 = ConvWithBatchNorm( filters=filters , kernel_size=kernel_size , strides=strides,padding='same',
                                      kernel_initializer=kernel_initializer , kernel_regularizer=kernel_regularizer)
        
        '''
        
        self.merge = ResidualMerge(kernel_initializer=kernel_initializer , kernel_regularizer=kernel_regularizer,
                                  name = name)
        
        self.activation = Activation(activation , name = name + '_act')
        
    
    
    def build( self, input_shape ):
        pass
    
    
    def call(self, inputs , training = None ):
        
        x =  inputs
        
        x_residual = self.conv1(x , training = training)
        x_residual = self.conv2( x_residual , training = training)
        
        x_merge = self.merge([ x , x_residual])
        x_merge = self.activation(x_merge)
        
        return x_merge
        

In [14]:
class ResidualBlockWithBottleneck(tf.keras.Model):
    
    def __init__( self, filters = 16 , kernel_size = 1 , strides =   1, activation = 'relu',
               kernel_initializer='he_normal', kernel_regularizer=regulizers.l2(1e-4),
                 name='res_basic', **kwargs):
        
        
        super().__init__(name = name )
        
        self.conv0 = ConvWithBatchNorm( filters=filters , kernel_size=kernel_size , strides=strides,
                                      padding = 'same' , activation=activation , kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer)
        
        self.conv1 = ConvWithBatchNorm( filters=filters , kernel_size=kernel_size , strides=strides,
                                      padding = 'same' , activation=activation , kernel_initializer=kernel_initializer,
                                      kernel_regularizer= kernel_regularizer )
        
        self.conv2 = ConvWithBatchNorm(filters=filters , kernel_size=kernel_size , strides=strides , padding = 'same' , activation=activation,
                                      kernel_initializer=kernel_initializer , kernel_regularizer=kernel_regularizer)
        
        
        self.merge = ResidualMerge(kernel_initializer=kernel_initializer , kernel_regularizer=kernel_regularizer , name = name )
        self.activation = Activation(activation=activation , name = name + '_act')
        
        
        
    def buiild(self, input_shape):
        pass
    
    
    def call(self, inputs , training = None ):
        
        x = inputs
        x_residual = self.conv0(x , training = training)
        x_residual = self.conv1(x_residual , training =training)
        x_residual = self.conv2(x_residual , training = training)
        
        x_merge = self.merge([x , x_residual])
        x_merge = self.activation(x_merge)
        return x_merge

In [15]:
class ResidualMacroBlock(tf.keras.models.Sequential):
    """ Macro-block, chaining multiple residual blocks (as a Sequential model)"""

    def __init__(self, block_class=ResidualBlockWithBottleneck, repetitions=3, 
                 filters=16, kernel_size=1, strides=1, activation='relu',
                 kernel_initializer='he_normal', kernel_regularizer=regulizers.l2(1e-4),
                 name='res_macroblock', **kwargs):
        
        
        layer_lst = []
        for i in range(repetitions):
            block_class_object = block_class(filters =filters , kernel_size=kernel_size ,
                                            strides = strides if i==0 else 1 , name = "{}_{}".format(name , i ),
                                            kernel_initializer = kernel_initializer , kernel_regularizer = kernel_regularizer)
            layer_lst.append(block_class_object)
            
        
        super().__init__(layer_lst , name = name )
        
        
    
    def build(self, input_shape):
        pass
    
    
    
    #def call(self, input_shape):
        # always return tensor value
        #pass
    

In [16]:
class ResNet(tf.keras.models.Sequential):
    """ ResNet model for classification"""

    def __init__(self, input_shape, num_classes=1000, 
                 block_class=ResidualBlockWithBottleneck, repetitions=(2, 2, 2, 2),
                 kernel_initializer='he_normal', kernel_regularizer=regulizers.l2(1e-4),
                 name='resnet'):
        
        
        filters = 64
        strides = 2
        
        seq_lst  = []
        
        seq_lst.append(Input( shape = input_shape , name = 'input'))
        seq_lst.append(ConvWithBatchNorm( filters = filters , kernel_size = 7 , strides = strides , padding = 'same',
                     kernel_initializer = kernel_initializer , kernel_regularizer =kernel_regularizer,name = 'conv'))
        seq_lst.append(MaxPooling2D(pool_size = 3 , strides= strides, padding = 'same' , name = 'max_pool'))
        
        for i, repet in enumerate(repetitions):
            residual_block_obj = ResidualMacroBlock(block_class=block_class, repetitions=repet, filters=min(filters * (2 ** i), 1024), kernel_size=3, activation='relu',
                     strides=strides if i != 0 else 1, name='block_{}'.format(i),
                     kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer)
            seq_lst.append(residual_block_obj)
                
        
        seq_lst.append(GlobalAveragePooling2D(name='avg_pool'))
        seq_lst.append(  Dense(units=num_classes, kernel_initializer=kernel_initializer, activation='softmax'))
        
        
        super().__init__(seq_lst , name = name )
        
        
        
    
    def build(self, input_shape):
        pass
    
    
    #def call(self, input_shape):
        #pass
        
                
                
        

In [17]:
class ResNet18(ResNet):
    def __init__(self, input_shape, num_classes=1000, name='resnet18',kernel_initializer='he_normal', kernel_regularizer=regulizers.l2(1e-4)):
        
        super().__init__(input_shape, num_classes, block_class=BasicResidualBlock, repetitions=(2, 2, 2, 2),
                         kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer)
        

In [18]:
input_shape = [224, 224, 3]
model = ResNet18(input_shape, 1000)



In [19]:
model.summary()

Model: "resnet"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv_c (ConvWithBatchNorm)   (None, 112, 112, 64)      9728      
_________________________________________________________________
max_pool (MaxPooling2D)      (None, 56, 56, 64)        0         
_________________________________________________________________
block_0 (ResidualMacroBlock) (None, 56, 56, 64)        157568    
_________________________________________________________________
block_1 (ResidualMacroBlock) (None, 14, 14, 128)       544512    
_________________________________________________________________
block_2 (ResidualMacroBlock) (None, 4, 4, 256)         2170368   
_________________________________________________________________
block_3 (ResidualMacroBlock) (None, 1, 1, 512)         8666112   
_________________________________________________________________
avg_pool (GlobalAveragePooli (None, 512)               0    