In [None]:
import keras
import tensorflow
from tensorflow.keras.layers import Layer
from tensorflow.keras.utils import plot_model
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Lambda, ReLU, Add,Dropout, Activation, Flatten, Input, PReLU,SeparableConv2D, Conv2DTranspose,concatenate,Convolution2D,ZeroPadding2D,Add,MaxPool2D
from tensorflow.keras.layers import Conv2D,Conv2DTranspose, Activation,MaxPooling3D, MaxPooling2D, BatchNormalization, UpSampling2D,AveragePooling2D,GlobalMaxPooling2D,GlobalAveragePooling2D
import random


class PWFS(Layer):
    def __init__(self):
        super(PWFS, self).__init__()

    def call(self, inputs):

        # Split feature map to 3 sub-group channel-wisely
        split1, split2, split3 = tf.split(inputs, num_or_size_splits=3, axis=-1)

        # Compute median using element-wise operations and minimum/maximum functions
        min_split = tf.minimum(tf.minimum(split1, split2), split3)
        max_split = tf.maximum(tf.maximum(split1, split2), split3)
        median_values = split1 + split2 + split3 - min_split - max_split

        # Averaging max and median sub-group
        average_values = 0.5 * (max_split + median_values)


        return average_values

    def get_config(self):
        # No additional hyperparameters to configure
        config = super(MFM, self).get_config()
        return config







def MassAtt(input_tensor, ratio=4):

        # Channel Attention Map
        num_input_channels = input_tensor.get_shape().as_list()[-1]
        # Squeeze operation: Global average pooling
        squeeze = tf.reduce_mean(input_tensor, axis=[1, 2], keepdims=True)
        # Excitation operation: Two fully connected layers
        excitation = tf.keras.layers.Dense(units=num_input_channels // ratio, activation='relu')(squeeze)
        channel_att_map = tf.keras.layers.Dense(units=num_input_channels, activation='sigmoid')(excitation)


        # Spatial Attention Map
        spatial_attention = tf.keras.layers.Lambda(lambda x: tf.reduce_mean(x, axis=-1, keepdims=True))(input_tensor)
        x= tf.keras.layers.Conv2D(filters= 2,kernel_size=3,kernel_initializer='he_uniform',activation='relu',strides=2,padding='same')(spatial_attention)
        x= tf.keras.layers.Conv2D(filters= 4,kernel_size=3,kernel_initializer='he_uniform',activation='relu',strides=2,padding='same')(x)
        x = Conv2DTranspose(4, (3, 3), activation='relu', padding='same',strides=2)(x)
        spatial_att_map = Conv2DTranspose(1, (3, 3), activation='sigmoid', padding='same',strides=2)(x)

        # attention
        attention= channel_att_map * spatial_att_map * input_tensor

        return attention





def two_path(input_tensor, filters, kernel_size, strides=(1, 1), padding='valid'):

        # Get input shape
        input_shape = input_tensor.get_shape().as_list()
        batch_size, height, width, input_channels = input_shape

        # Calculate the number of channels per group
        channels_per_group = input_channels // 2
        filters_per_group = filters // 2



        # Shuffle the channel indices randomly
        channel_indices = list(range(input_channels))
        random.shuffle(channel_indices)

        # Rearrange the input tensor based on shuffled channel indices
        input_tensor_shuffled = tf.gather(input_tensor, channel_indices, axis=-1)


        # Split input and filters into groups
        input_groups = tf.split(input_tensor_shuffled, 2, axis=-1)


        # H path- First stage of convolution
        convH1 = tf.keras.layers.Conv2D(filters=filters_per_group,
                                          kernel_size=kernel_size,
                                          kernel_initializer='he_uniform',
                                          strides=strides,
                                          padding=padding)(input_groups[0])
        convH1= BatchNormalization()(convH1)
        convH1= ReLU()(convH1)



        # L path- First stage of convolution
        convL1 = tf.keras.layers.Conv2D(filters=filters_per_group,
                                          kernel_size=kernel_size,
                                          kernel_initializer='he_uniform',
                                          dilation_rate= 2,
                                          strides=strides,
                                          padding=padding)(input_groups[1])
        convL1    = BatchNormalization()(convL1)
        convL1    = ReLU()(convL1)



        # Concat first stage
        X1     = tf.concat([convH1,convL1], axis=-1)


        # H path- Second stage of convolution
        convH2 = tf.keras.layers.SeparableConv2D(filters=filters_per_group,
                                          kernel_size=kernel_size,
                                          kernel_initializer='he_uniform',
                                          strides=strides,
                                          padding=padding)(X1)
        convH2= BatchNormalization()(convH2)
        convH2= ReLU()(convH2)



        # L path- Second stage of convolution
        convL2 = tf.keras.layers.SeparableConv2D(filters=filters_per_group,
                                          kernel_size=kernel_size,
                                          kernel_initializer='he_uniform',
                                          dilation_rate= 2,
                                          strides=strides,
                                          padding=padding)(X1)
        convL2= BatchNormalization()(convL2)
        convL2= ReLU()(convL2)


        # Concat second stage
        X2     = tf.concat([convH2,convL2], axis=-1)



        # H-path-Third stage of convolution
        convH3 = tf.keras.layers.Conv2D(filters=filters_per_group,
                                          kernel_size=kernel_size,
                                          kernel_initializer='he_uniform',
                                          strides=strides,
                                          padding=padding)(X2)
        convH3= BatchNormalization()(convH3)
        convH3= ReLU()(convH3)


        # L-path-Third stage of convolution
        convL3 = tf.keras.layers.Conv2D(filters=filters_per_group,
                                          kernel_size=kernel_size,
                                          kernel_initializer='he_uniform',
                                          dilation_rate= 2,
                                          strides=strides,
                                          padding=padding)(X2)
        convL3= BatchNormalization()(convL3)
        convL3= ReLU()(convL3)


        # Final concat
        output_tensor = tf.concat([convH3,convL3], axis=-1)

        return output_tensor

#********************************
#********************************



input = tensorflow.keras.Input(shape=(64, 64, 1))

# Block 1

b1= tf.keras.layers.Conv2D(filters=66, kernel_size=(3, 3), kernel_initializer='he_uniform', padding='same')(input)
b1 = BatchNormalization()(b1)
b1 = ReLU()(b1)
b1 = tf.keras.layers.SeparableConv2D(filters=66, kernel_size=(3, 3), kernel_initializer='he_uniform', padding='same')(b1)
b1 = BatchNormalization()(b1)
b1 = ReLU()(b1)
b1 = tf.keras.layers.Conv2D(filters=66, kernel_size=(3, 3), kernel_initializer='he_uniform', padding='same')(b1)
b1 = BatchNormalization()(b1)
b1 = MaxPooling2D(pool_size=2)(b1)
b1 = ReLU()(b1)
b1 = Dropout(0.4)(b1)

# Block 2

b2 = two_path(b1, filters=72, kernel_size=3, strides=(1, 1), padding='same')
b2 = MassAtt(b2, ratio=4)
b2 = Conv2D(72, kernel_size=(1, 1),kernel_initializer='he_uniform', padding='same')(b2)
b2 = BatchNormalization()(b2)
b2 = MaxPooling2D(pool_size=2)(b2)
b2 = ReLU()(b2)
b2 = Dropout(0.4)(b2)

# Block 3

b3 = tf.keras.layers.Conv2D(filters=78, kernel_size=(3, 3), kernel_initializer='he_uniform', padding='same')(b2)
b3 = BatchNormalization()(b3)
b3 = ReLU()(b3)
b3 = tf.keras.layers.SeparableConv2D(filters=78, kernel_size=(3, 3), kernel_initializer='he_uniform', padding='same')(b3)
b3 = BatchNormalization()(b3)
b3 = ReLU()(b3)
b3 = tf.keras.layers.Conv2D(filters=78, kernel_size=(3, 3), kernel_initializer='he_uniform', padding='same')(b3)
b3 = BatchNormalization()(b3)
b3 = MaxPooling2D(pool_size=2)(b3)
b3= ReLU()(b3)
b3= Dropout(0.4)(b3)

# Block 4

b4 = two_path(b3, filters=84, kernel_size=3, strides=(1, 1), padding='same')
b4 = MassAtt(b4, ratio=4)
b4 = Conv2D(84, kernel_size=(1, 1),kernel_initializer='he_uniform', padding='same')(b4)
b4 = BatchNormalization()(b4)
b4 = MaxPooling2D(pool_size=2)(b4)
b4 = ReLU()(b4)
b4 = Dropout(0.4)(b4)

b1 = PWFS()(b1)
b2 = PWFS()(b2)
b3 = PWFS()(b3)

b1 = GlobalAveragePooling2D()(b1)
b2 = GlobalAveragePooling2D()(b2)
b3 = GlobalAveragePooling2D()(b3)
b4 = GlobalAveragePooling2D()(b4)

f = tf.concat([b1,b2,b3,b4], axis=-1)

output= Dense(7, activation='softmax')(f)

model = tensorflow.keras.Model(inputs=input, outputs=output)

model.summary()


Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_4 (InputLayer)        [(None, 64, 64, 1)]          0         []                            
                                                                                                  
 conv2d_9 (Conv2D)           (None, 64, 64, 66)           660       ['input_4[0][0]']             
                                                                                                  
 batch_normalization_12 (Ba  (None, 64, 64, 66)           264       ['conv2d_9[0][0]']            
 tchNormalization)                                                                                
                                                                                                  
 re_lu_12 (ReLU)             (None, 64, 64, 66)           0         ['batch_normalization_12[0