In [14]:
import torch
import torch.nn as nn
from torchsummary import summary

In [15]:
class SEBlock(nn.Module):
    def __init__(self, mode, channels, ratio):
        super(SEBlock, self).__init__()
        self.avg_pooling = nn.AdaptiveAvgPool2d(1)
        self.max_pooling = nn.AdaptiveMaxPool2d(1)
        if mode == "max":
            self.global_pooling = self.max_pooling
        elif mode == "avg":
            self.global_pooling = self.avg_pooling
        self.fc_layers = nn.Sequential(
            nn.Linear(in_features = channels, out_features = channels // ratio, bias = False),
            nn.ReLU(),
            nn.Linear(in_features = channels // ratio, out_features = channels, bias = False),
        )
        self.sigmoid = nn.Sigmoid()
     
    
    def forward(self, x):
        b, c, _, _ = x.shape
        v = self.global_pooling(x).view(b, c)
        v = self.fc_layers(v).view(b, c, 1, 1)
        v = self.sigmoid(v)
        return x * v


In [22]:
import torch
import torch.nn as nn
from torchsummary import summary

# 定义 SEBlock 类
class SEBlock(nn.Module):
    def __init__(self, pool_type, in_channels, reduction):
        super(SEBlock, self).__init__()
        self.global_pool = nn.AdaptiveAvgPool2d(1) if pool_type == "avg" else nn.AdaptiveMaxPool2d(1)
        self.fc1 = nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1)
        self.fc2 = nn.Conv2d(in_channels // reduction, in_channels, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        scale = self.global_pool(x)
        scale = self.fc1(scale)
        scale = nn.ReLU()(scale)
        scale = self.fc2(scale)
        scale = self.sigmoid(scale)
        return x * scale

# 定义 M_FANet 类
class M_FANet(nn.Module):
    def __init__(self, Ks, Ft, Kt, D, Nc):
        super(M_FANet, self).__init__()

        self.pointwiseconv2d = nn.Conv2d(in_channels=9, out_channels=1, kernel_size=1, stride=1, padding=0, bias=False)
        
        # 确保 padding 为整数
        padding = Ks // 2  # 使用整数进行padding计算
        self.conv2d1 = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(Ks, 1),  padding='same')
        
        self.TSconv = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=Ft, kernel_size=(1, Kt), padding=(0, Kt // 2)),  # padding 需要是一个元组
            nn.BatchNorm2d(Ft),  # 确保这里传递通道数
            nn.Conv2d(in_channels=Ft, out_channels=D * Ft, kernel_size=(30, 1), groups=Ft),
            nn.AvgPool2d(kernel_size=(1, 4))
        )
        
        self.SEBlock = SEBlock("avg", D * Ft, 16)
        
        self.TConv = nn.Sequential(
            nn.Conv2d(in_channels=D * Ft, out_channels=D * Ft, kernel_size=(1, 25), groups=D * Ft, padding=(0, 12)),
            nn.AvgPool2d(kernel_size=(1, 8))
        )
        
        self.classifier = nn.Conv2d(D * Ft, Nc, (1, 6))

    def forward(self, x):
        x = self.pointwiseconv2d(x)
        x = self.conv2d1(x)
        x = self.TSconv(x)
        x = self.SEBlock(x)
        x = self.TConv(x)
        x = self.classifier(x)
        return x

# 检查是否有可用的 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 实例化模型，假设超参数为 Ks=4, Ft=16, Kt=3, D=2, Nc=2
model = M_FANet(Ks=4, Ft=16, Kt=3, D=2, Nc=2).to(device)  # 将模型移动到 GPU

# 生成随机输入数据，假设输入形状为 (batch_size, channels, height, width)
input_data = torch.randn(1, 9, 30, 200).to(device)  # 将输入数据移动到 GPU

# 使用 torchsummary 打印模型结构
summary(model, input_data.size()[1:])  # 只传递输入的形状 (channels, height, width)

# 测试模型前向传播
output = model(input_data)
print("Output shape:", output.shape)





----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 1, 30, 200]               9
            Conv2d-2           [-1, 1, 30, 200]               5
            Conv2d-3          [-1, 16, 30, 200]              64
       BatchNorm2d-4          [-1, 16, 30, 200]              32
            Conv2d-5           [-1, 32, 1, 200]             992
         AvgPool2d-6            [-1, 32, 1, 50]               0
 AdaptiveAvgPool2d-7             [-1, 32, 1, 1]               0
            Conv2d-8              [-1, 2, 1, 1]              66
            Conv2d-9             [-1, 32, 1, 1]              96
          Sigmoid-10             [-1, 32, 1, 1]               0
          SEBlock-11            [-1, 32, 1, 50]               0
           Conv2d-12            [-1, 32, 1, 50]             832
        AvgPool2d-13             [-1, 32, 1, 6]               0
           Conv2d-14              [-1, 