In [3]:
import torch
from torch import nn
from torchsummary import summary

class Net(nn.Module):
    def __init__(self, n_chans, n_classes):
        super(Net, self).__init__()

        self.temp_conv1 = nn.Conv2d(n_chans, n_chans, kernel_size=(1, 2),  groups=n_chans)
        self.temp_conv2 = nn.Conv2d(n_chans, n_chans, kernel_size=(1, 2),  groups=n_chans)
        self.temp_conv3 = nn.Conv2d(n_chans, n_chans, kernel_size=(1, 2),  groups=n_chans)
        self.temp_conv4 = nn.Conv2d(n_chans, n_chans, kernel_size=(1, 2),  groups=n_chans)
        self.temp_conv5 = nn.Conv2d(n_chans, n_chans, kernel_size=(1, 2),  groups=n_chans)
        self.temp_conv6 = nn.Conv2d(n_chans, n_chans, kernel_size=(1, 2),  groups=n_chans)

        self.chpool1 = nn.Sequential(
            nn.Conv2d(n_chans, 32, kernel_size=(1, 4), groups=1),
            nn.BatchNorm2d(32),  # 更改为 nn.BatchNorm2d
            nn.LeakyReLU(0.01),
            nn.Conv2d(32, 32, kernel_size=(1, 4), groups=1),
            nn.BatchNorm2d(32),  # 更改为 nn.BatchNorm2d
            nn.LeakyReLU(0.01),
            nn.Conv2d(32, 32, kernel_size=(1, 4), groups=1),
            nn.BatchNorm2d(32),  # 更改为 nn.BatchNorm2d
            nn.LeakyReLU(0.01),
            nn.AdaptiveAvgPool2d((1, 1))  # 修改为 nn.AdaptiveAvgPool2d((1, 1))
        )

        # 其他的 chpool 代码保持不变...
        self.chpool2 = nn.Sequential(
            nn.Conv2d(n_chans, 32, kernel_size=(1, 4), groups=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.01),
            nn.Conv2d(32, 32, kernel_size=(1, 4), groups=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.01),
            nn.Conv2d(32, 32, kernel_size=(1, 4), groups=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.01),
            nn.AdaptiveAvgPool2d((1, 1))
        )

        self.chpool3 = nn.Sequential(
            nn.Conv2d(n_chans, 32, kernel_size=(1, 4), groups=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.01),
            nn.Conv2d(32, 32, kernel_size=(1, 4), groups=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.01),
            nn.Conv2d(32, 32, kernel_size=(1, 4), groups=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.01),
            nn.AdaptiveAvgPool2d((1, 1))
        )

        self.chpool4 = nn.Sequential(
            nn.Conv2d(n_chans, 32, kernel_size=(1, 4), groups=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.01),
            nn.Conv2d(32, 32, kernel_size=(1, 4), groups=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.01),
            nn.Conv2d(32, 32, kernel_size=(1, 4), groups=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.01),
            nn.AdaptiveAvgPool2d((1, 1))
        )

        self.chpool5 = nn.Sequential(
            nn.Conv2d(n_chans, 32, kernel_size=(1, 4), groups=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.01),
            nn.Conv2d(32, 32, kernel_size=(1, 4), groups=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.01),
            nn.Conv2d(32, 32, kernel_size=(1, 4), groups=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.01),
            nn.AdaptiveAvgPool2d((1, 1))
        )

        self.classifier = nn.Sequential(
            nn.Linear(160, 64),
            nn.LeakyReLU(0.01),
            nn.Linear(64, 32),
            nn.Sigmoid(),
            nn.Linear(32, n_classes)
        )

    def forward(self, x, training=True):
        temp_x = self.temp_conv1(x)
        temp_w1 = self.temp_conv2(temp_x)
        temp_w2 = self.temp_conv3(temp_w1)
        temp_w3 = self.temp_conv4(temp_w2)
        temp_w4 = self.temp_conv5(temp_w3)
        temp_w5 = self.temp_conv6(temp_w4)

        w1 = self.chpool1(temp_w1)
        w2 = self.chpool2(temp_w2)
        w3 = self.chpool3(temp_w3)
        w4 = self.chpool4(temp_w4)
        w5 = self.chpool5(temp_w5)

        concat_vector = torch.cat([w1, w2, w3, w4, w5], 1)
        concat_vector = concat_vector.view(concat_vector.size(0), -1)  # Flatten
        classes = nn.functional.log_softmax(self.classifier(concat_vector), dim=1)

        return classes

# 创建模型实例
model = Net(n_chans=1, n_classes=1)
model = model.cuda()

# 使用summary查看结构
summary(model, (1, 30, 200))


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 1, 30, 199]               3
            Conv2d-2           [-1, 1, 30, 198]               3
            Conv2d-3           [-1, 1, 30, 197]               3
            Conv2d-4           [-1, 1, 30, 196]               3
            Conv2d-5           [-1, 1, 30, 195]               3
            Conv2d-6           [-1, 1, 30, 194]               3
            Conv2d-7          [-1, 32, 30, 195]             160
       BatchNorm2d-8          [-1, 32, 30, 195]              64
         LeakyReLU-9          [-1, 32, 30, 195]               0
           Conv2d-10          [-1, 32, 30, 192]           4,128
      BatchNorm2d-11          [-1, 32, 30, 192]              64
        LeakyReLU-12          [-1, 32, 30, 192]               0
           Conv2d-13          [-1, 32, 30, 189]           4,128
      BatchNorm2d-14          [-1, 32, 