In [1]:
import tensorflow as tf 
import tensorflow.keras.layers as layers
import tensorflow.keras.models as model

In [None]:
!pip install tensorflow --upgrade

# BASNET ATTEMPT 1


In [28]:
class ResBlock(tf.keras.Model):
    def __init__(self,channel,stride=1):
        super(ResBlock,self).__init__(name='res_block')
        self.flag = (stride != 1)
        self.cn1 = layers.Conv2D(channel,3,stride,padding='same')
        self.bn1 = layers.BatchNormalization()
        self.cn2 = layers.Conv2D(channel,3,padding='same')
        self.bn2 = layers.BatchNormalization()
        self.relu = layers.ReLU()
        if self.flag:
            self.bn3 = layers.BatchNormalization()
            self.cn3 = layers.Conv2D(channel,1,stride)
            
    def call(self,x):
        x1 = self.cn1(x)
        x1 = self.bn1(x1)
        x1 = self.relu(x1)
        x1 = self.cn2(x1)
        x1 = self.bn2(x1)
        if self.flag:
            x1 = self.cn3(x1)
            x1 = self.bn3(x1)
        x1 = layers.add([x,x1])
        x1 = self.relu(x1)
        return x1
            
        
                


In [29]:
class RefUnet(tf.keras.Model):
    def __init__(self,channel):
        super(RefUnet,self).__init__()
        
        self.conv0 = layers.Conv2D(channel,3,padding='same')

        self.conv1 = layers.Conv2D(64,3,padding='same')
        self.bn1 = layers.BatchNormalization()
        self.relu1 = layers.ReLU()

        self.pool1 = layers.MaxPool2D(2,2)

        self.conv2 = layers.Conv2D(64,3,padding='same')
        self.bn2 = layers.BatchNormalization(64)
        self.relu2 = layers.ReLU()

        self.pool2 = layers.MaxPool2D(2,2)

        self.conv3 = layers.Conv2D(64,3,padding='same')
        self.bn3 = layers.BatchNormalization(64)
        self.relu3 = layers.ReLU()

        self.pool3 = layers.MaxPool2D(2,2)

        self.conv4 = layers.Conv2D(64,3,padding='same')
        self.bn4 = layers.BatchNormalization(64)
        self.relu4 = layers.ReLU()

        self.pool4 = layers.MaxPool2D(2,2)

        #####

        self.conv5 = layers.Conv2D(64,3,padding='same')
        self.bn5 = layers.BatchNormalization(64)
        self.relu5 = layers.ReLU()

        #####

        self.conv_d4 = layers.Conv2D(64,3,padding='same')
        self.bn_d4 = layers.BatchNormalization(64)
        self.relu_d4 = layers.ReLU()

        self.conv_d3 = layers.Conv2D(64,3,padding='same')
        self.bn_d3 = layers.BatchNormalization(64)
        self.relu_d3 = layers.ReLU()

        self.conv_d2 = layers.Conv2D(64,3,padding='same')
        self.bn_d2 = layers.BatchNormalization(64)
        self.relu_d2 = layers.ReLU()

        self.conv_d1 = layers.Conv2D(64,3,padding='same')
        self.bn_d1 = layers.BatchNormalization(64)
        self.relu_d1 = layers.ReLU()

        self.conv_d0 = layers.Conv2D(1,3,padding='same')

        self.upscore2 = layers.UpSampling2D(2,interpolation='bilinear')
        
        def call(self,x):
            hx = x
            hx = self.conv0(hx)

            hx1 = self.relu1(self.bn1(self.conv1(hx)))
            hx = self.pool1(hx1)

            hx2 = self.relu2(self.bn2(self.conv2(hx)))
            hx = self.pool2(hx2)

            hx3 = self.relu3(self.bn3(self.conv3(hx)))
            hx = self.pool3(hx3)

            hx4 = self.relu4(self.bn4(self.conv4(hx)))
            hx = self.pool4(hx4)

            hx5 = self.relu5(self.bn5(self.conv5(hx)))

            hx = self.upscore2(hx5)

            d4 = self.relu_d4(self.bn_d4(self.conv_d4(torch.cat((hx,hx4),1))))
            hx = self.upscore2(d4)

            d3 = self.relu_d3(self.bn_d3(self.conv_d3(torch.cat((hx,hx3),1))))
            hx = self.upscore2(d3)

            d2 = self.relu_d2(self.bn_d2(self.conv_d2(torch.cat((hx,hx2),1))))
            hx = self.upscore2(d2)

            d1 = self.relu_d1(self.bn_d1(self.conv_d1(torch.cat((hx,hx1),1))))

            residual = self.conv_d0(d1)

            return x + residual
        
        
    

In [42]:
class BasNet(tf.keras.Model):
    def __init__(self):
        super(BasNet,self).__init__(name='res_net_encoder')
        
        '''
        Input layer 
        '''
        
        self.conv1 = layers.Conv2D(64,3,1,padding='same')
        self.bn = layers.BatchNormalization()
        self.relu = layers.ReLU()
        '''
        ResNetwork Encoder
        '''
        
        # stage 1
        self.conv2_1 = ResBlock(64)
        self.conv2_2 = ResBlock(64)
        self.conv2_3 = ResBlock(64)

        # stage 2
        self.conv3_1 = ResBlock(128, 2)
        self.conv3_2 = ResBlock(128)
        self.conv3_3 = ResBlock(128)
        self.conv3_4 = ResBlock(128)

        # stage 3
        self.conv4_1 = ResBlock(256, 2)
        self.conv4_2 = ResBlock(256)
        self.conv4_3 = ResBlock(256)
        self.conv4_4 = ResBlock(256)
        self.conv4_5 = ResBlock(256)
        self.conv4_6 = ResBlock(256)
        
        # stage 4
        self.conv5_1 = ResBlock(512, 2)
        self.conv5_2 = ResBlock(512)
        self.conv5_3 = ResBlock(512)
        
        self.max_pool = layers.MaxPool2D(2,2)
        
        # stage 5
        self.resb6_1 = ResBlock(512)
        self.resb6_2 = ResBlock(512)
        self.resb6_3 = ResBlock(512)
        
        self.max_pool2 = layers.MaxPool2D(2,2)
        
        # stage 6
        self.resb7_1 = ResBlock(512)
        self.resb7_2 = ResBlock(512)
        self.resb7_3 = ResBlock(512)
        
        
        '''
        Bridge
        '''
        self.convbg_1 = layers.Conv2D(512,3,dilation_rate=2 , padding='same')
        self.bnbg_1 = layers.BatchNormalization()
        self.relubg_1 = layers.ReLU()
        self.convbg_m = layers.Conv2D(512,3,dilation_rate=2 , padding='same')
        self.bnbg_m = layers.BatchNormalization()
        self.relubg_m = layers.ReLU()
        self.convbg_2 = layers.Conv2D(512,3,dilation_rate=2 , padding='same')
        self.bnbg_2 = layers.BatchNormalization()
        self.relubg_2 = layers.ReLU()
        
        '''
        Decoder
        '''
        #stage 6 
        self.conv6d_1 = layers.Conv2D(512,3,padding='same')
        self.bn6d_1 = layers.BatchNormalization()
        self.relu6d_1 = layers.ReLU()
        
        self.conv6d_2 = layers.Conv2D(512,3,padding='same')
        self.bn6d_2 = layers.BatchNormalization()
        self.relu6d_2 = layers.ReLU()
        
        self.conv6d_3 = layers.Conv2D(512,3,padding='same')
        self.bn6d_3 = layers.BatchNormalization()
        self.relu6d_3 = layers.ReLU()
        
        #stage 5
        self.conv5d_1 = layers.Conv2D(512,3,padding='same')
        self.bn5d_1 = layers.BatchNormalization()
        self.relu5d_1 = layers.ReLU()
        
        self.conv5d_2 = layers.Conv2D(512,3,padding='same')
        self.bn5d_2 = layers.BatchNormalization()
        self.relu5d_2 = layers.ReLU()
        
        self.conv5d_3 = layers.Conv2D(512,3,padding='same')
        self.bn5d_3 = layers.BatchNormalization()
        self.relu5d_3 = layers.ReLU()
        
        #stage 4 
        self.conv4d_1 = layers.Conv2D(512,3,padding='same')
        self.bn4d_1 = layers.BatchNormalization()
        self.relu4d_1 = layers.ReLU()
        
        self.conv4d_2 = layers.Conv2D(512,3,padding='same')
        self.bn4d_2 = layers.BatchNormalization()
        self.relu4d_2 = layers.ReLU()
        
        self.conv4d_3 = layers.Conv2D(256,3,padding='same')
        self.bn4d_3 = layers.BatchNormalization()
        self.relu4d_3 = layers.ReLU()
        
        #stage 3
        self.conv3d_1 = layers.Conv2D(256,3,padding='same')
        self.bn3d_1 = layers.BatchNormalization()
        self.relu3d_1 = layers.ReLU()
        
        self.conv3d_2 = layers.Conv2D(256,3,padding='same')
        self.bn3d_2 = layers.BatchNormalization()
        self.relu3d_2 = layers.ReLU()
        
        self.conv3d_3 = layers.Conv2D(128,3,padding='same')
        self.bn3d_3 = layers.BatchNormalization()
        self.relu3d_3 = layers.ReLU()
        
        #stage 2
        self.conv2d_1 = layers.Conv2D(128,3,padding='same')
        self.bn2d_1 = layers.BatchNormalization()
        self.relu2d_1 = layers.ReLU()
        
        self.conv2d_2 = layers.Conv2D(128,3,padding='same')
        self.bn2d_2 = layers.BatchNormalization()
        self.relu2d_2 = layers.ReLU()
        
        self.conv2d_3 = layers.Conv2D(64,3,padding='same')
        self.bn2d_3 = layers.BatchNormalization()
        self.relu2d_3 = layers.ReLU()
        
        #stage1
        self.conv1d_1 = layers.Conv2D(64,3,padding='same')
        self.bn1d_1 = layers.BatchNormalization()
        self.relu1d_1 = layers.ReLU()
        
        self.conv1d_2 = layers.Conv2D(64,3,padding='same')
        self.bn1d_2 = layers.BatchNormalization()
        self.relu1d_2 = layers.ReLU()
        
        self.conv1d_3 = layers.Conv2D(64,3,padding='same')
        self.bn1d_3 = layers.BatchNormalization()
        self.relu1d_3 = layers.ReLU()
        
          ## -------------Bilinear Upsampling--------------
            
        self.upscore6 = layers.UpSampling2D(32,interpolation='bilinear')
        self.upscore5 = layers.UpSampling2D(16,interpolation='bilinear')
        self.upscore4 = layers.UpSampling2D(8,interpolation='bilinear')
        self.upscore3 = layers.UpSampling2D(4,interpolation='bilinear')
        self.upscore2 = layers.UpSampling2D(2,interpolation='bilinear')
        
        ## -------------Side Output--------------
        self.outconvb = layers.Conv2D(1,3,padding='same')
        self.outconv6 = layers.Conv2D(1,3,padding='same')
        self.outconv5 = layers.Conv2D(1,3,padding='same')
        self.outconv4 = layers.Conv2D(1,3,padding='same')
        self.outconv3 = layers.Conv2D(1,3,padding='same')
        self.outconv2 = layers.Conv2D(1,3,padding='same')
        self.outconv1 = layers.Conv2D(1,3,padding='same')
        
        ## Refine module 
        self.refNet = RefUnet(64)
        
    
    def call(self, x):
        hx = x
        #--------------Encoder-------------
        hx = self.conv1(hx)
        hx = self.bn(hx)
        hx = self.relu(hx)
        # stage1
        h1 = self.conv2_1(hx)
        h1 = self.conv2_2(h1)
        h1 = self.conv2_3(h1)
        #stage2 
        h2 = self.conv3_1(h1)
        h2 = self.conv3_2(h2)
        h2 = self.conv3_3(h2)
        h2 = self.conv3_4(h2)
        #stage 3
        h3 = self.conv4_1(h2)
        h3 = self.conv4_2(h3)
        h3 = self.conv4_3(h3)
        h3 = self.conv4_4(h3)
        h3 = self.conv4_5(h3)
        h3 = self.conv4_6(h3)
        #stage 4 
        h4 = self.conv5_1(h3)
        h4 = self.conv5_2(h4)
        h4 = self.conv5_3(h4)
        
        hx = self.max_pool(h4)
        
        hx = resb6_1(hx)
        hx = resb6_2(hx)
        h5 = resb6_3(hx)
        
        hx = self.max_pool(h5)
        
        hx = resb7_1(hx)
        hx = resb7_2(hx)
        h6 = resb7_3(hx)
        
        #------------Bridge--------------
        
        hx = self.relubg_1(self.bnbg_1(self.convbg_1(h6))) 
        hx = self.relubg_m(self.bnbg_m(self.convbg_m(hx)))
        hbg = self.relubg_2(self.bnbg_2(self.convbg_2(hx)))
        
        ## -------------Decoder-------------
        hx = self.relu6d_1(self.bn6d_1(self.conv6d_1(tf.concat((hbg,h6),1))))
        hx = self.relu6d_m(self.bn6d_m(self.conv6d_m(hx)))
        hd6 = self.relu6d_2(self.bn6d_2(self.conv6d_2(hx)))

        hx = self.upscore2(hd6) # 8 -> 16

        hx = self.relu5d_1(self.bn5d_1(self.conv5d_1(tf.concat((hx,h5),1))))
        hx = self.relu5d_m(self.bn5d_m(self.conv5d_m(hx)))
        hd5 = self.relu5d_2(self.bn5d_2(self.conv5d_2(hx)))

        hx = self.upscore2(hd5) # 16 -> 32

        hx = self.relu4d_1(self.bn4d_1(self.conv4d_1(tf.concat((hx,h4),1))))
        hx = self.relu4d_m(self.bn4d_m(self.conv4d_m(hx)))
        hd4 = self.relu4d_2(self.bn4d_2(self.conv4d_2(hx)))

        hx = self.upscore2(hd4) # 32 -> 64

        hx = self.relu3d_1(self.bn3d_1(self.conv3d_1(tf.concat((hx,h3),1))))
        hx = self.relu3d_m(self.bn3d_m(self.conv3d_m(hx)))
        hd3 = self.relu3d_2(self.bn3d_2(self.conv3d_2(hx)))

        hx = self.upscore2(hd3) # 64 -> 128

        hx = self.relu2d_1(self.bn2d_1(self.conv2d_1(tf.concat((hx,h2),1))))
        hx = self.relu2d_m(self.bn2d_m(self.conv2d_m(hx)))
        hd2 = self.relu2d_2(self.bn2d_2(self.conv2d_2(hx)))

        hx = self.upscore2(hd2) # 128 -> 256

        hx = self.relu1d_1(self.bn1d_1(self.conv1d_1(tf.concat((hx,h1),1))))
        hx = self.relu1d_m(self.bn1d_m(self.conv1d_m(hx)))
        hd1 = self.relu1d_2(self.bn1d_2(self.conv1d_2(hx)))
        
        db = self.outconvb(hbg)
        db = self.upscore6(db) # 8->256

        d6 = self.outconv6(hd6)
        d6 = self.upscore6(d6) # 8->256

        d5 = self.outconv5(hd5)
        d5 = self.upscore5(d5) # 16->256

        d4 = self.outconv4(hd4)
        d4 = self.upscore4(d4) # 32->256

        d3 = self.outconv3(hd3)
        d3 = self.upscore3(d3) # 64->256

        d2 = self.outconv2(hd2)
        d2 = self.upscore2(d2) # 128->256

        d1 = self.outconv1(hd1) # 256

        ## -------------Refine Module-------------
        dout = self.refNet(d1) # 256
        
        return tf.keras.activations.sigmoid(dout), tf.keras.activations.sigmoid(d1), tf.keras.activations.sigmoid(d2), tf.keras.activations.sigmoid(d3), tf.keras.activations.sigmoid(d4), tf.keras.activations.sigmoid(d5), tf.keras.activations.sigmoid(d6), tf.keras.activations.sigmoid(db)

In [43]:
model = BasNet()

In [45]:
model.build([1,224,224,3])

ValueError: Exception encountered when calling layer "res_block" (type ResBlock).

in user code:

    File "C:\Users\shiva\AppData\Local\Temp\ipykernel_10224\3447930896.py", line 23, in call  *
        x1 = layers.add([x,x1])
    File "E:\MACHINELEARNING\TENSORFLOW_P\tensorflow\lib\site-packages\keras\layers\merge.py", line 791, in add  **
        return Add(**kwargs)(inputs)
    File "E:\MACHINELEARNING\TENSORFLOW_P\tensorflow\lib\site-packages\keras\utils\traceback_utils.py", line 67, in error_handler
        raise e.with_traceback(filtered_tb) from None
    File "E:\MACHINELEARNING\TENSORFLOW_P\tensorflow\lib\site-packages\keras\layers\merge.py", line 78, in _compute_elemwise_op_output_shape
        raise ValueError(

    ValueError: Inputs have incompatible shapes. Received shapes (224, 224, 64) and (56, 56, 128)


Call arguments received:
  • x=tf.Tensor(shape=(1, 224, 224, 64), dtype=float32)

In [47]:
res = ResBlock(3)

In [49]:
res.build([1,224,224,3])

In [50]:
res.summary()

Model: "res_block"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_521 (Conv2D)         multiple                  84        
                                                                 
 batch_normalization_493 (Ba  multiple                 12        
 tchNormalization)                                               
                                                                 
 conv2d_522 (Conv2D)         multiple                  84        
                                                                 
 batch_normalization_494 (Ba  multiple                 12        
 tchNormalization)                                               
                                                                 
 re_lu_318 (ReLU)            multiple                  0         
                                                                 
Total params: 192
Trainable params: 180
Non-trainable par