In [164]:
import tensorflow as tf
from tensorflow.keras import layers, models

class ChannelAttention(layers.Layer):
    def __init__(self, channel, reduction=16):
        super(ChannelAttention, self).__init__()
        self.maxpool = layers.GlobalMaxPooling2D()
        self.avgpool = layers.GlobalAveragePooling2D()
        self.se = models.Sequential([
            layers.Conv2D(channel // reduction, (1, 1), padding='same', use_bias=False),
            layers.ReLU(),
            layers.Conv2D(channel, (1, 1), padding='same', use_bias=False)
        ])
        self.sigmoid = layers.Activation('sigmoid')

    def call(self, x):
        max_result = self.maxpool(x)
        max_result = tf.expand_dims(tf.expand_dims(max_result, 1), 1)  # Shape (batch, 1, 1, channels)
        
        avg_result = self.avgpool(x)
        avg_result = tf.expand_dims(tf.expand_dims(avg_result, 1), 1)  # Shape (batch, 1, 1, channels)

        max_out = self.se(max_result)
        avg_out = self.se(avg_result)

        output = self.sigmoid(max_out + avg_out)
        return output


class SpatialAttention(layers.Layer):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv = layers.Conv2D(1, (kernel_size, kernel_size), padding='same', use_bias=False)
        self.sigmoid = layers.Activation('sigmoid')

    def call(self, x):
        max_result = tf.reduce_max(x, axis=-1, keepdims=True)
        avg_result = tf.reduce_mean(x, axis=-1, keepdims=True)
        result = tf.concat([max_result, avg_result], axis=-1)
        output = self.conv(result)
        output = self.sigmoid(output)
        return output


class CBAMBlock(layers.Layer):
    def __init__(self, channel=512, reduction=16, kernel_size=49):
        super(CBAMBlock, self).__init__()
        self.ca = ChannelAttention(channel=channel, reduction=reduction)
        self.sa = SpatialAttention(kernel_size=kernel_size)

    def call(self, x):
        residual = x
        out = x * self.ca(x)
        out = out * self.sa(out)
        return out + residual
    
class GCModule(tf.keras.Model):
    def __init__(self, channel, reduction=16):
        super(GCModule, self).__init__()
        self.conv = layers.Conv2D(1, kernel_size=1)
        self.softmax = layers.Softmax(axis=2)
        self.transform = models.Sequential([
            layers.Conv2D(channel // reduction, kernel_size=1),
            layers.LayerNormalization(),
            layers.ReLU(),
            layers.Conv2D(channel, kernel_size=1)
        ])
    
    def context_modeling(self, x):
        b, h, w, c = x.shape
        input_x = x
        # print(input_x.shape)
        input_x = tf.reshape(input_x, (b, h * w, c))
        context = self.conv(x)
        context = tf.reshape(context, (b, h * w, 1))
        context = tf.transpose(context, perm=[0, 2, 1])
        # print(context.shape)
        out = tf.matmul(input_x, context,transpose_a=True,transpose_b=True)
        out = tf.reshape(out, (b, 1, 1, c))
        return out
    
    def call(self, x):
        context = self.context_modeling(x)
        y = self.transform(context)
        return x + y

In [167]:
class CustomModel(tf.keras.Model):
    def __init__(self, in_channels=3, num_classes=5):
        super(CustomModel, self).__init__()
        self.conv1 = layers.Conv2D(8,7,activation='relu')
        self.conv2 = layers.Conv2D(32,7,activation='relu')
        self.conv3 = layers.Conv2D(64,3,activation='relu')
        self.conv4 = layers.Conv2D(128,3,activation='relu')
        self.conv5 = layers.Conv2D(192,3,activation='relu')
        self.conv6 = layers.Conv2D(384,3,activation='relu')
        self.conv7 = layers.Conv2D(512,3,activation='relu')
        self.conv8 = layers.Conv2D(768,3,activation='relu')
        self.conv9 = layers.Conv2D(1024,3,activation='relu')
        self.context1 = GCModule(32)
        self.context2 = GCModule(128)
        self.context3 = GCModule(384)
        self.context4 = GCModule(768)
        self.cbam = CBAMBlock(1024)
        self.flat = layers.Flatten()
        self.dropout = layers.Dropout(0.5)
        self.max_pool = layers.MaxPool2D(2)
        self.fc1 = layers.Dense(1024)
        self.fc2 = layers.Dense(512)
        self.fc3 = layers.Dense(num_classes)
    def call(self, x):
       x = self.conv1(x)
       x = self.conv2(x)
       x = self.context1(x)
       x = self.max_pool(x)
       x = self.conv3(x)
       x = self.conv4(x)
       x = self.context2(x)
       x = self.max_pool(x)
       x = self.conv5(x)
       x = self.conv6(x)
       x = self.context3(x)
       x = self.max_pool(x)
       x = self.conv7(x)
       x = self.conv8(x)
       x = self.context4(x)
       x = self.max_pool(x)
       x = self.conv9(x)
       x = self.cbam(x)
       x = self.flat(x)
       x = self.dropout(x)
    #    print(x.shape)
       x = self.fc1(x)
       x = self.fc2(x)
       x = self.fc3(x)
       return x

In [168]:
model = CustomModel(in_channels=3, num_classes=2)

# Create a random image with shape (1, 224, 224, 3)
random_image = tf.random.normal([1, 224, 224, 3])

output = model(random_image)
print(output.shape)

(1, 2)


In [None]:
# """ 
# PyTorch implementation of Gcnet: Non-local networks meet squeeze-excitation networks and beyond

# As described in https://arxiv.org/pdf/1904.11492

# GC module contains three steps: (a) a context modeling module which aggregates the features of all positions 
# together to form a global context feature; (b) a feature transform module to capture the channel-wise 
# interdependencies; and (c) a fusion module to merge the global context feature into features of all positions.
# """




# import torch
# from torch import nn


# class GCModule(nn.Module):
#     def __init__(self, channel, reduction=16):
#         super().__init__()
#         self.conv = nn.Conv2d(channel, 1, kernel_size=1)
#         self.softmax = nn.Softmax(dim=2)
#         self.transform = nn.Sequential(
#             nn.Conv2d(channel, channel // reduction, kernel_size=1),
#             nn.LayerNorm([channel // reduction, 1, 1]),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(channel // reduction, channel, kernel_size=1)
#         )
    
#     def context_modeling(self, x):
#         b, c, h, w = x.shape
#         input_x = x
#         input_x = input_x.reshape(b, c, h * w)
#         context = self.conv(x)
#         context = context.reshape(b, 1, h * w).transpose(1, 2)
#         print(input_x.shape)
#         print(context.shape)
#         out = torch.matmul(input_x, context)
#         out = out.reshape(b, c, 1, 1)
#         return out
    
#     def forward(self, x):
#         context = self.context_modeling(x)
#         y = self.transform(context)
#         return x + y
    
# if __name__ == "__main__":
#     x = torch.randn(2, 64, 32, 32)
#     attn = GCModule(64)
#     y = attn(x)
#     # print(y.shape)

torch.Size([2, 64, 1024])
torch.Size([2, 1024, 1])


In [None]:
# import tensorflow as tf
# from tensorflow.keras import layers, models

# class GCModule(tf.keras.Model):
#     def __init__(self, channel, reduction=16):
#         super(GCModule, self).__init__()
#         self.conv = layers.Conv2D(1, kernel_size=1)
#         self.softmax = layers.Softmax(axis=2)
#         self.transform = models.Sequential([
#             layers.Conv2D(channel // reduction, kernel_size=1),
#             layers.LayerNormalization(),
#             layers.ReLU(),
#             layers.Conv2D(channel, kernel_size=1)
#         ])
    
#     def context_modeling(self, x):
#         b, h, w, c = x.shape
#         input_x = x
#         print(input_x.shape)
#         input_x = tf.reshape(input_x, (b, h * w, c))
#         context = self.conv(x)
#         context = tf.reshape(context, (b, h * w, 1))
#         context = tf.transpose(context, perm=[0, 2, 1])
#         print(context.shape)
#         out = tf.matmul(input_x, context,transpose_a=True,transpose_b=True)
#         out = tf.reshape(out, (b, 1, 1, c))
#         return out
    
#     def call(self, x):
#         context = self.context_modeling(x)
#         y = self.transform(context)
#         return x + y
# if __name__ == "__main__":
#     x = torch.randn(1, 32, 32, 64)
#     attn = GCModule(64)
#     y = attn(x)
#     print(y.shape)

(1, 32, 32, 64)
(1, 1, 1024)
(1, 32, 32, 64)
(1, 1, 1024)
(1, 32, 32, 64)
