In [57]:
import tensorflow as tf
from tensorflow import keras

In [60]:
class InceptionUnit(keras.layers.Layer):
    def __init__(self,filter_1,filter_2,filter_3,filter_4,filter_5,filter_6,activation='relu',**kwargs):
        super().__init__(**kwargs)
        
        self.direction_3 = [
            
            keras.layers.Conv2D(filter_3,1,strides=1,padding='same',activation = 'relu')
            
            
        ]
        self.direction_1 = [
            
            keras.layers.Conv2D(filter_1,1,strides = 1, padding='same',activation = 'relu'),
            keras.layers.Conv2D(filter_4,3,strides = 1 ,padding='same',activation = 'relu')
            
            
            
        ]
        
        self.direction_2 = [
            
            keras.layers.Conv2D(filter_2,1,strides = 1, padding='same',activation = 'relu'),
            keras.layers.Conv2D(filter_5,5,strides = 1 ,padding='same',activation = 'relu')
            
            
            
            
            
        ]
        
        self.direction_4 = [
            
            keras.layers.MaxPool2D(pool_size=3,strides=1,padding='same'),
            keras.layers.Conv2D(filter_6,1,strides=1,padding='same',activation = 'relu')
        ]
        
        
    def call(self,inputs):
        Z = inputs
        for layer in self.direction_1:
            Z_1 = layer(Z)
        for layer in self.direction_2:
            Z_2 = layer(Z)
            
        for layer in self.direction_3:
            Z_3 = layer(Z)
        for layer in self.direction_4:
            Z_4 = layer(Z)
        
        return tf.concat([Z_3,Z_1,Z_2,Z_4],axis=3)

class SEUnit(keras.layers.Layer):
    def __init__(self,units_1,units_2,**kwargs):
        super().__init__(**kwargs)
        
        self.main_layers = [
            
            keras.layers.GlobalAveragePooling2D(),
            keras.layers.Dense(units_1,activation='relu'),
            keras.layers.Dense(units_2,activation='sigmoid')
            
            
            
        ]
        
        
    def call(self,inputs):
        Z = inputs
        for layers in self.main_layers:
            Z = layers(Z)
            
            
    
        
        #print(tf.expand_dims(tf.expand_dims(Z,1),1).shape)
        # https://stackoverflow.com/questions/51900409/tensorflow-multiply-feature-map-in-batch-with-its-feature-mean-in-batch-n-h
        return(tf.expand_dims(tf.expand_dims(Z,1),1))
    
class SE_InceptionUnit(keras.layers.Layer):
    def __init__(self,strides_final,filters_1,filters_2,filters_3,filters_4,filters_5,filters_6,units_1,units_2,activation='relu',**kwargs):
        super().__init__(**kwargs)
        self.activation = keras.activations.get(activation)
        
        
        self.skip_layers = [
            
            
            
            
            InceptionUnit(filter_1= filters_1,filter_2=filters_2,filter_3=filters_3,filter_4 = filters_4,filter_5= filters_5,filter_6= filters_6,activation='relu'),
            
            SEUnit(units_1= units_1,units_2 = units_2)
            
        


        ]
        
        
        self.main_layers = [
            
            
            InceptionUnit(filter_1= filters_1,filter_2=filters_2,filter_3=filters_3,filter_4 = filters_4,filter_5= filters_5,filter_6= filters_6,activation='relu'),
            
            keras.layers.Conv2D(filters= units_2,kernel_size=1,strides = strides_final,activation = 'relu',padding='same')
            ## SE UNIT
            
            
        


        ]
        
        
        
    def call(self,inputs):
        Z = inputs
        for layer in self.main_layers:
            Z = layer(Z)
       
        Z2 = inputs
        for layer in self.skip_layers:
            Z2 = layer(Z2)
        
   
        return (Z*Z2)    
        
    
    
    

    


In [61]:
model = keras.models.Sequential()
model.add(keras.layers.Conv2D(64,kernel_size= 7,strides=2,activation = 'relu',input_shape=[229,229,3]))
model.add(keras.layers.MaxPool2D(pool_size=3,strides=2,padding='same'))




for i in range(3):
    if i == 2:
        model.add(SE_InceptionUnit(strides_final= 2,filters_1=128,filters_2=128,filters_3=256,filters_4=256,filters_5=256,filters_6=256,units_1=16,units_2=256,activation='relu'))

    else:
        model.add(SE_InceptionUnit(strides_final = 1,filters_1=128,filters_2=128,filters_3=256,filters_4=256,filters_5=256,filters_6=256,units_1=16,units_2=256,activation='relu'))

for i in range(4):
    if i == 3:
        model.add(SE_InceptionUnit(strides_final= 2,filters_1=256,filters_2=256,filters_3=512,filters_4=512,filters_5=512,filters_6=512,units_1=32,units_2=512,activation='relu'))

    else:
        model.add(SE_InceptionUnit(strides_final= 1,filters_1=256,filters_2=256,filters_3=512,filters_4=512,filters_5=512,filters_6=512,units_1=32,units_2=512,activation='relu'))


for i in range(6):
    if i == 5:
        model.add(SE_InceptionUnit(strides_final= 2,filters_1=512,filters_2=512,filters_3=1024,filters_4=1024,filters_5=1024,filters_6=1024,units_1=64,units_2=1024,activation='relu'))

    else:
        
        model.add(SE_InceptionUnit(strides_final= 1,filters_1=512,filters_2=512,filters_3=1024,filters_4=1024,filters_5=1024,filters_6=1024,units_1=64,units_2=1024,activation='relu'))


for i in range(3):
    if i == 2:
        model.add(SE_InceptionUnit(strides_final= 2,filters_1=1024,filters_2=1024,filters_3=2048,filters_4=2048,filters_5=2048,filters_6=2048,units_1=128,units_2=2048,activation='relu'))

    else:
        
        model.add(SE_InceptionUnit(strides_final= 1,filters_1=1024,filters_2=1024,filters_3=2048,filters_4=2048,filters_5=2048,filters_6=2048,units_1=128,units_2=2048,activation='relu'))

        


model.add(keras.layers.GlobalAveragePooling2D())

model.add(keras.layers.Dense(1000,activation='softmax'))





(None, 1, 1, 256)
(None, 1, 1, 256)
(None, 1, 1, 256)
(None, 1, 1, 512)
(None, 1, 1, 512)
(None, 1, 1, 512)
(None, 1, 1, 512)
(None, 1, 1, 1024)
(None, 1, 1, 1024)
(None, 1, 1, 1024)
(None, 1, 1, 1024)
(None, 1, 1, 1024)
(None, 1, 1, 1024)
(None, 1, 1, 2048)
(None, 1, 1, 2048)
(None, 1, 1, 2048)


In [49]:
model.build()
model.summary()

Model: "sequential_20"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_549 (Conv2D)         (None, 112, 112, 64)      9472      
                                                                 
 max_pooling2d_100 (MaxPooli  (None, 56, 56, 64)       0         
 ng2D)                                                           
                                                                 
 se__inception_unit_41 (SE_I  (None, 56, 56, 256)      1498128   
 nceptionUnit)                                                   
                                                                 
 se__inception_unit_42 (SE_I  (None, 56, 56, 256)      5135376   
 nceptionUnit)                                                   
                                                                 
 se__inception_unit_43 (SE_I  (None, 28, 28, 256)      5135376   
 nceptionUnit)                                       