In [86]:
# Define the Channel Attention module
class ChannelAttention(tf.keras.layers.Layer):
    def __init__(self, out_planes, ratio=2):
        super(ChannelAttention, self).__init__()
        self.conv = tf.keras.layers.Conv2D(out_planes, 1, padding='same', use_bias=False)
        self.avg_pool = tf.keras.layers.GlobalAveragePooling2D()
        self.max_pool = tf.keras.layers.GlobalMaxPooling2D()
        self.fc11 = tf.keras.layers.Conv2D(out_planes // ratio, 1, padding='same', use_bias=False)
        self.fc12 = tf.keras.layers.Conv2D(out_planes, 1, padding='same', use_bias=False)
        self.fc21 = tf.keras.layers.Conv2D(out_planes // ratio, 1, padding='same', use_bias=False)
        self.fc22 = tf.keras.layers.Conv2D(out_planes, 1, padding='same', use_bias=False)
        self.relu1 = tf.keras.layers.ReLU()
        self.sigmoid = tf.keras.layers.Activation('sigmoid')
        self.reshape= tf.keras.layers.Reshape((1, 1, -1))

    def call(self, x):
        print('input to attention', x.shape)
        x = self.conv(x)
        print('chA conv1 ', x.shape)
        # x=self.avg_pool(x)
        # print('chA avgpool ', x.shape)
        # x = tf.keras.layers.Reshape((1, 1, -1))(x)
        # x=self.fc11(x)
        # print('chA fc11 ', x.shape)
        # x=self.relu1(x)
        # print('chA relu1 ', x.shape)
        # x=self.fc12(x)
        # print('chA fc12 ', x.shape)
        
        avg_out = self.fc12(self.relu1(self.fc11(self.reshape(self.avg_pool(x)))))
        max_out = self.fc22(self.relu1(self.fc21(self.reshape(self.max_pool(x)))))
        out = avg_out + max_out
        return x * self.sigmoid(out)

In [87]:
def conv3otherRelu(filters, kernel_size=None, stride=None, padding=None):
    if kernel_size is None:
        kernel_size = 3
    if stride is None:
        stride = 1
    if padding is None:
        padding = 'same'

    return tf.keras.Sequential([
        tf.keras.layers.Conv2D(filters, kernel_size, stride, padding, use_bias=True),
        tf.keras.layers.LeakyReLU()
    ])

In [88]:
# Define the ACBlock module
class ACBlock(tf.keras.layers.Layer):
    def __init__(self, out_planes):
        super(ACBlock, self).__init__()
        self.squre = tf.keras.layers.Conv2D(out_planes, 3, padding='same', use_bias=False)
        self.cross_ver = tf.keras.layers.Conv2D(out_planes, (1, 3), padding='same', use_bias=False)
        self.cross_hor = tf.keras.layers.Conv2D(out_planes, (3, 1), padding='same', use_bias=False)
        self.bn = tf.keras.layers.BatchNormalization()
        self.relu = tf.keras.layers.ReLU()

    def call(self, x):
        print('ACBlock output shape', x.shape)
        x1 = self.squre(x)
        x2 = self.cross_ver(x)
        x3 = self.cross_hor(x)
        print('ACBlock output shape', self.relu(self.bn(x1 + x2 + x3)).shape)
        return self.relu(self.bn(x1 + x2 + x3))

In [89]:
# class ACBlock(tf.keras.layers.Layer):
#     def __init__(self, in_planes, out_planes):
#         super(ACBlock, self).__init__()
#         self.squre = tf.keras.layers.Conv2D(out_planes, kernel_size=3, padding='same', stride=1)
#         self.cross_ver = tf.keras.layers.Conv2D(out_planes, kernel_size=(1, 3), padding=(0, 1), stride=1)
#         self.cross_hor = tf.keras.layers.Conv2D(out_planes, kernel_size=(3, 1), padding=(1, 0), stride=1)
#         self.bn = tf.keras.layers.BatchNormalization()
#         self.ReLU = tf.keras.layers.ReLU()

#     def forward(self, x):
#         x1 = self.squre(x)
#         x2 = self.cross_ver(x)
#         x3 = self.cross_hor(x)
#         return self.ReLU(self.bn(x1 + x2 + x3))

In [92]:
class MACUNet(tf.keras.Model):
    def __init__(self, band_num, class_num):
        super(MACUNet, self).__init__()
        self.band_num = band_num
        self.class_num = class_num

        # channels = [32, 64, 128, 256, 512]
        channels = [16, 32, 64, 128, 256, 512]
        self.conv1 = tf.keras.Sequential(
           [ ACBlock(channels[0]),
            ACBlock(channels[0])]
        )
        self.conv12 = tf.keras.Sequential(
            [tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2)),
            ACBlock(channels[1])]
        )
        self.conv13 = tf.keras.Sequential(
            [tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2)),
            ACBlock(channels[2])]
        )
        self.conv14 = tf.keras.Sequential(
            [tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2)),
            ACBlock(channels[3])]
        )

        self.conv2 = tf.keras.Sequential(
           [tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2)),
            ACBlock(channels[1]),
            ACBlock(channels[1])]
        )
        self.conv23 = tf.keras.Sequential(
            [tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2)),
            ACBlock(channels[2])]
        )
        self.conv24 = tf.keras.Sequential(
            [tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2)),
            ACBlock(channels[3])]
        )

        self.conv3 = tf.keras.Sequential(
           [ tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2)),
            ACBlock(channels[2]),
            ACBlock(channels[2]),
            ACBlock(channels[2])]
        )
        self.conv34 = tf.keras.Sequential(
            [tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2)),
            ACBlock(channels[3])]
        )

        self.conv4 = tf.keras.Sequential(
            [tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2)),
            ACBlock(channels[3]),
            ACBlock(channels[3]),
            ACBlock(channels[3])]
        )

        self.conv5 = tf.keras.Sequential(
            [tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2)),
            ACBlock(channels[4]),
            ACBlock(channels[4]),
            ACBlock(channels[4])]
        )

        self.skblock4 = ChannelAttention(channels[3]*2, 16)
        self.skblock3 = ChannelAttention(channels[2]*2, 16)
        self.skblock2 = ChannelAttention(channels[1]*2, 16)
        self.skblock1 = ChannelAttention(channels[0]*2, 16)

        self.deconv4 = tf.keras.layers.Conv2DTranspose(channels[3], kernel_size=2, strides=(2, 2))
        self.deconv43 = tf.keras.layers.Conv2DTranspose(channels[2], kernel_size=2, strides=(2, 2))
        self.deconv42 = tf.keras.layers.Conv2DTranspose(channels[1], kernel_size=2, strides=(2, 2))
        self.deconv41 = tf.keras.layers.Conv2DTranspose(channels[0], kernel_size=2, strides=(2, 2))


        self.conv6 = tf.keras.Sequential(
           [ ACBlock( channels[3]),
            ACBlock(channels[3])]
        )

        self.deconv3 = tf.keras.layers.Conv2DTranspose(channels[2], kernel_size=2, strides=(2, 2))
        self.deconv32 = tf.keras.layers.Conv2DTranspose(channels[1], kernel_size=2, strides=(2, 2))
        self.deconv31 = tf.keras.layers.Conv2DTranspose(channels[0], kernel_size=2, strides=(2, 2))
        self.conv7 = tf.keras.Sequential(
            [ACBlock(channels[2]),
            ACBlock(channels[2])]
        )

        self.deconv2 = tf.keras.layers.Conv2DTranspose(channels[1], kernel_size=2, strides=(2, 2))
        self.deconv21 = tf.keras.layers.Conv2DTranspose(channels[0], kernel_size=2, strides=(2, 2))
        self.conv8 = tf.keras.Sequential(
            [ACBlock(channels[1]),
            ACBlock(channels[1])]
        )

        self.deconv1 = tf.keras.layers.Conv2DTranspose(channels[0], kernel_size=2, strides=(2, 2))
        self.conv9 = tf.keras.Sequential(
            [ACBlock(channels[0]),
            ACBlock(channels[0])]
        )

        self.conv10 = tf.keras.layers.Conv2D(self.class_num, kernel_size=1, strides=1)

    def call(self, x):
        print('input shape', x.shape)
        conv1 = self.conv1(x)
        print('conv1 shape', conv1.shape)
        conv12 = self.conv12(conv1)
        print('conv12 shape', conv12.shape)
        conv13 = self.conv13(conv12)
        print('conv13 shape', conv13.shape)
        conv14 = self.conv14(conv13)
        print('conv14 shape', conv14.shape)


        conv2 = self.conv2(conv1)
        print('conv2 shape', conv2.shape)
        conv23 = self.conv23(conv2)
        print('conv23 shape', conv23.shape)
        conv24 = self.conv24(conv23)
        print('conv24 shape', conv24.shape)

        conv3 = self.conv3(conv2)
        print('conv3 shape', conv3.shape)
        conv34 = self.conv34(conv3)
        print('conv34 shape', conv34.shape)
        conv4 = self.conv4(conv3)
        print('conv4 shape', conv4.shape)
        conv5 = self.conv5(conv4)
        print('conv5 shape', conv5.shape)
        
        deconv4 = self.deconv4(conv5)
        print('deconv4 shape', deconv4.shape)
        deconv43 = self.deconv43(deconv4)
        print('deconv43 shape', deconv43.shape)
        deconv42 = self.deconv42(deconv43)
        print('deconv42 shape', deconv42.shape)
        deconv41 = self.deconv41(deconv42)
        print('deconv41 shape', deconv41.shape)

        conv6 = tf.concat([deconv4, conv4, conv34, conv24, conv14], axis=-1)
        print('conv6 concat shape', conv6.shape)
        conv6 = self.skblock4(conv6)
        print('conv6 attention shape', conv6.shape)
        conv6 = self.conv6(conv6)
        del deconv4, conv4, conv34, conv24, conv14, conv5

        deconv3 = self.deconv3(conv6)
        deconv32 = self.deconv32(deconv3)
        deconv31 = self.deconv31(deconv32)

        conv7 = tf.concat([deconv3, deconv43, conv3, conv23, conv13], axis=-1)
        conv7 = self.skblock3(conv7)
        conv7 = self.conv7(conv7)
        del deconv3, deconv43, conv3, conv23, conv13, conv6

        deconv2 = self.deconv2(conv7)
        deconv21 = self.deconv21(deconv2)

        conv8 = tf.concat([deconv2, deconv42, deconv32, conv2, conv12], axis=-1)
        conv8 = self.skblock2(conv8)
        conv8 = self.conv8(conv8)
        del deconv2, deconv42, deconv32, conv2, conv12, conv7

        deconv1 = self.deconv1(conv8)
        conv9 = tf.concat([deconv1, deconv41, deconv31, deconv21, conv1], axis=-1)
        conv9 = self.skblock1(conv9)
        conv9 = self.conv9(conv9)
        # conv9 = self.seblock(conv9)
        del deconv1, deconv41, deconv31, deconv21, conv1, conv8

        output = self.conv10(conv9)
        print('output shape', output.shape)
        return output

In [95]:
import tensorflow as tf

# Define the input layer
inputs = tf.keras.layers.Input((512, 512, 3))

# Create an instance of your MACUNet model
macu_net = MACUNet(3, 1)

# Call the MACUNet model on the input tensor to get the output tensor
outputs = macu_net(inputs)

# Define the model with input and output tensors
model = tf.keras.Model(inputs, outputs)

# Display the model summary
model.summary()


input shape (None, 512, 512, 3)
ACBlock output shape (None, 512, 512, 3)
ACBlock output shape (None, 512, 512, 16)
ACBlock output shape (None, 512, 512, 16)
ACBlock output shape (None, 512, 512, 16)
ACBlock output shape (None, 512, 512, 3)
ACBlock output shape (None, 512, 512, 16)
ACBlock output shape (None, 512, 512, 16)
ACBlock output shape (None, 512, 512, 16)
conv1 shape (None, 512, 512, 16)
ACBlock output shape (None, 256, 256, 16)
ACBlock output shape (None, 256, 256, 32)
ACBlock output shape (None, 256, 256, 16)
ACBlock output shape (None, 256, 256, 32)
conv12 shape (None, 256, 256, 32)
ACBlock output shape (None, 128, 128, 32)
ACBlock output shape (None, 128, 128, 64)
ACBlock output shape (None, 128, 128, 32)
ACBlock output shape (None, 128, 128, 64)
conv13 shape (None, 128, 128, 64)
ACBlock output shape (None, 64, 64, 64)
ACBlock output shape (None, 64, 64, 128)
ACBlock output shape (None, 64, 64, 64)
ACBlock output shape (None, 64, 64, 128)
conv14 shape (None, 64, 64, 128)
AC