In [1]:
### 包依赖
import pandas as pd
import numpy as np
import torch
import torch.nn as nn

In [3]:
### CNN+LSTM+自互注意力模型
#   Classification Model

# params：
# feature_size                                 输入特征个数 21个
# temporal_size                                时间步长度  101个时间步
# cnn_kernel_size                              卷积核长度
# cnn_kernel_num                               卷积核个数
# lstm_layer                                   LSTM层个数
# self_att_hide                                自注意力层神经元数量
# n                                            相同时间内，选取前n个分数最高的特征；相同特征下，选取前n个分数最高的时间
# m                                            选取对单一特征，单一时间片段下，影响最大的前m个任意时间片段下的任意特征

class CnnLstmModel(nn.Module):
    def __init__(self, feature_size, temporal_size, cnn_kernel_size, cnn_kernel_num, lstm_layer, self_att_dim, inter_att_dim, n, m):
        super(CnnLstmModel, self).__init__()
        
        # init param
        self.feature_size = feature_size
        self.temporal_size = temporal_size
        self.cnn_kernel_size = cnn_kernel_size
        self.cnn_kernel_num = cnn_kernel_num
        self.lstm_layer = lstm_layer
        self.self_att_dim = self_att_dim
        self.inter_att_dim = inter_att_dim
        self.n = n
        self.m = m
        self.dim = self.n * (self.feature_size + self.temporal_size)
        
        # CNN-layer:
        # 输入:[batch_size, feature_size, temporal_size]
        # 输出:[batch_size, feature_size * cnn_kernel_num, temporal_size]
        self.cnnin = self.feature_size
        self.cnnout = self.feature_size * self.cnn_kernel_num
        self.cnn_padding = (self.cnn_kernel_size - 1) / 2
        self.cnn = nn.Conv1d(
            in_channels= self.cnnin, 
            out_channels= self.cnnout, 
            kernel_size= self.cnn_kernel_size, 
            stride= 1, 
            padding = self.cnn_padding
        )
        
        # LSTM-layer:
        # 输入:[feature_size, temporal_size， cnn_kernel_num] 需要转置操作
        # 输出 = 输入。需要设置 hidden_size = cnn_kernel_num
        self.lstm = nn.LSTM(
            input_size= self.cnn_kernel_num,
            hidden_size = self.cnn_kernel_num,
            num_layers = self.lstm_layer,
            batch_first = True
        )

        #Add&Norm

        # Self-Attention:
        # 输入:[batch_size, feature_size, temporal_size, cnn_kernel_num]
        # 输出:[batch_size, feature_size, temporal_size, 1]
        self.self_att_score = 1
        self.self_att = nn.Sequential(
            nn.Linear(self.cnn_kernel_num, self.self_att_dim),
            nn.ReLU(),
            nn.Linear(self.self_att_dim, self.self_att_score)
        )
        
        # Inter-Attention
        # 输入: 
        #     1.CNN+LSTM的输出结果: [batch_size, feature_size, temporal_size, cnn_kernel_num]
        #     2.选择后的Self-Attention的输出结果: [batch_size, (feature_size + temporal_size) * n, 1, cnn_kernel_num]
        # 输出: [batch_size, (feature_size + temporal_size)*n*(m+1), 1, cnn_kernel_num]
        # K:
        # K的输入:[batch_size, feature_size, temporal_size, cnn_kernel_num]
        # K的输出:[batch_size, feature_size, temporal_size, inter_att_dim]
        self.k = nn.Sequential(
            nn.Linear(self.cnn_kernel_num, self.inter_att_dim),
            nn.ReLU(),
            nn.Linear(self.inter_att_dim, self.inter_att_dim)
        )
        # Q:
        # Q的输入:[batch_size, (feature_size + temporal_size) * n, 1, cnn_kernel_num]
        # Q的输出:[batch_size, (feature_size + temporal_size) * n, 1, inter_att_dim]
        self.q = nn.Sequential(
            nn.Linear(self.cnn_kernel_num, self.inter_att_dim),
            nn.ReLU(),
            nn.Linear(self.inter_att_dim, self.inter_att_dim)
        )
        

        #dense-layer
        # 输入:[batch_size, n*(temporal_size+ feature_size), (m+1), cnn_kernel_num]
        # 输出:[batch_size, class_size]
        self.conv_layers = nn.Sequential(
            # 第一层卷积:保持不变
            nn.Conv2d(
                in_channels=self.dim,        # 输入通道数
                out_channels=self.dim,       # 输出通道数
                kernel_size=(3, 3),          # 卷积核大小
                padding=1                    # 填充以保持空间维度
            ),
            nn.BatchNorm2d(self.dim),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2)),  # 池化减少空间维度
            
            # 第二层卷积:缓慢增加
            nn.Conv2d(self.dim, self.dim * 2, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(self.dim * 2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2)),
            
            # 第三层卷积
            nn.Conv2d(self.dim * 2, 256, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))  # 全局平均池化
        )
        # [batch_size, 256, 1, 1]
        
        # 全连接分类层
        self.fc = nn.Linear(256, num_classes)
        
        #softmax
        self.softmax = nn.softmax(dim = -1)
        
    def forward(self, x):
        # x:[batch_size, temporal_size, feature_size]
        ### cnn layer
        x = x.permute(0, 2, 1)
        # x:[batch_size, feature_size, temporal_size]
        
        x = self.cnn(x)
        # x:[batch_size, feature_size * cnn_kernel_num, temporal_size]
        
        x = x.view(x.size(0), self.feature_size, self.cnn_kernel_num, x.size(2))
        # x:[batch_size, feature_size, cnn_kernel_num, temporal_size]
        
        cnn_out = x.permute(0, 1, 3, 2)
        # cnn_out:[batch_size, feature_size, temporal_size, cnn_kernel_num]


        
        ### lstm layer
        x = cnn_out.view(cnn_out.size(0) * cnn_out.size(1), cnn_out.size(2), cnn_out.size(3))
        # x:[batch_size * feature_size, temporal_size, cnn_kernel_num]
        
        x = self.lstm(x)
        # x:[batch_size * feature_size, temporal_size, cnn_kernel_num]
        
        lstm_out = x.view(self.batch_size, self.feature_size, x.size(1), x.size(2))
        # lstm_out:[batch_size, feature_size, temporal_size, cnn_kernel_num]


        
        ###cnn和lstm层结果的归一化  
        union_out = cnn_out + lstm_out


        
        ### self-att layer
        x = union_out.view(-1, self.cnn_kernel_num)
        # x:[batch_size * feature_size * temporal_size, cnn_kernel_num]
        
        x = self.self_att(x)
        # x:[batch_size * feature_size * temporal_size, 1]
        
        score = x.view(self.batch_size, self.feature_size, self.temporal_size)
        # score:[batch_size, feature_size, temporal_size]



        ## 同一时间维度选取得分最高的前n个特征
        x = score.permute(0, 2, 1)
        # x:[batch_size, temporal_size, feature_size]

        temporal_topn_score, temporal_topn_indices = torch.topk(x, self.n, dim = 2)
        # temporal_topn_score, temporal_topn_indices:[batch_size, temporal_size, n]

        temporal_expanded_indices = temporal_topn_indices.unsqueeze(-1).expand(-1, -1, -1, self.cnn_kernel_num)
        # temporal_expanded_indices:[batch_size, temporal_size, n, cnn_kernel_num]

        temporal_topn_values = torch.gather(union_out, 2, temporal_expanded_indices)
        # temporal_topn_values:[batch_size, temporal_size, n, cnn_kernel_num]
        
        ## 同一特征维度选取得分最高的前n个时间
        feature_topn_score, feature_topn_indices = torch.topk(score, self.n, dim = 2)
        # feature_topn_score, feature_topn_indices:[batch_size, feature_size, n]

        feature_expanded_indices = feature_topn_indices.unsqueeze(-1).expand(-1, -1, -1, self.cnn_kernel_num)
        # feature_expanded_indices:[batch_size, feature_size, n, cnn_kernel_num]

        feature_topn_values = torch.gather(union_out, 2, feature_expanded_indices)
        # feature_topn_values:[batch_size, feature_size, n, cnn_kernel_num]


        ### inter-att layer
        #待定
        topn_values = []
        # topn_values:[batch_size, (feature_size + temporal_size) * n, cnn_kernel_num]

        ## 
        x = union_out.view(union_out.size(0) * union_out.size(1) * union_out.size(2), union_out.size(3))
        # x:[batch_size * feature_size * temporal_size, cnn_kernel_num]

        x = self.k(x)
        # x:[batch_size * feature_size * temporal_size, inter_att_dim]
        k = x.view(self.batch_size, self.feature_size, self.temporal_size, -1)
        # k:[batch_size, feature_size, temporal_size, inter_att_dim]

        x = topn_values.view(topn_values.size(0) * topn_values.size(1), topn_values.size(2))
        # x:[batch_size * (feature_size + temporal_size) * n, cnn_kernel_num]
        
        x = self.q(x)
        # x:[batch_size * (feature_size + temporal_size) * n, inter_att_dim]
        
        q = x.view(self.batch_size, -1, x.size(1))
        # q:[batch_size, (feature_size + temporal_size) * n, inter_att_dim]
        
        inter_att_score = torch.einsum('bftd,bqd->bqft', k, q)
        # inter_att_score:[batch_size, (feature_size + temporal_size) * n, feature_size, temporal_size]

        #初始化结果张量
        dim1 = (self.feature_size + self.temporal_size) * self.n
        inter_att_result = torch.zeros(self.batch_size, dim1, self.m, self.cnn_kernel_num)
        for b in range(self.batch_size):
            for d in range(dim1):
                # 获取当前注意力分数矩阵 [feature_size, temporal_size]
                att_matrix = inter_att_score[b, d]

                # 展平注意力矩阵并获取前m个最大值的索引
                flat_att = att_matrix.view(-1)
                topk_values, topk_indices = torch.topk(flat_att, m, dim=0)

                # 将扁平索引转换为二维索引 (feature_idx, temporal_idx)
                feature_indices = topk_indices // temporal_size
                temporal_indices = topk_indices % temporal_size
                
                # 从union_out中提取对应的特征
                for i, (f_idx, t_idx) in enumerate(zip(feature_indices, temporal_indices)):
                    result[b, d, i] = union_out[b, f_idx, t_idx]
        # inter_att_result = [batch_size, (feature_size + temporal_size) * n, m, cnn_kernel]

        topn_values_expanded = topn_values.unsqueeze(2)
        # [batch_size, dim1, cnn_kernel_num] -> [batch_size, dim1, 1, cnn_kernel_num]

        combined = torch.cat([topn_values_expanded, inter_att_result], dim=2)
        # combined = [batch_size, dim1, m+1, cnn_kernel_num]
        conv2d_out = self.conv_layer(combined)
        conv2d_out = conv2d_out.view(conv2d_out.size(0), -1)
        fc_out = self.fc(conv2d_out)
        # fc_out = [batch_size, num_classes]
        out = self.softmax(fc_out)

        return out
        # cnn-layer:
        