In [157]:
import tensorflow as tf
from tensorflow import keras
import os
os.environ["CUDA_VISIBLE_DEVICES"]="2"

class Depthwise_Seperable_Conv(keras.layers.Layer):
    def __init__(self, nin, nout):
        super(Depthwise_Seperable_Conv, self).__init__()
        self.depthwise = keras.layers.Conv2D(nin, kernel_size=3, padding='same', groups=nin)
        self.pointwise = keras.layers.Conv2D(nout, kernel_size=1, padding='same')
        
    def call(self, inputs):
        x = self.depthwise(inputs)
        x = self.pointwise(x)
        return x
    
class LIST(keras.layers.Layer):
    def __init__(self, in_channel, out_channel, k=4, nb=2):
        super(LIST, self).__init__()
        self.squeeze = keras.layers.Conv2D( 
            in_channel//k, 
            kernel_size=(1, 1),
            padding='same')
        self.bn1 = keras.layers.BatchNormalization()
        self.stream1 = keras.layers.Conv2D(
            out_channel//nb,
            kernel_size=1,
            padding='same'
        )
        self.bn2 = keras.layers.BatchNormalization()
        self.stream2 = Depthwise_Seperable_Conv(in_channel//k, out_channel//nb)
        self.bn3 = keras.layers.BatchNormalization()
        
    def call(self, inputs):
        x = self.squeeze(inputs)
        x = self.bn1(x)
        x = keras.activations.relu(x)
        
        # stream 1
        x1 = self.stream1(x)
        x1 = self.bn2(x1)
        x1 = keras.activations.relu(x1)
        
        # stream 2
        x2 = self.stream2(x)
        x2 = self.bn3(x2)
        x2 = keras.activations.relu(x2)
        
        return keras.layers.Concatenate(axis=3)([x1, x2])

class Group_Conv_Dilated(keras.layers.Layer):
    def __init__(self, in_channel, groups, dilation_factor):
        super(Group_Conv_Dilated, self).__init__()
        self.num_groups = groups
        self.dilation_factor = dilation_factor
        self.partial = in_channel // groups
        
    def call(self, inputs):
        groups = tf.split(inputs, self.num_groups, axis=3)
#         print([group.shape for group in groups])
        dil_conv = []
        
        for group in groups:
            dil_conv.append(keras.layers.Conv2D(
                self.partial, 
                kernel_size=(3,3), 
                dilation_rate=self.dilation_factor, 
                padding='same')(group)
            )
            
        return keras.layers.Concatenate(axis=3)(dil_conv)

class Channel_Shuffle(keras.layers.Layer):
    def __init__(self, groups):
        super(Channel_Shuffle, self).__init__()
        self.groups = groups
        
    def call(self, inputs):
        shape = inputs.shape[1:]
        
        x = keras.layers.Reshape([shape[0], shape[1], self.groups, shape[2]//self.groups])(inputs)
        x = keras.layers.Permute([1, 2, 4, 3])(x)
        x = keras.layers.Reshape([shape[0], shape[1], shape[2]])(x)
        return x
    
class Group_Conv_Normal(keras.layers.Layer):
    def __init__(self, in_channel, groups):
        super(Group_Conv_Normal, self).__init__()
        self.num_groups = groups
        self.partial = in_channel // groups
        
    def call(self, inputs):
        groups = tf.split(inputs, self.num_groups, axis=3)
#         print([group.shape for group in groups])
        dil_conv = []
        
        for group in groups:
            dil_conv.append(keras.layers.Conv2D(
                self.partial, 
                kernel_size=(1,1),
                padding='same')(group)
            )
            
        return keras.layers.Concatenate(axis=3)(dil_conv)
    
class GSAT(keras.layers.Layer):
    def __init__(self, in_channel, num_groups=8, dilation_factor=2):
        super(GSAT, self).__init__()
        self.dil_conv = Group_Conv_Dilated(in_channel, num_groups, dilation_factor)
        self.shuffle = Channel_Shuffle(num_groups)
        self.norm_conv = Group_Conv_Normal(in_channel, num_groups)
        self.bn = keras.layers.BatchNormalization()
        self.add = keras.layers.Add()

    def call(self, inputs):
        skip = inputs;
        
        x = self.dil_conv(inputs)
        x = self.shuffle(x)
        x = self.norm_conv(x)
        x = self.bn(x)
        x = self.add([skip, x])
        
        return keras.activations.relu(x)

class UpSampling(keras.layers.Layer):
    def __init__(self, in_channel, out_channel, stride=2):
        super(UpSampling, self).__init__()
        self.list = LIST(in_channel, out_channel)
        self.stride = stride
        
    def call(self, inputs, concat=None):
        shape = inputs.shape
        x = tf.image.resize(
            inputs, 
            [shape[1] * self.stride, shape[2] * self.stride]
        )
        
        if concat is not None:
            print("Concatenating")
            x = keras.layers.Concatenate(axis=3)([concat, x])
        
        x = self.list(x)
        return x
    
class DownSampling(keras.layers.Layer):
    def __init__(self, in_channel, out_channel, stride=2):
        super(DownSampling, self).__init__()
        self.list = LIST(in_channel, out_channel)
        self.stride = stride
        
    def call(self, inputs):
        shape = inputs.shape
        x = tf.image.resize(
            inputs, 
            [shape[1] // self.stride, shape[2] // self.stride]
        )
        x = self.list(x)
        return x

In [161]:
# https://medium.com/analytics-vidhya/creating-mobilenetsv2-with-tensorflow-from-scratch-c85eb8605342
def expansion_block(x, t, filters, block_id):
    prefix = 'block_{}_'.format(block_id)
    total_filters = t*filters
    x = keras.layers.Conv2D(total_filters, 1, padding='same', use_bias=False, name=prefix+'expand')(x)
    x = keras.layers.BatchNormalization(name=prefix +'expand_bn')(x)
    x = keras.layers.ReLU(6,name = prefix +'expand_relu')(x)
    return x

def depthwise_block(x, stride, block_id):
    prefix = 'block_{}_'.format(block_id)
    x = DownSampling(x.shape[-1], x.shape[-1], stride=stride)(x)
#     x = keras.layers.DepthwiseConv2D(3,strides=(stride,stride),padding ='same', use_bias = False, name = prefix + 'depthwise_conv')(x)
    x = keras.layers.BatchNormalization(name=prefix +'dw_bn')(x)
    x = keras.layers.ReLU(6, name = prefix +'dw_relu')(x)
    return x

def projection_block(x, out_channels, block_id):
    prefix = 'block_{}_'.format(block_id)
    x = keras.layers.Conv2D(
        filters=out_channels,
        kernel_size = 1,
        padding='same',
        use_bias=False,
        name= prefix + 'compress'
    )(x)
    
    x = keras.layers.BatchNormalization(name=prefix +'compress_bn')(x)
    return x

def Bottleneck(x, t, filters, out_channels, stride, block_id, expand=False):
    y = expansion_block(x,t,filters,block_id)
    if expand:
        return y
    
    y = depthwise_block(y, stride, block_id)
    y = projection_block(y, out_channels,block_id)
    if y.shape[-1]==x.shape[-1]:
       y = keras.layers.Add()([x,y])
    return y

In [162]:
def MobileNetV2(input_shape=(224,224,3)):
    input_layer = keras.layers.Input(input_shape)
    
    x = keras.layers.Conv2D(32,3,strides=(2,2),padding='same', use_bias=False)(input_layer)
    x = keras.layers.BatchNormalization(name='conv1_bn')(x)
    x = keras.layers.ReLU(6, name='conv1_relu')(x)
    
    # 13 Bottlenecks
    x = depthwise_block(x, stride=1, block_id=0)
    x = projection_block(x, out_channels=16,block_id=0)
    x1 = Bottleneck(x, t = 6, filters = x.shape[-1], out_channels = 24, stride = 2,block_id = 1)
    x2 = Bottleneck(x1, t = 6, filters = x1.shape[-1], out_channels = 24, stride = 1,block_id = 2)
    x3 = Bottleneck(x2, t = 6, filters = x2.shape[-1], out_channels = 32, stride = 2,block_id = 3)
    x4 = Bottleneck(x3, t = 6, filters = x3.shape[-1], out_channels = 32, stride = 1,block_id = 4)
    x5 = Bottleneck(x4, t = 6, filters = x4.shape[-1], out_channels = 32, stride = 1,block_id = 5)
    x6 = Bottleneck(x5, t = 6, filters = x5.shape[-1], out_channels = 64, stride = 2,block_id = 6)
    x7 = Bottleneck(x6, t = 6, filters = x6.shape[-1], out_channels = 64, stride = 1,block_id = 7)
    x8 = Bottleneck(x7, t = 6, filters = x7.shape[-1], out_channels = 64, stride = 1,block_id = 8)
    x9 = Bottleneck(x8, t = 6, filters = x8.shape[-1], out_channels = 64, stride = 1,block_id = 9)
    x10 = Bottleneck(x9, t = 6, filters = x9.shape[-1], out_channels = 96, stride = 1,block_id = 10)
    x11 = Bottleneck(x10, t = 6, filters = x10.shape[-1], out_channels = 96, stride = 1,block_id = 11)
    x12 = Bottleneck(x10, t = 6, filters = x10.shape[-1], out_channels = 96, stride = 1,block_id = 12)
    x13_expand = Bottleneck(x11, t = 6, filters = x11.shape[-1], out_channels = 96, stride = 1,block_id = 13, expand=True)
    
    model = keras.Model(input_layer, x13_expand)
    return model
    
def OptimizedUNet(input_shape=(256,512,3)):
    base_model = MobileNetV2(input_shape)
    input_layer = base_model.input
    
    layer_names = ["block_6_expand_relu", "block_3_expand_relu", "block_1_expand_relu"]
    
    # Up Scaling
    x = base_model.output
    x = keras.layers.Conv2D(filters=512, kernel_size=(1,1), padding='same')(x)
    x = keras.layers.LeakyReLU(alpha=0.2)(x)
    x = UpSampling(x.shape[-1], 256)(x, concat=base_model.get_layer(layer_names[0]).output)
    x = keras.layers.LeakyReLU(alpha=0.2)(x)
    x = UpSampling(x.shape[-1], 128)(x, concat=base_model.get_layer(layer_names[1]).output)
    x = keras.layers.LeakyReLU(alpha=0.2)(x)
    x = UpSampling(x.shape[-1], 64)(x, concat=base_model.get_layer(layer_names[2]).output)
    x = keras.layers.LeakyReLU(alpha=0.2)(x)
    x = UpSampling(x.shape[-1], 32)(x, concat=base_model.layers[0].output)
    output = keras.layers.Conv2D(filters=1, activation='sigmoid', kernel_size=(3,3), padding='same')(x)

    model = keras.Model(input_layer, output)
    return model

In [163]:
model = OptimizedUNet((256,512,3))
model.summary()

Concatenating
Concatenating
Concatenating
Concatenating
Model: "model_51"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_76 (InputLayer)           [(None, 256, 512, 3) 0                                            
__________________________________________________________________________________________________
conv2d_2187 (Conv2D)            (None, 128, 256, 32) 864         input_76[0][0]                   
__________________________________________________________________________________________________
conv1_bn (BatchNormalization)   (None, 128, 256, 32) 128         conv2d_2187[0][0]                
__________________________________________________________________________________________________
conv1_relu (ReLU)               (None, 128, 256, 32) 0           conv1_bn[0][0]                   
___________________________________

In [152]:
def upsampling(input_tensor, n_filters, concat_layer, concat=True):
  '''
  Block of Decoder
  '''
  # Bilinear 2x upsampling layer
  x = keras.layers.UpSampling2D(size=(2,2), interpolation='bilinear')(input_tensor)
  # concatenation with encoder block 
  if concat:
    x = keras.layers.concatenate([x, concat_layer])
  # decreasing the depth filters by half
  x = keras.layers.Conv2D(filters=n_filters, kernel_size=(3,3), padding='same')(x)
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.Conv2D(filters=n_filters, kernel_size=(3,3), padding='same')(x)
  x = keras.layers.BatchNormalization()(x)
  return x

def MobileNet():
    base_model = keras.applications.MobileNetV2(include_top=False, weights='imagenet', input_shape=(256, 512, 3))
    for layer in base_model.layers:
        layer.trainable = True

    inputs = base_model.input
    x = base_model.get_layer("block_13_expand_relu").output

    names = ["input_6", "block_1_expand_relu", "block_3_expand_relu", "block_6_expand_relu"][::-1]
    bneck = keras.layers.Conv2D(filters=512, kernel_size=(1,1), padding='same')(x)
    x = keras.layers.LeakyReLU(alpha=0.2)(bneck)
    x = upsampling(x, 256, base_model.get_layer(names[0]).output)
    x = keras.layers.LeakyReLU(alpha=0.2)(x)
    x = upsampling(x, 128, base_model.get_layer(names[1]).output)
    x = keras.layers.LeakyReLU(alpha=0.2)(x)
    x = upsampling(x, 64, base_model.get_layer(names[2]).output)
    x = keras.layers.LeakyReLU(alpha=0.2)(x)
    x = upsampling(x, 32, base_model.layers[0].output)
    x = keras.layers.Conv2D(filters=1, activation='sigmoid', kernel_size=(3,3), padding='same')(x)

    model = keras.Model(inputs=inputs, outputs=x)
    return model

model = MobileNet()
model.summary()

Model: "model_45"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_73 (InputLayer)           [(None, 256, 512, 3) 0                                            
__________________________________________________________________________________________________
Conv1 (Conv2D)                  (None, 128, 256, 32) 864         input_73[0][0]                   
__________________________________________________________________________________________________
bn_Conv1 (BatchNormalization)   (None, 128, 256, 32) 128         Conv1[0][0]                      
__________________________________________________________________________________________________
Conv1_relu (ReLU)               (None, 128, 256, 32) 0           bn_Conv1[0][0]                   
___________________________________________________________________________________________