In [3]:
import torch
import torch.nn as nn

class TSCL(nn.Module):
    def conv_block(self, in_chan, out_chan, kernel, step, pool):
        return nn.Sequential(
            nn.Conv2d(in_channels=in_chan, out_channels=out_chan,
                      kernel_size=kernel, stride=step, bias=False),
            nn.LeakyReLU(),
            nn.AvgPool2d(kernel_size=(1, pool), stride=(1, pool))
        )
    def Sconv_block1(self, in_chan, out_chan, kernel, step, pool):
        kernel_h, kernel_w = kernel if isinstance(kernel, tuple) else (1, kernel)
        return nn.Sequential(
            nn.Conv2d(in_chan, in_chan, (kernel_h, kernel_w), 
                     stride=(1, step), padding=(0, kernel_w//2), 
                     groups=in_chan, bias=False),
            nn.Conv2d(in_chan, out_chan, (1, 1), bias=False),
            nn.BatchNorm2d(out_chan, eps=1e-05, momentum=0.1, affine=True,
                           track_running_stats=True),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=(1, pool), stride=(1, pool))
        )
    def Sconv_block(self, in_chan, out_chan, kernel_size=3,pool=2):
        return nn.Sequential(
            nn.Conv1d(in_channels=in_chan, out_channels=in_chan, 
                      kernel_size=kernel_size, stride=1, padding=2, groups=in_chan, bias=False),
            nn.Conv1d(in_channels=in_chan, out_channels=out_chan, 
                      kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm1d(num_features=out_chan, eps=1e-05, momentum=0.1, affine=True,
                           track_running_stats=True),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=pool, stride=pool, padding=0)
        )
    def __init__(self, num_classes=3, input_size=[1,4,128], sampling_rate=128, num_T=128, num_S=128, hidden=32,
                 dropout_rate=0.5, num_c=2):
        super(TSCL, self).__init__()
        
        self.inception_window = [0.5, 0.25, 0.125]
        self.pool = 8

        self.MTCL1 = self.conv_block(1, num_T, (1, int(self.inception_window[0] * sampling_rate)), 1, self.pool)
        self.MTCL2 = self.conv_block(1, num_T, (1, int(self.inception_window[1] * sampling_rate)), 1, self.pool)
        self.MTCL3 = self.conv_block(1, num_T, (1, int(self.inception_window[2] * sampling_rate)), 1, self.pool)

        self.MSCL1 = self.conv_block(num_T, num_S, (int(input_size[1]), 1), 1, int(self.pool*0.25))
        self.MSCL2 = self.conv_block(num_T, num_S, (int(input_size[1] * 0.5), 1), (int(input_size[1] * 0.5), 1),
                                         int(self.pool*0.25))
        
        self.BN_t = nn.BatchNorm2d(num_T)
        self.BN_s = nn.BatchNorm2d(num_S)

        self.FFL1 = self.Sconv_block1(num_S, num_S, (3, 1), 1, 4)   
        self.FFL2 = self.Sconv_block(1, 32, 3)  
        self.FFL3 = self.Sconv_block(32, 64, 3)
        self.FFL4 = self.Sconv_block(64, 128, 3)
        
        self.fc1 = nn.Linear(in_features=2176, out_features=512)
        self.batchnorm1 = nn.BatchNorm1d(num_features=512, eps=1e-05, momentum=0.1, affine=True,
                                          track_running_stats=True)
        self.activation1 = nn.ReLU()
        self.fc2 = nn.Linear(in_features=512, out_features=128)
        self.batchnorm2 = nn.BatchNorm1d(num_features=128, eps=1e-05, momentum=0.1, affine=True,
                                          track_running_stats=True)
        self.activation2 = nn.ReLU()
        self.out = nn.Linear(in_features=128, out_features=num_c)
        self.dropout=nn.Dropout(dropout_rate)

    def forward(self, x):
        x = x.unsqueeze(1)  # [batch, 1, channels, time]        
        y = self.MTCL1(x)
        out = y
        y = self.MTCL2(x)
        out = torch.cat((out, y), dim=-1)  
        y = self.MTCL3(x)
        out = torch.cat((out, y), dim=-1)         
        out = self.BN_t(out)        
        z = self.MSCL1(out)
        out_ = z
        z = self.MSCL2(out)
        out_ = torch.cat((out_, z), dim=2)     
        out = self.BN_s(out_)

     #   print("Spatial output shape:", out_.shape) 
        out = self.FFL1(out_)
        out = torch.squeeze(torch.mean(out, dim=-1), dim=-1)
        out = out.unsqueeze(1)
        out= self.FFL2(out)
        out = self.FFL3(out)
        out = self.FFL4(out)
        out = torch.flatten(out, start_dim=1)
        out = self.fc1(out)
        out = self.batchnorm1(out)
        out = self.activation1(out)
        out = self.dropout(out)
        out = self.fc2(out)
        out = self.batchnorm2(out)
        out = self.activation2(out)
        out = self.dropout(out)
        out = self.out(out)
        
        return out

if __name__ == "__main__":
    data = torch.ones((16, 4, 128))  # 输入形状：(batch=16, 4通道, 128时间点)
    model = TSCL(input_size=[1,4,128], sampling_rate=128, num_T=128, num_S=128, hidden=32, dropout_rate=0.5, num_c=2)
    
    # 打印模型结构
  #  print("Model structure:")
#    print(model)
    
    output = model(data)
    print("输出形状:", output.shape)  # 预期：(16, 2)

输出形状: torch.Size([16, 2])
