In [11]:
from typing import Optional

import torch 
import torch.nn as nn
import torch.nn.functional as F

class SAM(nn.Module):
    def __init__(self, bias=False):
        super(SAM, self).__init__()
        self.bias = bias
        self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3, dilation=1, bias=self.bias)

    def forward(self, x):
        max = torch.max(x,1)[0].unsqueeze(1)
        avg = torch.mean(x,1).unsqueeze(1)
        concat = torch.cat((max,avg), dim=1)
        output = self.conv(concat)
        output = F.sigmoid(output) * x 
        return output

class SpatialAttention2D(nn.Module):

    def __init__(self, kernel_size: Optional[int] = 7, bias: Optional[bool] = False) -> None:
        super().__init__()
        assert kernel_size % 2 == 1, "The kernel size must be odd."
        self.kernel_size = kernel_size
        self.stride = 1
        self.padding = kernel_size // 2
        self.bias = bias
        
        self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=kernel_size, stride=self.stride, padding=self.padding, dilation=1, bias=self.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        
        max = torch.max(x,1)[0].unsqueeze(1)
        avg = torch.mean(x,1).unsqueeze(1)
        concat = torch.cat((max,avg), dim=1)
        output = self.conv(concat)
        output = F.sigmoid(output) * x 
        return output
    
    
class SpatialAttention1D(nn.Module):

    def __init__(self, kernel_size: Optional[int] = 7, bias: Optional[bool] = False) -> None:
        super().__init__()
        assert kernel_size % 2 == 1, "The kernel size must be odd."
        self.kernel_size = kernel_size
        self.stride = 1
        self.padding = kernel_size // 2
        self.bias = bias
        
        self.conv = nn.Conv1d(in_channels=2, out_channels=1, kernel_size=kernel_size, stride=self.stride, padding=self.padding, dilation=1, bias=self.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        
        max = torch.max(x,1)[0].unsqueeze(1)
        avg = torch.mean(x,1).unsqueeze(1)
        concat = torch.cat((max,avg), dim=1)
        output = self.conv(concat)
        output = F.sigmoid(output) * x 
        return output
    

class SpatialAttention(nn.Module):
    
    def __init__(self, n_dims: int, kernel_size: Optional[int] = 7, bias: Optional[bool] = False) -> None:
        super().__init__()
        assert n_dims in [1, 2], "The dimension of input data must be either 1 or 2."
        
        assert kernel_size % 2 == 1, "The kernel size must be odd."
        self.kernel_size = kernel_size
        self.stride = 1
        self.padding = kernel_size // 2
        self.bias = bias
        
        # parameters for 1D conv
        parameters = {
            'in_channels': 2,
            'out_channels': 1,
            'kernel_size': self.kernel_size,
            'stride': self.stride,
            'padding': self.padding,
            'dilation': 1,
            'bias': self.bias
        }

        self.conv = nn.Conv1d(**parameters) if n_dims == 1 else nn.Conv2d(**parameters)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        
        max = torch.max(x,1)[0].unsqueeze(1)
        avg = torch.mean(x,1).unsqueeze(1)
        concat = torch.cat((max,avg), dim=1)
        output = self.conv(concat)
        output = F.sigmoid(output) * x 
        return output
    

class CAM(nn.Module):
    def __init__(self, channels, r):
        super(CAM, self).__init__()
        self.channels = channels
        self.r = r
        self.linear = nn.Sequential(
            nn.Linear(in_features=self.channels, out_features=self.channels//self.r, bias=True),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=self.channels//self.r, out_features=self.channels, bias=True))

    def forward(self, x):
        max = F.adaptive_max_pool2d(x, output_size=1)
        avg = F.adaptive_avg_pool2d(x, output_size=1)
        b, c, _, _ = x.size()
        linear_max = self.linear(max.view(b,c)).view(b, c, 1, 1)
        linear_avg = self.linear(avg.view(b,c)).view(b, c, 1, 1)
        output = linear_max + linear_avg
        output = F.sigmoid(output) * x
        return output
    
    
class ChannelAttention2D(nn.Module):
    def __init__(self, n_channels: int, r):
        super().__init__()
        self.channels = n_channels
        self.r = r
        self.linear = nn.Sequential(
            nn.Linear(in_features=self.channels, out_features=self.channels//self.r, bias=True),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=self.channels//self.r, out_features=self.channels, bias=True))

    def forward(self, x):
        max = F.adaptive_max_pool2d(x, output_size=1)
        avg = F.adaptive_avg_pool2d(x, output_size=1)
        b, c, _, _ = x.size()
        linear_max = self.linear(max.view(b,c)).view(b, c, 1, 1)
        linear_avg = self.linear(avg.view(b,c)).view(b, c, 1, 1)
        output = linear_max + linear_avg
        output = F.sigmoid(output) * x
        return output
    
class CBAM(nn.Module):
    def __init__(self, channels, r):
        super(CBAM, self).__init__()
        self.channels = channels
        self.r = r
        self.sam = SAM(bias=False)
        self.cam = CAM(channels=self.channels, r=self.r)

    def forward(self, x):
        output = self.cam(x)
        output = self.sam(output)
        return output + x

In [12]:
sam = SpatialAttention(n_dims=2)
x = torch.rand(size=(10, 3, 224, 224))

x = sam(x)
print(x.shape)

sam = SpatialAttention(n_dims=1)
x = torch.rand(size=(10, 3, 224))
x = sam(x)
print(x.shape)

torch.Size([10, 3, 224, 224])
torch.Size([10, 3, 224])


In [3]:
sam = SpatialAttention1D()
x = torch.rand(size=(10, 3, 224))

x = sam(x)
x.shape

torch.Size([10, 3, 224])

In [None]:
from typing import Optional

import torch
from torch import nn


class SEAttention1D(nn.Module):
    """
    1D Squeeze-and-Excitation Attention for Time Series Analysis.
    This module adaptively recalibrates channel-wise feature responses by explicitly modeling interdependencies between channels.
    Reference: "Squeeze-and-Excitation Networks" by Jie Hu, Li Shen, et al.
    URL: https://arxiv.org/abs/1709.01507
    """

    def __init__(
        self, n_channels: int, reduction: Optional[int] = 8, bias: bool = False
    ) -> None:
        """
        :param n_channels: (int) The number of input channels of time series data.
        :param reduction: (int) The reduction ratio for the intermediate layer in the SE block.
        :param bias: (bool) Whether to include bias terms in the linear layers.
        """
        super(SEAttention1D, self).__init__()
        # Global average pooling layer to squeeze the temporal dimension
        self.avg_pool = nn.AdaptiveAvgPool1d(1)

        # Fully connected layers for the excitation operation
        self.fc = nn.Sequential(
            nn.Linear(n_channels, n_channels // reduction, bias=bias),
            nn.ReLU(inplace=True),
            nn.Linear(n_channels // reduction, n_channels, bias=bias),
            nn.Sigmoid(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the SEAttention module.

        :param x: (torch.Tensor) Input tensor of shape (batch_size, channels, seq_len)

        :return: (torch.Tensor) Output tensor of the same shape as input
        """
        # Get the batch size, number of channels, and sequence length
        batch_size, channels, _ = x.size()

        # Perform the Squeeze operation
        y = self.avg_pool(x).view(batch_size, channels)

        # Perform the Excitation operation
        y = self.fc(y).view(batch_size, channels, 1)

        # Scale the input tensor with the recalibrated weights
        return x * y.expand_as(x)


class SEAttention2D(nn.Module):
    """
    2D Squeeze-and-Excitation Attention for Image Analysis.
    This module adaptively recalibrates channel-wise feature responses by explicitly modeling interdependencies between channels.
    Reference: "Squeeze-and-Excitation Networks" by Jie Hu, Li Shen, et al.
    URL: https://arxiv.org/abs/1709.01507
    """

    def __init__(
        self, n_channels: int, reduction: Optional[int] = 4, bias: bool = False
    ) -> None:
        """
        :param n_channels: (int) The number of input channels of time series data.
        :param reduction: (int) The reduction ratio for the intermediate layer in the SE block.
        :param bias: (bool) Whether to include bias terms in the linear layers.
        """
        super(SEAttention2D, self).__init__()
        # Global average pooling layer to squeeze the spatial dimensions
        self.avg_pool = nn.AdaptiveAvgPool2d(1)

        # Fully connected layers for the excitation operation
        self.fc = nn.Sequential(
            nn.Linear(n_channels, n_channels // reduction, bias=bias),
            nn.ReLU(inplace=True),
            nn.Linear(n_channels // reduction, n_channels, bias=bias),
            nn.Sigmoid(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the SEAttention module.

        :param x: (torch.Tensor) Input tensor of shape (batch_size, channels, seq_len)

        :return: (torch.Tensor) Output tensor of the same shape as input
        """
        # Get the batch size, number of channels
        batch_size, channels, _, _ = x.size()

        # Perform the Squeeze operation
        y = self.avg_pool(x).view(batch_size, channels)
        print(y.size())

        # Perform the Excitation operation
        y = self.fc(y).view(batch_size, channels, *(1, 1))
        print(y.size())

        # Scale the input tensor with the recalibrated weights
        return x * y.expand_as(x)
    
    
class SEAttention(nn.Module):
    """
    The Squeeze-and-Excitation Attention for Time Series (1D) or Image (2D) Analysis.
    This module adaptively recalibrates channel-wise feature responses by explicitly modeling interdependencies between channels.
    Reference: "Squeeze-and-Excitation Networks" by Jie Hu, Li Shen, et al.
    URL: https://arxiv.org/abs/1709.01507
    """
    
    def __init__(
        self, n_dims: int, n_channels: int, reduction: Optional[int] = 4, bias: bool = False
    ) -> None:
        """
        1D Squeeze-and-Excitation Attention for Time Series Analysis or
        2D Squeeze-and-Excitation Attention for Image Analysis.
        
        :param n_dims: (int) The dimension of input data, either 1 (time series) or 2 (image).
        :param n_channels: (int) The number of input channels of time series data.
        :param reduction: (int) The reduction ratio for the intermediate layer in the SE block.
        :param bias: (bool) Whether to include bias terms in the linear layers.
        """
        super().__init__()
        
        # Validate the input dimension
        assert n_dims in [1, 2], "The dimension of input data must be either 1 or 2."
        
        # The dimension of inputs data
        self.n_dims = n_dims
        
        # Global average pooling layer to squeeze the spatial dimensions
        self.avg_pool = nn.AdaptiveAvgPool2d(1) if n_dims == 2 else nn.AdaptiveAvgPool1d(1)

        # Fully connected layers for the excitation operation
        self.fc = nn.Sequential(
            nn.Linear(n_channels, n_channels // reduction, bias=bias),
            nn.ReLU(inplace=True),
            nn.Linear(n_channels // reduction, n_channels, bias=bias),
            nn.Sigmoid(),
        )
        
        # View shape for reshaping the excitation output
        self.view_shape = (1, 1) if n_dims == 2 else (1,)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the SEAttention module.

        :param x: (torch.Tensor) 
                  1D Time Series: Input tensor of shape (batch_size, channels, seq_len);
                  2D Image: Input tensor of shape (batch_size, channels, height, width).

        :return: (torch.Tensor) Output tensor of the same shape as input
        """
        # Get the batch size, number of channels
        batch_size, channels = x.size()[:2]

        # Perform the Squeeze operation
        y = self.avg_pool(x).view(batch_size, channels)

        # Perform the Excitation operation
        y = self.fc(y).view(batch_size, channels, *self.view_shape)

        # Scale the input tensor with the recalibrated weights
        return x * y.expand_as(x)
    
    
    

x = torch.rand(size=(16, 16, 224, 224))

se = SEAttention(n_dims=2, n_channels=16)
print(se(x).shape)

x = torch.rand(size=(16, 16, 224))
se = SEAttention(n_dims=1, n_channels=16)
print(se(x).shape)

torch.Size([16, 16])
torch.Size([16, 16, 1, 1])
torch.Size([16, 16, 224, 224])
torch.Size([16, 16])
torch.Size([16, 16, 1])
torch.Size([16, 16, 224])


In [17]:
class ChannelAttention(nn.Module):

    def __init__(self, n_dims: int, n_channels: int, reduction: Optional[int] = 4) -> None:
        """
        
        """ 
        super().__init__()
        
        assert n_dims in [1, 2], "The dimension of input data must be either 1 or 2."

        self.n_channels = n_channels
        self.reduction = reduction
        
        self.linear = nn.Sequential(
            nn.Linear(
                in_features=self.n_channels,
                out_features=self.n_channels // self.reduction,
                bias=True,
            ),
            nn.ReLU(inplace=True),
            nn.Linear(
                in_features=self.n_channels // self.reduction,
                out_features=self.n_channels,
                bias=True,
            ),
        )
        
        # View shape for reshaping the excitation output
        self.view_shape = (1, 1) if n_dims == 2 else (1,)
        
        self.adaptive_max_pool = nn.AdaptiveMaxPool2d(1) if n_dims == 2 else nn.AdaptiveMaxPool1d(1)
        self.adaptive_avg_pool = nn.AdaptiveAvgPool2d(1) if n_dims == 2 else nn.AdaptiveAvgPool1d(1)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        max = self.adaptive_max_pool(x)
        avg = self.adaptive_avg_pool(x)
        
        batch_size, n_channels = x.size()[:2]
        linear_max = self.linear(max.view(batch_size, n_channels)).view(batch_size, n_channels, *self.view_shape)
        linear_avg = self.linear(avg.view(batch_size, n_channels)).view(batch_size, n_channels, *self.view_shape)
        
        output = linear_max + linear_avg
        
        output = F.sigmoid(output) * x
        return output
    

ca = ChannelAttention(n_dims=2, n_channels=16)
x = torch.rand(size=(16, 16, 224, 224))
x = ca(x)
print(x.shape)

ca = ChannelAttention(n_dims=1, n_channels=16)
x = torch.rand(size=(16, 16, 224))
x = ca(x)
print(x.shape)

torch.Size([16, 16, 224, 224])
torch.Size([16, 16, 224])


In [18]:
class ConvBlockAttention(nn.Module):
    """
    Convolutional Block Attention Module (CBAM) for Time Series (1D) or Image (2D) Analysis.
    This module sequentially applies Channel Attention and Spatial Attention to refine feature representations.
    
    Reference: "CBAM: Convolutional Block Attention Module" by Sanghyun Woo, et al.
    
    URL: https://arxiv.org/abs/1807.06521
    
    Also see: `ChannelAttention` and `SpatialAttention` classes.
    """
    
    def __init__(self, n_dims: int, n_channels: int, reduction: Optional[int] = 4, kernel_size: Optional[int] = 7) -> None:
        """
        Initialize the Convolutional Block Attention Module.
        
        :param n_dims: (int) The dimension of input data, either 1 (time series) or 2 (image).
        :param n_channels: (int) The number of input channels of time series or image.
        :param reduction: (int) The reduction ratio for the intermediate layer in the channel attention block.
        :param kernel_size: (int) The size of the convolutional kernel in the spatial attention block. Must be odd to maintain spatial dimensions.
        """
        super().__init__()
        
        # Initialize Channel Attention and Spatial Attention modules
        self.channel_attention = ChannelAttention(n_dims=n_dims, n_channels=n_channels, reduction=reduction)
        self.spatial_attention = SpatialAttention(n_dims=n_dims, kernel_size=kernel_size)

    def forward(self, x: torch.Tensor, with_residual: bool = True) -> torch.Tensor:
        """
        Forward pass for the Convolutional Block Attention Module.
        
        :param x: (torch.Tensor)
                  1D Time Series: Input tensor of shape (batch_size, channels, seq_len);
                  2D Image: Input tensor of shape (batch_size, channels, height, width).
        :param with_residual: (bool) Whether to include a residual connection from input to output.
                  
        :return: (torch.Tensor) Output tensor of the same shape as input.
        """
        output = self.channel_attention(x)
        output = self.spatial_attention(output)
        
        if with_residual:
            return output + x
        return output
    
    
cbam = ConvBlockAttention(n_dims=2, n_channels=16)
x = torch.rand(size=(16, 16, 224, 224))
x = cbam(x)
print(x.shape)

cbam = ConvBlockAttention(n_dims=1, n_channels=16)
x = torch.rand(size=(16, 16, 224))
x = cbam(x)
print(x.shape)

torch.Size([16, 16, 224, 224])
torch.Size([16, 16, 224])
