In [6]:
import tensorflow as tf
from tensorflow import keras



class SEUnit(keras.layers.Layer):
    def __init__(self,units_1,units_2,**kwargs):
        super().__init__(**kwargs)
        
        self.main_layers = [
            
            keras.layers.GlobalAveragePooling2D(),
            keras.layers.Dense(units_1,activation='relu'),
            keras.layers.Dense(units_2,activation='sigmoid')
            
            
            
        ]
        
        
    def call(self,inputs):
        Z = inputs
        for layers in self.main_layers:
            Z = layers(Z)
            
            
    
        
        print(tf.expand_dims(tf.expand_dims(Z,1),1).shape)
        # https://stackoverflow.com/questions/51900409/tensorflow-multiply-feature-map-in-batch-with-its-feature-mean-in-batch-n-h
        return(tf.expand_dims(tf.expand_dims(Z,1),1))
    

    
    
    
class SE_ResidualUnit(keras.layers.Layer):
    def __init__(self,strides,size_1,size_2,size_3,size_4,filters_1,filters_2,filters_3,filters_4,units_1,units_2,activation='relu',**kwargs):
        super().__init__(**kwargs)
        self.activation = keras.activations.get(activation)
        
        
        self.skip_layers_2 = [
            
            
            
            ### RESIDUAL MODULE
            
            keras.layers.Conv2D(filters_1,size_1,strides=strides,padding='same',use_bias=False),
            keras.layers.BatchNormalization(),
            self.activation,
            keras.layers.Conv2D(filters_1,size_2,strides=1,padding='same',use_bias=False),
            keras.layers.BatchNormalization(),
            self.activation,

            keras.layers.Conv2D(filters_3,size_3,strides=1,padding='same',use_bias= False),
            keras.layers.BatchNormalization(),
            self.activation, 
            
            ## SE UNIT
            
            SEUnit(units_1= units_1,units_2 = units_2)
            
        


        ]
        
        
        self.main_layers = [
            
            
            
            ### RESIDUAL MODULE
            
            keras.layers.Conv2D(filters_1,size_1,strides=strides,padding='same',use_bias=False),
            keras.layers.BatchNormalization(),
            self.activation,
            keras.layers.Conv2D(filters_1,size_2,strides=1,padding='same',use_bias=False),
            keras.layers.BatchNormalization(),
            self.activation,

            keras.layers.Conv2D(filters_3,size_3,strides=1,padding='same',use_bias= False),
            keras.layers.BatchNormalization(),            
            self.activation
            
            ## SE UNIT
            
            
        


        ]
        self.skip_layers = [
            
            keras.layers.Conv2D(filters_4,size_4,strides=strides,padding='same',use_bias = False),
            keras.layers.BatchNormalization(),
            self.activation
            
            
        ]
        
        
    def call(self,inputs):
        Z = inputs
        for layer in self.main_layers:
            Z = layer(Z)
        skip_Z_2 = inputs
        for layer in self.skip_layers_2:
            skip_Z_2 = layer(skip_Z_2)
        Z2 = inputs
        for layer in self.skip_layers:
            Z2 = layer(Z2)
        result_Z = (skip_Z_2*Z)+ Z2
        
   
        return (result_Z)    
    
    


In [7]:
model = keras.models.Sequential()
model.add(keras.layers.Conv2D(64, kernel_size = 7, strides=2, activation='relu',input_shape = [229,229,3]))
model.add(keras.layers.MaxPool2D(pool_size=2,strides=2))

for i in range(3):
    
    
    model.add(SE_ResidualUnit(strides = 1,size_1 = 1,size_2 = 3,size_3 = 1,size_4=1,filters_1=64,filters_2=64,filters_3=256,filters_4=256,units_1=16,units_2=256))
    
for i in range(4):
    if i == 0:
        
        model.add(SE_ResidualUnit(strides = 2,size_1 = 1,size_2 = 3,size_3 = 1,size_4=1,filters_1=128,filters_2=128,filters_3=512,filters_4=512,units_1=32,units_2=512))
    
    else:
        
        model.add(SE_ResidualUnit(strides = 1,size_1 = 1,size_2 = 3,size_3 = 1,size_4=1,filters_1=128,filters_2=128,filters_3=512,filters_4=512,units_1=32,units_2=512))

        
        
for i in range(6):
    if i == 0:
        
        model.add(SE_ResidualUnit(strides = 2,size_1 = 1,size_2 = 3,size_3 = 1,size_4=1,filters_1=256,filters_2=256,filters_3=1024,filters_4=1024,units_1=64,units_2=1024))
    
    else:
        model.add(SE_ResidualUnit(strides = 1,size_1 = 1,size_2 = 3,size_3 = 1,size_4=1,filters_1=256,filters_2=256,filters_3=1024,filters_4=1024,units_1=64,units_2=1024))
        
        
for i in range(3):
    if i == 0:
        
        model.add(SE_ResidualUnit(strides = 2,size_1 = 1,size_2 = 3,size_3 = 1,size_4=1,filters_1=512,filters_2=512,filters_3=2048,filters_4=2048,units_1=128,units_2=2048))
    
    else:
        model.add(SE_ResidualUnit(strides = 1,size_1 = 1,size_2 = 3,size_3 = 1,size_4=1,filters_1=512,filters_2=512,filters_3=2048,filters_4=2048,units_1=128,units_2=2048))

        
model.add(keras.layers.GlobalAveragePooling2D())

model.add(keras.layers.Dense(1000,activation = 'softmax'))


(None, 1, 1, 256)
(None, 1, 1, 256)
(None, 1, 1, 256)
(None, 1, 1, 512)
(None, 1, 1, 512)
(None, 1, 1, 512)
(None, 1, 1, 512)
(None, 1, 1, 1024)
(None, 1, 1, 1024)
(None, 1, 1, 1024)
(None, 1, 1, 1024)
(None, 1, 1, 1024)
(None, 1, 1, 1024)
(None, 1, 1, 2048)
(None, 1, 1, 2048)
(None, 1, 1, 2048)


In [8]:
model.summary()

Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_120 (Conv2D)         (None, 112, 112, 64)      9472      
                                                                 
 max_pooling2d_2 (MaxPooling  (None, 56, 56, 64)       0         
 2D)                                                             
                                                                 
 se__residual_unit_17 (SE_Re  (None, 56, 56, 256)      143632    
 sidualUnit)                                                     
                                                                 
 se__residual_unit_18 (SE_Re  (None, 56, 56, 256)      217360    
 sidualUnit)                                                     
                                                                 
 se__residual_unit_19 (SE_Re  (None, 56, 56, 256)      217360    
 sidualUnit)                                          