In [None]:
from typing import Optional

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F

# 论文：FECAM: Frequency Enhanced Channel Attention Mechanism for Time Series Forecasting
# 论文地址：https://arxiv.org/abs/2212.01209

try:
    from torch.fft import rfft, irfft
except ImportError:

    def rfft(x, d):
        t = torch.fft.fft(x, dim=(-d))
        r = torch.stack((t.real, t.imag), -1)
        return r

    def irfft(x, d):
        t = torch.fft.ifft(torch.complex(x[:, :, 0], x[:, :, 1]), dim=(-d))
        return t.real


def dct(x, norm=None):
    """
    Discrete Cosine Transform, Type II (a.k.a. the DCT)

    For the meaning of the parameter `norm`, see:
    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html

    :param x: the input signal
    :param norm: the normalization, None or 'ortho'
    :return: the DCT-II of the signal over the last dimension
    """
    x_shape = x.shape
    N = x_shape[-1]
    x = x.contiguous().view(-1, N)

    v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)

    Vc = rfft(v, 1)

    k = -torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
    W_r = torch.cos(k)
    W_i = torch.sin(k)

    V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i

    if norm == "ortho":
        V[:, 0] /= np.sqrt(N) * 2
        V[:, 1:] /= np.sqrt(N / 2) * 2

    V = 2 * V.view(*x_shape)

    return V


class FreqEnhancedAttention(nn.Module):
    def __init__(
        self,
        n_dims: int,
        n_channels: int,
        reduction: Optional[int] = 2,
        dropout: Optional[float] = 0.1,
    ) -> None:
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(n_channels, n_channels // reduction, bias=False),
            nn.Dropout(p=dropout),
            nn.ReLU(inplace=True),
            nn.Linear(n_channels // reduction, n_channels, bias=False),
            nn.Sigmoid(),
        )

        self.dct_norm = nn.LayerNorm(n_channels, eps=1e-6)  # for lstm on length-wise

    def forward(self, x):
        b, c, l = x.size()  # (B,C,L) (32,96,512)
        list = []
        for i in range(c):
            freq = dct(x[:, i, :])
            list.append(freq)

        stack_dct = torch.stack(list, dim=1)

        lr_weight = F.normalize()
        lr_weight = self.dct_norm(stack_dct)
        lr_weight = self.fc(lr_weight)
        lr_weight = self.dct_norm(lr_weight)

        return x * lr_weight


if __name__ == "__main__":
    input = torch.rand(8, 7, 96)
    block = FreqEnhancedAttention(96)
    result = block(input)
    print("input_tensor.shape:", input.shape)
    print("result.shape:", result.shape)

In [None]:
from typing import Optional

import math
import torch
from torch import nn


class Mix(nn.Module):
    def __init__(self, m: float = -0.80) -> None:
        super(Mix, self).__init__()
        w = torch.nn.Parameter(torch.FloatTensor([m]), requires_grad=True)
        w = torch.nn.Parameter(w, requires_grad=True)
        self.w = w
        self.mix_block = nn.Sigmoid()

    def forward(self, feature1: torch.Tensor, feature2: torch.Tensor) -> torch.Tensor:
        mix_factor = self.mix_block(self.w)
        out = feature1 * mix_factor.expand_as(feature1) + feature2 * (
            1 - mix_factor.expand_as(feature2)
        )
        return out


#
class FineGrainedCAttention(nn.Module):
    """
    Adaptive Fine-Grained Channel Attention (FCA) module for Time Series or Image Data.
    This module captures fine-grained cross-channel interaction adaptively.

    Reference: Unsupervised Bidirectional Contrastive Reconstruction and Adaptive Fine-Grained Channel Attention Networks for image dehazing

    URL: https://www.sciencedirect.com/science/article/abs/pii/S0893608024002387
    """

    def __init__(
        self,
        n_dims: int,
        n_channels: int,
        b: float = 1.0,
        gamma: float = 2.0,
        bias: Optional[bool] = False,
    ) -> None:
        super(FineGrainedCAttention, self).__init__()

        # Dimension assertion
        assert n_dims in [1, 2], "The dimension of input data must be either 1 or 2."
        self.n_dims = n_dims

        # Create the adaptive fine-grained channel attention module
        self.avg_pool = (
            nn.AdaptiveAvgPool2d(1) if n_dims == 2 else nn.AdaptiveAvgPool1d(1)
        )

        # 一维卷积
        t = int(abs((math.log(n_channels, 2) + b) / gamma))
        kernal_size = t if t % 2 else t + 1

        self.conv1 = nn.Conv1d(
            1, 1, kernel_size=kernal_size, padding=kernal_size // 2, bias=bias
        )

        self.fc = nn.Conv2d(n_channels, n_channels, 1, padding=0, bias=True)
        self.sigmoid = nn.Sigmoid()
        self.mix = Mix()

    def forward(self, x):
        pool = self.avg_pool(x)
        x1 = self.conv1(pool.squeeze(-1).transpose(-1, -2)).transpose(
            -1, -2
        )  # (1,64,1)

        print("x1:", x1.size())

        x2 = self.fc(pool).squeeze(-1).transpose(-1, -2)  # (1,1,64)
        print("x2:", x2.size())

        out1 = (
            torch.sum(torch.matmul(x1, x2), dim=1).unsqueeze(-1).unsqueeze(-1)
        )  # (1,64,1,1)
        out1 = self.sigmoid(out1)
        out2 = (
            torch.sum(torch.matmul(x2.transpose(-1, -2), x1.transpose(-1, -2)), dim=1)
            .unsqueeze(-1)
            .unsqueeze(-1)
        )

        out2 = self.sigmoid(out2)
        out = self.mix(out1, out2)
        out = (
            self.conv1(out.squeeze(-1).transpose(-1, -2))
            .transpose(-1, -2)
            .unsqueeze(-1)
        )
        out = self.sigmoid(out)

        return x * out


if __name__ == "__main__":
    input = torch.rand(1, 64, 256, 256)
    block = FineGrainedCAttention(n_dims=2, n_channels=64)
    output = block(input)
    print(output.size())

torch.Size([1, 64, 256, 256])


In [None]:
from typing import Optional

import torch
from torch import nn

from channel_attention.utils import create_conv_layer


class SEAttention(nn.Module):
    """
    The Squeeze-and-Excitation Attention for Time Series (1D) or Image (2D) Analysis.
    This module adaptively recalibrates channel-wise feature responses by explicitly modeling interdependencies between channels.

    Reference: "Squeeze-and-Excitation Networks" by Jie Hu, Li Shen, et al.

    URL: https://arxiv.org/abs/1709.01507
    """

    def __init__(
        self,
        n_dims: int,
        n_channels: int,
        reduction: Optional[int] = 4,
        bias: bool = False,
    ) -> None:
        """
        1D Squeeze-and-Excitation Attention for Time Series Analysis or
        2D Squeeze-and-Excitation Attention for Image Analysis.

        :param n_dims: (int) The dimension of input data, either 1 (time series) or 2 (image).
        :param n_channels: (int) The number of input channels of time series data.
        :param reduction: (int) The reduction ratio for the intermediate layer in the SE block.
        :param bias: (bool) Whether to include bias terms in the linear layers.
        """
        super().__init__()

        # Validate the input dimension
        assert n_dims in [1, 2], "The dimension of input data must be either 1 or 2."

        # The dimension of inputs data
        self.n_dims = n_dims

        # Global average pooling layer to squeeze the spatial dimensions
        self.avg_pool = (
            nn.AdaptiveAvgPool2d(1) if n_dims == 2 else nn.AdaptiveAvgPool1d(1)
        )

        # Fully connected layers for the excitation operation
        self.fc = nn.Sequential(
            nn.Linear(n_channels, n_channels // reduction, bias=bias),
            nn.ReLU(inplace=True),
            nn.Linear(n_channels // reduction, n_channels, bias=bias),
            nn.Sigmoid(),
        )

        # View shape for reshaping the excitation output
        self.view_shape = (1, 1) if n_dims == 2 else (1,)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the SEAttention module.

        :param x: (torch.Tensor)
                  1D Time Series: Input tensor of shape (batch_size, channels, seq_len);
                  2D Image: Input tensor of shape (batch_size, channels, height, width).

        :return: (torch.Tensor) Output tensor of the same shape as input
        """
        # Get the batch size, number of channels
        batch_size, channels = x.size()[:2]

        # Perform the Squeeze operation
        y = self.avg_pool(x).view(batch_size, channels)

        # Perform the Excitation operation
        y = self.fc(y).view(batch_size, channels, *self.view_shape)

        # Scale the input tensor with the recalibrated weights
        return x * y.expand_as(x)


class MultiSEAttention(nn.Module):
    """
    Multi-Branch Squeeze-and-Excitation Attention Module for Time Series (1D) or Image (2D) Analysis.
    This module enhances the representational power of the standard SE block by incorporating multiple branches and adaptive style assignment.
    """

    def __init__(
        self, n_dims: int, n_channels: int, reduction: int = 4, n_branches: int = 3
    ) -> None:
        super(MultiSEAttention, self).__init__()

        # Dimension assertion
        assert n_dims in [1, 2], "The dimension of input data must be either 1 or 2."
        self.n_dims = n_dims

        # Create the average pooling layer and activation function
        self.avg_pool = (
            nn.AdaptiveAvgPool2d(1) if n_dims == 2 else nn.AdaptiveAvgPool1d(1)
        )
        self.activation = nn.Sigmoid()

        # Store the reduction ratio, number of branches, and number of channels
        self.reduction = reduction
        self.n_branches = n_branches
        self.n_channels = n_channels
        new_channels = n_channels * n_branches

        # Layers for multi-branch excitation
        self.fc = nn.Sequential(
            create_conv_layer(
                n_dims=n_dims,
                in_channels=new_channels,
                out_channels=new_channels // self.reduction,
                kernel_size=1,
                bias=True,
                groups=n_branches,
            ),
            nn.ReLU(inplace=True),
            create_conv_layer(
                n_dims=n_dims,
                in_channels=new_channels // self.reduction,
                out_channels=new_channels,
                kernel_size=1,
                bias=True,
                groups=n_branches,
            ),
        )

        # Style assignment layer
        self.style_assigner = nn.Linear(n_channels, n_branches, bias=False)

        # Repeat size for reshaping the output
        self.repeat_size = (1, 1) if n_dims == 2 else (1,)

    def _style_assignment(
        self, channel_mean: torch.Tensor, batch_size: int
    ) -> torch.Tensor:
        """
        Assign styles to each channel based on the channel mean.

        :param channel_mean: (torch.Tensor) The mean values of each channel, shape (batch_size, n_channels, 1, 1).
        :param batch_size: (int) The batch size of the input tensor.

        :return: (torch.Tensor) Style assignment probabilities for each branch, shape (batch_size, n_branches).
        """
        style_assignment = self.style_assigner(channel_mean.view(batch_size, -1))
        style_assignment = nn.functional.softmax(style_assignment, dim=1)
        return style_assignment

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        avg_y = self.avg_pool(x)
        print("avg_y:", avg_y.size())
        batch_size, n_channels = avg_y.shape[:2]

        style_assignment = self._style_assignment(avg_y, batch_size=batch_size)  # B x N
        avg_y = avg_y.repeat(1, self.n_branches, *self.repeat_size)  # B x NC x 1 x 1
        z = self.fc(avg_y)  # B x NC x 1 x 1
        style_assignment = style_assignment.repeat_interleave(n_channels, dim=1)

        if self.n_dims == 1:
            z = z * style_assignment[:, :, None]
        else:
            z = z * style_assignment[:, :, None, None]

        # [batch_size, n_channels, 1, 1]
        z = torch.sum(
            z.view(batch_size, self.n_branches, n_channels, *self.repeat_size), dim=1
        )  # B x C x 1 x 1
        z = self.activation(z)

        return x * z

avg_y: torch.Size([1, 64, 1, 1])
style_assignment: torch.Size([1, 3])
avg_y repeated: torch.Size([1, 192, 1, 1])
z: torch.Size([1, 192, 1, 1])
torch.Size([1, 64, 256, 256])
avg_y: torch.Size([1, 64, 1])
style_assignment: torch.Size([1, 3])
avg_y repeated: torch.Size([1, 192, 1])
z: torch.Size([1, 192, 1])
torch.Size([1, 64, 256])
