In [1]:
import torch
import torch.nn as nn

class ConvertTo3Channels(nn.Module):
    def __init__(self):
        super(ConvertTo3Channels, self).__init__()
        self.conv = nn.Conv2d(512, 3, kernel_size=1)

    def forward(self, x):
        x = x.view(x.size(0), -1, 32, 32)
        return self.conv(x)

# Example usage
input_tensor = torch.randn(1, 512, 32, 32)  # Example input tensor

# Convert to 3 channels
converter = ConvertTo3Channels()
output_tensor = converter(input_tensor)
print("Output tensor shape:", output_tensor.shape)

Output tensor shape: torch.Size([1, 3, 32, 32])


In [2]:
output_tensor

tensor([[[[-2.7423e-01, -5.3842e-01, -8.9325e-02,  ...,  1.2988e-03,
           -3.5259e-01, -2.5830e-01],
          [ 2.1624e-01,  4.8532e-01,  6.4083e-01,  ..., -2.8504e-01,
           -7.1230e-01, -1.0108e-01],
          [-1.5990e-01,  1.6821e-01,  1.3357e-01,  ...,  5.1624e-01,
           -9.0361e-01, -2.9006e-01],
          ...,
          [ 1.2752e-01, -1.5059e-02, -5.7059e-01,  ...,  3.1145e-01,
           -5.3018e-01,  3.4071e-01],
          [-2.6404e-01, -2.0116e-01,  4.6474e-01,  ...,  5.6449e-01,
           -4.7560e-01, -4.2593e-01],
          [ 3.4443e-01, -5.0213e-01,  5.0383e-01,  ..., -7.3862e-01,
           -6.0495e-01, -8.9527e-01]],

         [[-8.2730e-02, -2.3951e-01, -1.1211e+00,  ...,  1.3933e-01,
            2.3821e-02,  7.3373e-01],
          [-9.5581e-01,  1.1010e-01,  3.0707e-01,  ...,  5.1498e-01,
            4.7180e-01, -7.0015e-02],
          [ 3.9114e-01, -6.1494e-01,  9.4509e-01,  ...,  4.6940e-01,
           -1.2980e-01,  1.9701e-01],
          ...,
     

In [3]:
import torch
import torch.nn as nn

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv(x)
        return self.sigmoid(x)

class ChannelSpatialAttention(nn.Module):
    def __init__(self, in_planes, ratio=16, kernel_size=7):
        super(ChannelSpatialAttention, self).__init__()
        self.channel_attention = ChannelAttention(in_planes, ratio)
        self.spatial_attention = SpatialAttention(kernel_size)

    def forward(self, x):
        channel_out = self.channel_attention(x)
        spatial_out = self.spatial_attention(x)
        out = torch.mul(channel_out, spatial_out)
        return out

class CombinedAttention(nn.Module):
    def __init__(self, in_channels, ratio=16, kernel_size=7):
        super(CombinedAttention, self).__init__()
        self.attention = ChannelSpatialAttention(in_channels, ratio, kernel_size)
        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.tanh = nn.Tanh()
        self.conv_final = nn.Conv2d(in_channels + in_channels, in_channels, kernel_size=1)  # Updated to include input channels

    def forward(self, x):
        att = self.attention(x)
        conv_out = self.conv(att)
        tanh_out = self.tanh(conv_out)
        combined = torch.cat([tanh_out, x], dim=1)  # Concatenate along channel dimension
        final_out = self.conv_final(combined)  # Apply conv to match input size
        return final_out

# Example usage:
feature = torch.randn(8, 128 , 128 , 128)  # Input feature
model = CombinedAttention(in_channels=128)
output = model(feature)
print("Output shape:", output.shape)


Output shape: torch.Size([8, 128, 128, 128])


In [16]:
import torch
import torch.nn as nn

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, in_channels, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)  # Updated convolution layer
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x_avg = torch.mean(x, dim=1, keepdim=True)
        print("Avg" , x_avg.shape)
        x_max, _ = torch.max(x, dim=1, keepdim=True)
        print("max" , x_max.shape)
        x_concat = torch.cat([x_avg, x_max], dim=1)
        print("con" , x_concat.shape)
        print("conv" , self.conv(x_concat))
        return self.sigmoid(self.conv(x_concat))

class ChannelSpatialAttention(nn.Module):
    def __init__(self, in_channels, ratio=16, kernel_size=7):
        super(ChannelSpatialAttention, self).__init__()
        self.channel_attention = ChannelAttention(in_channels, ratio)
        self.spatial_attention = SpatialAttention(in_channels * 2, kernel_size)  # Updated input channels

    def forward(self, x):
        channel_out = self.channel_attention(x)
        spatial_out = self.spatial_attention(x)
        out = torch.mul(channel_out, spatial_out)
        return out

class CombinedAttention(nn.Module):
    def __init__(self, in_channels, ratio=16, kernel_size=7):
        super(CombinedAttention, self).__init__()
        self.attention = ChannelSpatialAttention(in_channels, ratio, kernel_size)
        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.tanh = nn.Tanh()
        self.conv_final = nn.Conv2d(in_channels * 2, in_channels, kernel_size=1)  # Updated to include input channels

    def forward(self, x):
        att = self.attention(x)
        conv_out = self.conv(att)
        tanh_out = self.tanh(conv_out)
        combined = torch.cat([tanh_out, x], dim=1)  # Concatenate along channel dimension
        final_out = self.conv_final(combined)  # Apply conv to match input size
        return final_out

# Example usage:
feature = torch.randn(8, 256 , 256 , 256)  # Input feature
in_channels = feature.size(1)  # Get the number of input channels dynamically
model = CombinedAttention(in_channels=in_channels)
output = model(feature)
print("Output shape:", output.shape)

Avg torch.Size([8, 1, 256, 256])
max torch.Size([8, 1, 256, 256])
con torch.Size([8, 2, 256, 256])
conv tensor([[[[-5.4023e-01, -8.9021e-01, -6.2424e-01,  ..., -4.7003e-01,
           -2.9704e-01, -2.3351e-01],
          [-8.3550e-01, -8.9075e-01, -1.4888e-01,  ..., -2.1433e-01,
            4.5876e-01,  6.7162e-01],
          [-1.5441e-01, -8.4979e-02,  7.6736e-01,  ...,  9.7400e-01,
            1.3031e+00,  9.8229e-01],
          ...,
          [-1.3581e-01, -1.9889e-01,  8.7171e-01,  ...,  1.3067e+00,
            1.9177e+00,  1.6084e+00],
          [ 1.6394e-02,  1.1206e-01,  8.9854e-01,  ...,  1.4135e+00,
            1.7978e+00,  1.7870e+00],
          [ 4.1042e-01,  4.9643e-01,  1.0271e+00,  ...,  1.6706e+00,
            1.6909e+00,  1.6232e+00]]],


        [[[-7.2294e-01, -8.4463e-01, -4.2732e-01,  ..., -5.5632e-01,
           -2.9744e-01, -9.8651e-02],
          [-7.8191e-01, -8.7735e-01, -3.6199e-01,  ..., -1.0555e-02,
            4.7615e-01,  4.6984e-01],
          [-1.2327e-0