In [None]:
import keras._tf_keras.keras as keras
from keras._tf_keras.keras.models import Model
from keras._tf_keras.keras.layers import Input,Conv3D, MaxPooling3D, Dropout, UpSampling3D, concatenate, BatchNormalization, GlobalAveragePooling3D, multiply
import tensorflow as tf

In [None]:
class CAM(keras.layers.Layer):
    def __init__(self, num_filters):
        super(CAM, self).__init__()
        self.Glob = GlobalAveragePooling3D(keepdims=True)
        self.Conv1 = Conv3D(num_filters, (1, 1, 1), activation='relu', padding='same')
        self.Conv2 = Conv3D(num_filters, (1, 1, 1), activation='sigmoid', padding='same')

    def call(self, x):
        x = self.Glob(x)
        x = self.Conv1(x)
        x = self.Conv2(x)
        return x

class SAM(keras.layers.Layer):
    def __init__(self, num_filters, **kwargs):
        super(SAM, self).__init__()
        self.conv1 = Conv3D(num_filters, (3, 3, 3), activation='relu', padding='same')
        self.norm = BatchNormalization()
        self.conv2 = Conv3D(num_filters, (1, 1, 1), activation='relu', padding='same')
    def call(self, x):
        x = self.conv1(x)
        x = self.norm(x)
        x = self.conv2(x)
        return x

class Attention(keras.layers.Layer):
    def __init__(self, num_filters:int, **kwargs):
        super(Attention, self).__init__()
        self.cam = CAM(num_filters)
        self.sam = SAM(num_filters)
    def call(self, fm):
        x = self.cam(fm)
        y = self.sam(fm)
        am = multiply([x,y])
        fm = multiply([fm, am])
        return fm


In [None]:
class MY_3D_Net(keras.Model):
    def __init__(self, input_shape):
        super(MY_3D_Net, self).__init__()
        self.conv1 = Conv3D(64, (3, 3, 3), activation='relu', padding='same', input_shape=input_shape)
        self.norm1 = BatchNormalization()
        self.att1 = Attention(64)

        self.maxpool1 = MaxPooling3D(pool_size=(2, 2, 2))

        self.conv2 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')
        self.norm2 = BatchNormalization()
        self.att2 = Attention(128)

        self.maxpool2 = MaxPooling3D(pool_size=(2, 2, 2))

        self.conv3 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')
        self.norm3 = BatchNormalization()
        self.att3 = Attention(256)

        self.maxpool3 = MaxPooling3D(pool_size=(2, 2, 2))

        self.conv4 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')
        self.norm4 = BatchNormalization()
        self.conv5 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')
        self.norm5 = BatchNormalization()
        self.att4 = Attention(256)

        self.up1 = UpSampling3D(size=(2, 2, 2))

        self.conv6 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')
        self.norm6 = BatchNormalization()
        self.att5 = Attention(256)

        self.up2 = UpSampling3D(size=(2, 2, 2))

        self.conv7 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')
        self.norm7 = BatchNormalization()
        self.att6 = Attention(128)

        self.up3 = UpSampling3D(size=(2, 2, 2))

        self.conv8 = Conv3D(8, (3, 3, 3), activation='relu', padding='same')
        self.norm8 = BatchNormalization()
        self.outputs = Attention(8)

    def call(self, x):
        c1 = self.conv1(x)
        n1 = self.norm1(c1)
        at1 = self.att1(n1)

        m1 = self.maxpool1(at1)

        c2 = self.conv2(m1)
        n2 = self.norm2(c2)
        at2 = self.att2(n2)

        m2 = self.maxpool2(at2)

        c3 = self.conv3(m2)
        n3 = self.norm3(c3)
        at3 = self.att3(n3)

        m3 = self.maxpool3(at3)

        c4 = self.conv4(m3)
        n4 = self.norm4(c4)
        c5 = self.conv5(n4)
        n5 = self.norm5(c5)
        at4 = self.att4(n5)

        up1 = self.up1(n5)
        con1 = concatenate([up1, at3])

        c6 = self.conv6(con1)
        n6 = self.norm6(c6)
        at5 = self.att5(n6)

        up2 = self.up2(n5)
        con2 = concatenate([up2, at2])

        c7 = self.conv7(con2)
        n7 = self.norm7(c7)
        at6 = self.att6(n7)

        up3 = self.up3(n5)
        con3 = concatenate([up3, at1])

        c8 = self.conv8(con3)
        n8 = self.norm8(c8)

        x = self.outputs(n8)
        return x