In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset,DataLoader,Subset
import numpy as np
import pandas as pd
import h5py
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import time
from sklearn.metrics import classification_report, confusion_matrix,precision_score, recall_score, f1_score
import seaborn as sns
from torch.cuda.amp import autocast, GradScaler
import random
import math

In [2]:
class ChannelAttention(nn.Module):
    def __init__(self, channels, ratio=4):
        super(ChannelAttention, self).__init__()
        self.channels = channels
        self.ratio = ratio
        
        # 全局平均池化和最大池化
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.max_pool = nn.AdaptiveMaxPool1d(1)
        
        # 共享的MLP
        self.mlp = nn.Sequential(
            nn.Linear(channels, channels // ratio),
            nn.ReLU(inplace=True),
            nn.Linear(channels // ratio, channels)
        )
        
    def forward(self, x):
        # x shape: (batch_size, channels, sequence_length)
        batch_size, channels, seq_len = x.size()
        
        # 全局平均池化和最大池化
        avg_out = self.avg_pool(x).view(batch_size, channels)  # (B, C)
        max_out = self.max_pool(x).view(batch_size, channels)  # (B, C)
        
        # 通过MLP
        avg_out = self.mlp(avg_out)
        max_out = self.mlp(max_out)
        
        # 相加并应用sigmoid
        out = torch.sigmoid(avg_out + max_out)
        out = out.view(batch_size, channels, 1)  # (B, C, 1)
        
        # 与输入相乘
        return x * out

class SpatialAttention(nn.Module):
    def __init__(self):
        super(SpatialAttention, self).__init__()
        self.conv = nn.Conv1d(2, 1, kernel_size=1, padding=0)
        
    def forward(self, x):
        # x shape: (batch_size, channels, sequence_length)
        
        # 计算平均值和最大值
        avg_out = torch.mean(x, dim=1, keepdim=True)  # (B, 1, L)
        max_out, _ = torch.max(x, dim=1, keepdim=True)  # (B, 1, L)
        
        # 拼接
        concat = torch.cat([avg_out, max_out], dim=1)  # (B, 2, L)
        
        # 1D卷积
        out = self.conv(concat)  # (B, 1, L)
        out = torch.sigmoid(out)
        
        # 与输入相乘
        return x * out

class CBAMBlock(nn.Module):
    def __init__(self, channels, ratio=4):
        super(CBAMBlock, self).__init__()
        self.channel_attention = ChannelAttention(channels, ratio)
        self.spatial_attention = SpatialAttention()
        
    def forward(self, x):
        x = self.channel_attention(x)
        x = self.spatial_attention(x)
        return x

class ParallelAttentionBlock(nn.Module):
    def __init__(self, channels, ratio=4):
        super(ParallelAttentionBlock, self).__init__()
        self.channel_attention = ChannelAttention(channels, ratio)
        self.spatial_attention = SpatialAttention()
        
    def forward(self, x):
        x1 = self.channel_attention(x)
        x2 = self.spatial_attention(x)
        return x1 + x2

class ConvLSTMBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(ConvLSTMBlock, self).__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = nn.BatchNorm1d(out_channels)
        self.lstm = nn.LSTM(out_channels, out_channels, batch_first=True)
        
    def forward(self, x):
        # x shape: (batch_size, channels, sequence_length)
        x = F.relu(self.conv(x))
        x = self.bn(x)
        
        # LSTM需要 (batch_size, sequence_length, features)
        x = x.transpose(1, 2)  # (B, L, C)
        x, _ = self.lstm(x)
        x = x.transpose(1, 2)  # (B, C, L)
        
        return x

class ImprovedPhaseNetLSTMParallelAttentionCBAM(nn.Module):
    def __init__(self):
        super(ImprovedPhaseNetLSTMParallelAttentionCBAM, self).__init__()
        
        # 编码器部分
        # 第一层卷积
        self.conv1_1 = nn.Conv1d(3, 8, kernel_size=3, padding=1)
        self.bn1_1 = nn.BatchNorm1d(8)
        self.conv1_2 = nn.Conv1d(8, 8, kernel_size=3, padding=1)
        self.bn1_2 = nn.BatchNorm1d(8)
        self.lstm1 = nn.LSTM(8, 8, batch_first=True)
        
        # 下采样1
        self.conv2_1 = nn.Conv1d(8, 8, kernel_size=3, stride=4, padding=1)
        self.bn2_1 = nn.BatchNorm1d(8)
        self.conv2_2 = nn.Conv1d(8, 11, kernel_size=3, padding=1)
        self.bn2_2 = nn.BatchNorm1d(11)
        self.lstm2 = nn.LSTM(11, 11, batch_first=True)
        
        # 下采样2
        self.conv3_1 = nn.Conv1d(11, 11, kernel_size=3, stride=4, padding=1)
        self.bn3_1 = nn.BatchNorm1d(11)
        self.conv3_2 = nn.Conv1d(11, 16, kernel_size=3, padding=1)
        self.bn3_2 = nn.BatchNorm1d(16)
        self.lstm3 = nn.LSTM(16, 16, batch_first=True)
        
        # 下采样3
        self.conv4_1 = nn.Conv1d(16, 16, kernel_size=3, stride=4, padding=1)
        self.bn4_1 = nn.BatchNorm1d(16)
        self.conv4_2 = nn.Conv1d(16, 22, kernel_size=3, padding=1)
        self.bn4_2 = nn.BatchNorm1d(22)
        self.lstm4 = nn.LSTM(22, 22, batch_first=True)
        
        # 下采样4
        self.conv5_1 = nn.Conv1d(22, 22, kernel_size=3, stride=4, padding=1)
        self.bn5_1 = nn.BatchNorm1d(22)
        
        # 解码器部分
        # 上采样1
        self.conv6 = nn.Conv1d(22, 32, kernel_size=3, padding=1)
        self.bn6 = nn.BatchNorm1d(32)
        self.deconv6 = nn.ConvTranspose1d(32, 22, kernel_size=3, stride=4, padding=1, output_padding=1)
        self.bn6_up = nn.BatchNorm1d(22)
        
        # 上采样2
        self.parallel_att4 = ParallelAttentionBlock(22)
        self.conv7 = nn.Conv1d(44, 22, kernel_size=3, padding=1)  # 22 + 22 = 44
        self.bn7 = nn.BatchNorm1d(22)
        self.deconv7 = nn.ConvTranspose1d(22, 16, kernel_size=3, stride=4, padding=1, output_padding=2)
        self.bn7_up = nn.BatchNorm1d(16)
        self.cbam7 = CBAMBlock(16)
        
        # 上采样3
        self.parallel_att3 = ParallelAttentionBlock(16)
        self.conv8 = nn.Conv1d(32, 16, kernel_size=3, padding=1)  # 16 + 16 = 32
        self.bn8 = nn.BatchNorm1d(16)
        self.deconv8 = nn.ConvTranspose1d(16, 11, kernel_size=3, stride=4, padding=1, output_padding=3)
        self.bn8_up = nn.BatchNorm1d(11)
        self.cbam8 = CBAMBlock(11)
        
        # 上采样4
        self.parallel_att2 = ParallelAttentionBlock(11)
        self.conv9 = nn.Conv1d(22, 11, kernel_size=3, padding=1)  # 11 + 11 = 22
        self.bn9 = nn.BatchNorm1d(11)
        self.deconv9 = nn.ConvTranspose1d(11, 8, kernel_size=3, stride=4, padding=1, output_padding=3)
        self.bn9_up = nn.BatchNorm1d(8)
        self.cbam9 = CBAMBlock(8)
        
        # 最终输出
        self.parallel_att1 = ParallelAttentionBlock(8)
        self.conv10 = nn.Conv1d(16, 8, kernel_size=3, padding=1)  # 8 + 8 = 16
        self.bn10 = nn.BatchNorm1d(8)
        self.lstm_final = nn.LSTM(8, 3, batch_first=True)
        self.conv_final = nn.Conv1d(3, 3, kernel_size=3, padding=1)
        self.bn_final = nn.BatchNorm1d(3)
        
    def forward(self, x):
        # 输入 x shape: (batch_size, sequence_length, channels)
        # 转换为 (batch_size, channels, sequence_length)
        x = x.transpose(1, 2)
        
        # 编码器
        # 第一层卷积
        x = F.relu(self.conv1_1(x))
        x = self.bn1_1(x)
        x1 = F.relu(self.conv1_2(x))
        x1 = self.bn1_2(x1)
        
        # LSTM
        x1_lstm = x1.transpose(1, 2)  # (B, L, C)
        x1_lstm, _ = self.lstm1(x1_lstm)
        x1 = x1_lstm.transpose(1, 2)  # (B, C, L)
        
        # 下采样1
        x2 = F.relu(self.conv2_1(x1))
        x2 = self.bn2_1(x2)
        x2 = F.relu(self.conv2_2(x2))
        x2 = self.bn2_2(x2)
        
        x2_lstm = x2.transpose(1, 2)
        x2_lstm, _ = self.lstm2(x2_lstm)
        x2 = x2_lstm.transpose(1, 2)
        
        # 下采样2
        x3 = F.relu(self.conv3_1(x2))
        x3 = self.bn3_1(x3)
        x3 = F.relu(self.conv3_2(x3))
        x3 = self.bn3_2(x3)
        
        x3_lstm = x3.transpose(1, 2)
        x3_lstm, _ = self.lstm3(x3_lstm)
        x3 = x3_lstm.transpose(1, 2)
        
        # 下采样3
        x4 = F.relu(self.conv4_1(x3))
        x4 = self.bn4_1(x4)
        x4 = F.relu(self.conv4_2(x4))
        x4 = self.bn4_2(x4)
        
        x4_lstm = x4.transpose(1, 2)
        x4_lstm, _ = self.lstm4(x4_lstm)
        x4 = x4_lstm.transpose(1, 2)
        
        # 下采样4
        x5 = F.relu(self.conv5_1(x4))
        x5 = self.bn5_1(x5)
        
        # 解码器
        # 上采样1
        x6 = F.relu(self.conv6(x5))
        x6 = self.bn6(x6)
        x6 = self.deconv6(x6)
        x6 = self.bn6_up(x6)
        
        # 上采样2
        x4_att = self.parallel_att4(x4)
        x7 = torch.cat([x6, x4_att], dim=1)
        x7 = F.relu(self.conv7(x7))
        x7 = self.bn7(x7)
        x7 = self.deconv7(x7)
        x7 = self.bn7_up(x7)
        x7 = self.cbam7(x7)
        
        # 上采样3
        x3_att = self.parallel_att3(x3)
        x8 = torch.cat([x3_att, x7], dim=1)
        x8 = F.relu(self.conv8(x8))
        x8 = self.bn8(x8)
        x8 = self.deconv8(x8)
        x8 = self.bn8_up(x8)
        x8 = self.cbam8(x8)
        
        # 上采样4
        x2_att = self.parallel_att2(x2)
        x9 = torch.cat([x2_att, x8], dim=1)
        x9 = F.relu(self.conv9(x9))
        x9 = self.bn9(x9)
        x9 = self.deconv9(x9)
        x9 = self.bn9_up(x9)
        x9 = self.cbam9(x9)
        
        # 最终输出
        x1_att = self.parallel_att1(x1)
        x10 = torch.cat([x1_att, x9], dim=1)
        x10 = F.relu(self.conv10(x10))
        x10 = self.bn10(x10)
        
        # 最终LSTM
        x10_lstm = x10.transpose(1, 2)  # (B, L, C)
        x10_lstm, _ = self.lstm_final(x10_lstm)
        x10 = x10_lstm.transpose(1, 2)  # (B, C, L)
        
        # 最终卷积层
        x10 = self.conv_final(x10)
        x10 = self.bn_final(x10)
        
        # 应用softmax
        #x10 = F.softmax(x10, dim=1)
        
        # 转换回 (batch_size, sequence_length, channels)
        output = x10.transpose(1, 2)
        
        return output

# 创建模型实例
def create_improved_phasenet():
    return ImprovedPhaseNetLSTMParallelAttentionCBAM()