In [1]:
import os
class Config:
    # 数据加载器配置
    DEBUG_MODE = True
    DEBUG_SAMPLE_SIZE = 5000
    BASE_DIR = os.getcwd()
    LIBRISPEECH_PATH = os.path.join(BASE_DIR, "devDataset", "LibriSpeech")
    SAMPLE_RATE = 16000
    DURATION = 1.0
    MAX_SAMPLES = int(SAMPLE_RATE * DURATION)
    NOISE_TYPES = ["white", "babble"]
    SNR_LEVELS = [0, 5, 10]
    NUM_WORKERS = 2
    VALID_RATIO = 0.1
    MAX_SEQ_LEN = 16000
    
    # 模型配置
    INPUT_DIM = 1
    HIDDEN_DIM = 256
    EMBEDDING_DIM = 128
    CHAOS_DIM = 64
    CHAOS_TIME_STEPS = 5
    ATTENTION_HEADS = 4
    
    # 训练配置
    EPOCHS = 500
    LR = 0.0008
    LR_DECAY = 0.95
    WEIGHT_DECAY = 1e-5
    SAVE_INTERVAL = 10
    VAL_INTERVAL = 1
    CHECKPOINT_DIR = "./checkpoints_T2"

    # 优化参数
    WARMUP_EPOCHS = 5
    GRAD_CLIP = 1.0

    # 损失函数权重
    CE_WEIGHT = 1.0
    ATTENTION_REG_WEIGHT = 0.01  # 注意力正则化权重（新增）

    # 早停参数
    PATIENCE = 15
    MIN_DELTA = 0.001

    # 内存优化参数
    GRADIENT_ACCUMULATION_STEPS = 4
    BATCH_SIZE = 8     # 减小批次大小，复杂注意力需要更多内存
    ENABLE_MIXED_PRECISION = False

    # 注意力机制特定参数
    ATTENTION_MAX_SEQ_LEN = 250  # 注意力机制处理的最大序列长度
    ATTENTION_HEAD_DIM = 64       # 注意力头维度
    ATTENTION_HEADS = 8           # 注意力头数量
    
    # 序列处理参数
    PADDING_MODE = 'constant'     # 填充模式
    PADDING_VALUE = 0.0           # 填充值
    
    BIFURCATION_THRESHOLD = 0.5  # 分岔阈值
    ATTENTION_DROPOUT = 0.1  # 注意力dropout率

    # 注意力监控参数
    ATTENTION_VIZ_INTERVAL = 5  # 注意力可视化间隔（epoch）
    ATTENTION_VIZ_DIR = "./attention_viz"  # 注意力可视化保存目录

In [2]:
import os
import random
import librosa
import numpy as np
import os
import random
import librosa
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import soundfile as sf
from tqdm import tqdm


# 噪声生成与注入
class NoiseInjector:
    @staticmethod
    def generate_white_noise(length):
        return np.random.randn(length).astype(np.float32)

    @staticmethod
    def generate_babble_noise(length, num_speakers=3):
        noise = np.zeros(length, dtype=np.float32)
        for _ in range(num_speakers):
            start = random.randint(0, max(0, length - Config.SAMPLE_RATE))
            end = min(start + Config.SAMPLE_RATE, length)
            noise[start:end] += np.random.randn(end - start).astype(np.float32)
        return noise / num_speakers

    @staticmethod
    def add_noise(signal, noise_type="white", snr_db=10):
        if len(signal) == 0:
            return signal

        signal_power = np.mean(signal ** 2)
        if signal_power < 1e-10:
            return signal

        signal_db = 10 * np.log10(signal_power)

        if noise_type == "white":
            noise = NoiseInjector.generate_white_noise(len(signal))
        elif noise_type == "babble":
            noise = NoiseInjector.generate_babble_noise(len(signal))
        else:
            raise ValueError(f"不支持的噪声类型：{noise_type}")

        noise_power = np.mean(noise ** 2)
        noise_db = -100 if noise_power < 1e-10 else 10 * np.log10(noise_power)

        target_noise_db = signal_db - snr_db
        noise_scale = 10 ** ((target_noise_db - noise_db) / 20)
        noisy_signal = signal + noise * noise_scale

        # 归一化
        max_val = np.max(np.abs(noisy_signal))
        if max_val > 1e-5:
            noisy_signal = noisy_signal / max_val

        return noisy_signal

# 数据集类
class SpeakerRecognitionDataset(Dataset):
    def __init__(self, split="train", add_noise=False, noise_type="white", snr_db=10):
        self.split = split
        self.add_noise = add_noise
        self.noise_type = noise_type
        self.snr_db = snr_db

        # 加载数据集
        self.audio_paths, self.labels = self._load_dataset()
        self.speaker_to_idx = self._build_speaker_map()

        # 调试模式
        if Config.DEBUG_MODE:
            if split == "train":
                self.audio_paths = self.audio_paths[:Config.DEBUG_SAMPLE_SIZE]
                self.labels = self.labels[:Config.DEBUG_SAMPLE_SIZE]
            elif split == "val":
                self.audio_paths = self.audio_paths[:min(Config.DEBUG_SAMPLE_SIZE // 2, len(self.audio_paths))]
                self.labels = self.labels[:min(Config.DEBUG_SAMPLE_SIZE // 2, len(self.labels))]

        # 验证数据集
        self._validate_dataset()
        print(f"最终 {split} 数据集大小: {len(self.audio_paths)} 个样本")

    def _load_dataset(self):
        audio_paths = []
        labels = []

        root = Config.LIBRISPEECH_PATH
        if not os.path.exists(root):
            print(f"错误: LibriSpeech路径不存在 - {root}")
            return [], []

        print(f"加载LibriSpeech数据集: {root}")

        # 只遍历 dev-clean 和 dev-other
        for subset_dir in ["dev-clean", "dev-other"]:
            subset_path = os.path.join(root, subset_dir)
            if not os.path.isdir(subset_path):
                continue

            print(f"处理子集: {subset_dir}")
            for speaker_dir in os.listdir(subset_path):
                speaker_path = os.path.join(subset_path, speaker_dir)
                if not os.path.isdir(speaker_path):
                    continue

                for chapter_dir in os.listdir(speaker_path):
                    chapter_path = os.path.join(speaker_path, chapter_dir)
                    if not os.path.isdir(chapter_path):
                        continue

                    for file in os.listdir(chapter_path):
                        if file.endswith(".flac"):
                            full_path = os.path.join(chapter_path, file)
                            audio_paths.append(full_path)
                            labels.append(speaker_dir)

            print(f"在 {subset_dir} 中找到 {len(audio_paths)} 个.flac文件")

        # 分割训练/验证
        if self.split != "test" and len(audio_paths) > 0:
            train_paths, val_paths, train_labels, val_labels = train_test_split(
                audio_paths, labels, test_size=Config.VALID_RATIO, random_state=42
            )
            if self.split == "train":
                return train_paths, train_labels
            else:
                return val_paths, val_labels

        return audio_paths, labels

    def _build_speaker_map(self):
        unique_speakers = sorted(set(self.labels))
        print(f"找到 {len(unique_speakers)} 个不同的说话人")
        return {speaker: idx for idx, speaker in enumerate(unique_speakers)}

    def _validate_dataset(self):
        if len(self.audio_paths) == 0:
            print(f"警告: {self.split}数据集为空")
            return

        valid_count = 0
        invalid_indices = []
        for i in range(len(self.audio_paths) - 1, -1, -1):
            path = self.audio_paths[i]
            try:
                if not os.path.exists(path):
                    raise FileNotFoundError("文件不存在")

                if os.path.getsize(path) < 1024:
                    raise ValueError("文件太小可能已损坏")

                if path.endswith('.flac'):
                    signal, sr = sf.read(path)
                else:
                    signal, sr = librosa.load(path, sr=Config.SAMPLE_RATE, mono=True)

                if len(signal) < Config.SAMPLE_RATE // 2:
                    raise ValueError("音频过短")

                if np.max(np.abs(signal)) < 1e-5:
                    raise ValueError("接近静音")

                valid_count += 1
            except Exception as e:
                invalid_indices.append(i)
                print(f"无效文件: {path} - {str(e)}")

        for i in invalid_indices:
            self.audio_paths.pop(i)
            self.labels.pop(i)

        print(f"有效文件: {valid_count}/{len(self.audio_paths) + len(invalid_indices)}")
        print(f"移除 {len(invalid_indices)} 个无效文件")
        self.speaker_to_idx = self._build_speaker_map()

    def __len__(self):
        return len(self.audio_paths)

    def __getitem__(self, idx):
        path = self.audio_paths[idx]
        speaker_id = self.labels[idx]
        label = self.speaker_to_idx[speaker_id]

        try:
            if path.endswith('.flac'):
                signal, sr = sf.read(path)
                if sr != Config.SAMPLE_RATE:
                    signal = librosa.resample(signal, orig_sr=sr, target_sr=Config.SAMPLE_RATE)
            else:
                signal, sr = librosa.load(path, sr=Config.SAMPLE_RATE, mono=True)
        except Exception as e:
            print(f"加载音频错误 {path}: {str(e)}")
            signal = np.zeros(Config.MAX_SAMPLES, dtype=np.float32)

        # 确保音频长度正确
        if len(signal) > Config.MAX_SAMPLES:
            # 随机裁剪而不是固定裁剪开头
            start = random.randint(0, len(signal) - Config.MAX_SAMPLES)
            signal = signal[start:start + Config.MAX_SAMPLES]
        elif len(signal) < Config.MAX_SAMPLES:
            pad_len = Config.MAX_SAMPLES - len(signal)
            signal = np.pad(signal, (0, pad_len), 
                           mode=Config.PADDING_MODE, 
                           constant_values=Config.PADDING_VALUE)
        
        # 为注意力机制创建长度掩码
        seq_length = min(len(signal), Config.MAX_SAMPLES)
        attention_mask = np.ones(Config.MAX_SAMPLES, dtype=np.float32)
        if seq_length < Config.MAX_SAMPLES:
            attention_mask[seq_length:] = 0.0
        
        # 归一化
        max_val = np.max(np.abs(signal))
        if max_val > 1e-5:
            signal = signal / max_val
        
        # 添加噪声
        if self.add_noise and self.split == "train":
            noise_type = random.choice(Config.NOISE_TYPES)
            snr_db = random.choice(Config.SNR_LEVELS)
            signal = NoiseInjector.add_noise(signal, noise_type, snr_db)
        
        # 返回信号、标签和注意力掩码
        return torch.FloatTensor(signal), label, torch.FloatTensor(attention_mask)

# 数据加载器
def get_dataloaders(batch_size=None):
    if batch_size is None:
        batch_size = Config.BATCH_SIZE

    train_dataset = SpeakerRecognitionDataset(split="train")
    val_dataset = SpeakerRecognitionDataset(split="val")
    test_dataset = SpeakerRecognitionDataset(split="test")

    if len(train_dataset) == 0 and len(test_dataset) > 0:
        print("警告: 训练集为空，使用测试集作为训练集")
        train_dataset = test_dataset

    noisy_test_dataset = SpeakerRecognitionDataset(
        split="test", add_noise=True, noise_type="white", snr_db=5
    )

    print(f"训练集: {len(train_dataset)} 样本")
    print(f"验证集: {len(val_dataset)} 样本")
    print(f"测试集: {len(test_dataset)} 样本")
    print(f"带噪声测试集: {len(noisy_test_dataset)} 样本")
    print(f"总说话人数: {len(train_dataset.speaker_to_idx)}")

    # 创建自定义collate函数处理注意力掩码
    def collate_fn(batch):
        signals, labels, attention_masks = zip(*batch)
        signals = torch.stack(signals)
        labels = torch.tensor(labels)
        attention_masks = torch.stack(attention_masks)
        return signals, labels, attention_masks
    
    # 创建数据加载器时使用自定义collate函数
    dataloaders = {
        "train": DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                           num_workers=Config.NUM_WORKERS, pin_memory=True,
                           collate_fn=collate_fn),
        "val": DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                         num_workers=Config.NUM_WORKERS, pin_memory=True,
                         collate_fn=collate_fn),
        "test": DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                          num_workers=Config.NUM_WORKERS, pin_memory=True,
                          collate_fn=collate_fn),
        "noisy_test": DataLoader(noisy_test_dataset, batch_size=batch_size, shuffle=False,
                                num_workers=Config.NUM_WORKERS, pin_memory=True,
                                collate_fn=collate_fn),
    }
    
    return dataloaders

if __name__ == "__main__":
    print("测试数据加载器...")
    dataloaders = get_dataloaders()
    x, y, attention_mask = next(iter(dataloaders["train"]))
    print(f"音频数据形状: {x.shape}")
    print(f"标签数据形状: {y.shape}")
    print(f"注意力掩码形状: {attention_mask.shape}")

测试数据加载器...
加载LibriSpeech数据集: /users/tianyuey/Project/devDataset/LibriSpeech
处理子集: dev-clean
在 dev-clean 中找到 2703 个.flac文件
处理子集: dev-other
在 dev-other 中找到 5567 个.flac文件
找到 73 个不同的说话人
有效文件: 5000/5000
移除 0 个无效文件
找到 73 个不同的说话人
最终 train 数据集大小: 5000 个样本
加载LibriSpeech数据集: /users/tianyuey/Project/devDataset/LibriSpeech
处理子集: dev-clean
在 dev-clean 中找到 2703 个.flac文件
处理子集: dev-other
在 dev-other 中找到 5567 个.flac文件
找到 73 个不同的说话人
有效文件: 557/557
移除 0 个无效文件
找到 73 个不同的说话人
最终 val 数据集大小: 557 个样本
加载LibriSpeech数据集: /users/tianyuey/Project/devDataset/LibriSpeech
处理子集: dev-clean
在 dev-clean 中找到 2703 个.flac文件
处理子集: dev-other
在 dev-other 中找到 5567 个.flac文件
找到 73 个不同的说话人
有效文件: 5567/5567
移除 0 个无效文件
找到 73 个不同的说话人
最终 test 数据集大小: 5567 个样本
加载LibriSpeech数据集: /users/tianyuey/Project/devDataset/LibriSpeech
处理子集: dev-clean
在 dev-clean 中找到 2703 个.flac文件
处理子集: dev-other
在 dev-other 中找到 5567 个.flac文件
找到 73 个不同的说话人
有效文件: 5567/5567
移除 0 个无效文件
找到 73 个不同的说话人
最终 test 数据集大小: 5567 个样本
训练集: 5000 样本
验证集: 557 样本
测试集: 5567 样本
带噪声测试集: 55

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


# 简化混沌激励模块
class ChaoticStimulus(nn.Module):
    def __init__(self, input_dim, output_dim):
        """
        简化的混沌激励模块
        :param input_dim: 输入维度
        :param output_dim: 输出维度
        """
        super().__init__()
        self.chaos_transform = nn.Sequential(
            nn.Conv1d(input_dim, output_dim, kernel_size=3, padding=1),
            nn.BatchNorm1d(output_dim),
            nn.PReLU(),
            nn.Conv1d(output_dim, output_dim, kernel_size=3, padding=1),
            nn.BatchNorm1d(output_dim),
            nn.PReLU()
        )

        # 混沌扰动参数
        self.chaos_factor = nn.Parameter(torch.tensor(0.1))

    def forward(self, x):
        """
        前向传播
        :param x: 输入特征 [batch_size, channels, seq_len]
        :return: 混沌处理后的特征 [batch_size, channels, seq_len]
        """
        # 常规特征变换
        transformed = self.chaos_transform(x)

        # 添加混沌扰动
        batch_size, channels, seq_len = transformed.size()
        if self.training:  # 仅在训练时添加混沌扰动
            # 生成与特征相同形状的混沌噪声
            chaos_noise = torch.randn_like(transformed) * self.chaos_factor
            # 应用非线性激活增强混沌特性
            chaos_noise = torch.tanh(chaos_noise)
            transformed = transformed + chaos_noise

        return transformed


# 复杂分岔注意力机制
class BifurcationAttention(nn.Module):
    def __init__(self, input_dim, num_heads=Config.ATTENTION_HEADS,
                 threshold=Config.BIFURCATION_THRESHOLD):
        """
        分岔注意力机制
        :param input_dim: 输入维度
        :param num_heads: 注意力头数
        :param threshold: 分岔阈值
        """
        super().__init__()
        self.input_dim = input_dim
        self.num_heads = num_heads

        # 确保输入维度可以被头数整除
        assert input_dim % num_heads == 0, f"input_dim ({input_dim}) must be divisible by num_heads ({num_heads})"
        self.head_dim = input_dim // num_heads
        self.threshold = threshold

        # 注意力投影
        self.query = nn.Linear(input_dim, input_dim)
        self.key = nn.Linear(input_dim, input_dim)
        self.value = nn.Linear(input_dim, input_dim)

        # 输出投影
        self.out_proj = nn.Linear(input_dim, input_dim)

        # 分岔控制参数
        self.bifurcation_param = nn.Parameter(torch.tensor(0.5))

        # Dropout层
        self.dropout = nn.Dropout(Config.ATTENTION_DROPOUT)

        self.max_seq_len = Config.ATTENTION_MAX_SEQ_LEN

    def forward(self, x, attention_mask=None):
        """
        前向传播
        :param x: 输入特征 [batch_size, seq_len, input_dim]
        :param attention_mask: 注意力掩码 [batch_size, seq_len]
        :return: 注意力加权后的特征 [batch_size, seq_len, input_dim]
        """
        batch_size, seq_len, _ = x.size()

        # 限制序列长度防止内存溢出
        if seq_len > self.max_seq_len:
            x = x[:, :self.max_seq_len, :]
            seq_len = self.max_seq_len
            if attention_mask is not None:
                attention_mask = attention_mask[:, :self.max_seq_len]

        # 线性投影
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        # 分割头 [batch_size, num_heads, seq_len, head_dim]
        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # 计算注意力分数
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)

        # 应用注意力掩码（如果提供）
        if attention_mask is not None:
            # 将注意力掩码扩展到注意力头的维度
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)  # [batch_size, 1, 1, seq_len]
            scores = scores.masked_fill(attention_mask == 0, -1e9)

        # 应用分岔控制
        # 分岔函数: f(r) = r * sin(π * r)
        r = torch.sigmoid(self.bifurcation_param)
        bifurcation_factor = r * torch.sin(np.pi * r)

        # 当分岔因子接近阈值时，系统动态变化加剧
        if torch.abs(bifurcation_factor - self.threshold) < 0.1:
            # 添加随机扰动模拟混沌行为
            scores = scores + 0.05 * torch.randn_like(scores, device=scores.device)

        # 应用softmax获取注意力权重
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        # 加权求和
        context = torch.matmul(attention_weights, v)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.input_dim)

        # 输出投影
        output = self.out_proj(context)

        return output, attention_weights


# 统计池化层
class StatisticalPooling(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        """
        前向传播
        :param x: 输入特征 [batch_size, channels, seq_len]
        :return: 池化后的特征 [batch_size, channels*2]
        """
        # 计算均值和标准差
        mean = torch.mean(x, dim=2)
        std = torch.std(x, dim=2)

        # 拼接均值和标准差
        return torch.cat((mean, std), dim=1)


# 完整的C-HiLAP模型（简化版+复杂注意力）
class CHiLAPModel(nn.Module):
    def __init__(self, input_dim=Config.INPUT_DIM, hidden_dim=Config.HIDDEN_DIM,
                 embedding_dim=Config.EMBEDDING_DIM, num_classes=None):
        """
        混沌层次吸引子传播(C-HiLAP)模型 - 简化版+复杂注意力
        """
        super().__init__()

        # 若未传入num_classes，可设置一个默认值（但实际使用时必须从数据集获取后传入）
        if num_classes is None:
            raise ValueError("必须指定num_classes（说话人数量），请从数据集获取后传入")

        # 特征提取层
        self.feature_extractor = nn.Sequential(
            nn.Conv1d(input_dim, hidden_dim, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm1d(hidden_dim),
            nn.PReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm1d(hidden_dim),
            nn.PReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm1d(hidden_dim),
            nn.PReLU()
        )

        # 混沌激励模块
        self.chaos_layer = ChaoticStimulus(hidden_dim, hidden_dim)

        # 复杂注意力层
        self.attention = BifurcationAttention(hidden_dim)

        # TDNN层
        self.tdnn_block = nn.Sequential(
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, dilation=1, padding=1),
            nn.BatchNorm1d(hidden_dim),
            nn.PReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, dilation=2, padding=2),
            nn.BatchNorm1d(hidden_dim),
            nn.PReLU()
        )

        # 池化层
        self.pooling = StatisticalPooling()

        # 嵌入层
        self.embedding = nn.Sequential(
            nn.Linear(hidden_dim * 2, embedding_dim),  # 统计池化输出channels*2
            nn.BatchNorm1d(embedding_dim),
            nn.PReLU(),
            nn.Dropout(0.2)
        )

        # 分类器
        self.classifier = nn.Linear(embedding_dim, num_classes)

    def forward(self, x, attention_mask=None):
        """
        前向传播
        :param x: 输入特征 [batch_size, channels, seq_len] 或 [batch_size, seq_len, channels]
        :param attention_mask: 注意力掩码 [batch_size, seq_len]
        :return: 嵌入向量和分类结果
        """
        # 检查输入维度并转换为正确的格式 [batch_size, channels, seq_len]
        if x.dim() == 3:
            # 如果是 [batch_size, seq_len, channels] 格式
            if x.size(1) > x.size(2):  # 序列长度应该大于通道数
                x = x.permute(0, 2, 1)  # 转换为 [batch_size, channels, seq_len]

        # 限制序列长度防止内存溢出
        seq_len = x.size(2)
        if seq_len > Config.MAX_SEQ_LEN:
            x = x[:, :, :Config.MAX_SEQ_LEN]
            if attention_mask is not None:
                attention_mask = attention_mask[:, :Config.MAX_SEQ_LEN]

        # 特征提取
        x = self.feature_extractor(x)

        # 混沌处理
        x = self.chaos_layer(x)

        # 转置维度以适应注意力模块 [batch_size, channels, seq_len] -> [batch_size, seq_len, channels]
        x = x.permute(0, 2, 1)

        # 调整注意力掩码的长度以匹配当前特征序列长度
        if attention_mask is not None:
            # 当前特征序列长度
            current_seq_len = x.size(1)
            # 原始注意力掩码的长度
            original_seq_len = attention_mask.size(1)
            if original_seq_len != current_seq_len:
                # 使用自适应平均池化1d来调整掩码长度
                attention_mask = attention_mask.unsqueeze(1)  # 添加通道维度
                attention_mask = F.adaptive_avg_pool1d(attention_mask, current_seq_len)
                attention_mask = attention_mask.squeeze(1)    # 移除通道维度

        # 注意力加权
        x, attention_weights = self.attention(x, attention_mask)

        # 转置回原始维度 [batch_size, seq_len, channels] -> [batch_size, channels, seq_len]
        x = x.permute(0, 2, 1)

        # TDNN处理
        x = self.tdnn_block(x)

        # 池化
        x = self.pooling(x)

        # 嵌入向量
        embedding = self.embedding(x)

        # 分类
        logits = self.classifier(embedding)

        return embedding, logits, attention_weights


# 注意力可视化工具
class AttentionVisualizer:
    def __init__(self):
        pass
    
    @staticmethod
    def plot_attention_weights(attention_weights, save_path=None):
        """
        绘制注意力权重热图
        :param attention_weights: 注意力权重 [batch_size, num_heads, seq_len, seq_len]
        :param save_path: 保存路径
        """
        try:
            import matplotlib.pyplot as plt
            import seaborn as sns
            
            # 取第一个样本和第一个注意力头
            attn = attention_weights[0, 0].detach().cpu().numpy()
            
            plt.figure(figsize=(10, 8))
            sns.heatmap(attn, cmap='viridis')
            plt.title("Attention Weights Heatmap")
            plt.xlabel("Key Position")
            plt.ylabel("Query Position")
            
            if save_path:
                plt.savefig(save_path)
                print(f"注意力权重热图已保存到: {save_path}")
            else:
                plt.show()
                
        except ImportError:
            print("无法导入matplotlib或seaborn，跳过注意力可视化")
    
    @staticmethod
    def plot_bifurcation_parameter(model, save_path=None):
        """
        绘制分岔参数随时间的变化
        :param model: 模型实例
        :param save_path: 保存路径
        """
        try:
            import matplotlib.pyplot as plt
            
            # 获取分岔参数
            bifurcation_param = model.attention.bifurcation_param.detach().cpu().numpy()
            r = 1 / (1 + np.exp(-bifurcation_param))  # sigmoid逆变换
            bifurcation_factor = r * np.sin(np.pi * r)
            
            plt.figure(figsize=(10, 6))
            plt.plot([bifurcation_factor], 'ro-')
            plt.axhline(y=model.attention.threshold, color='r', linestyle='--', label='Threshold')
            plt.title("Bifurcation Parameter")
            plt.xlabel("Training Step")
            plt.ylabel("Bifurcation Factor")
            plt.legend()
            
            if save_path:
                plt.savefig(save_path)
                print(f"分岔参数图已保存到: {save_path}")
            else:
                plt.show()
                
        except ImportError:
            print("无法导入matplotlib，跳过分岔参数可视化")


# 测试代码
if __name__ == "__main__":
    # 创建模型实例
    model = CHiLAPModel(num_classes=10)  # 添加num_classes参数

    print("模型结构:")
    print(model)

    # 生成随机输入（使用较小的序列长度进行测试）
    batch_size = 2
    seq_len = Config.MAX_SEQ_LEN
    # 正确的输入格式：[batch_size, channels, seq_len]
    x = torch.randn(batch_size, 1, seq_len)
    seq_len_feat = seq_len // 8  # feature extractor stride=2 x3
    attention_mask = torch.ones(batch_size, seq_len_feat)

    print(f"\n测试前向传播:")
    print(f"输入形状: {x.shape}")
    print(f"注意力掩码形状: {attention_mask.shape}")

    # 前向传播
    try:
        embedding, logits, attention_weights = model(x, attention_mask)
        print(f"嵌入向量形状: {embedding.shape}")
        print(f"分类输出形状: {logits.shape}")
        print(f"注意力权重形状: {attention_weights.shape}")
        print("前向传播成功!")
        
        # 测试注意力可视化
        visualizer = AttentionVisualizer()
        visualizer.plot_attention_weights(attention_weights.unsqueeze(0), "attention_heatmap.png")
        visualizer.plot_bifurcation_parameter(model, "bifurcation_parameter.png")
        
    except Exception as e:
        print(f"前向传播错误: {e}")
        import traceback

        traceback.print_exc()

模型结构:
CHiLAPModel(
  (feature_extractor): Sequential(
    (0): Conv1d(1, 256, kernel_size=(5,), stride=(2,), padding=(2,))
    (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): PReLU(num_parameters=1)
    (3): Conv1d(256, 256, kernel_size=(5,), stride=(2,), padding=(2,))
    (4): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): PReLU(num_parameters=1)
    (6): Conv1d(256, 256, kernel_size=(5,), stride=(2,), padding=(2,))
    (7): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): PReLU(num_parameters=1)
  )
  (chaos_layer): ChaoticStimulus(
    (chaos_transform): Sequential(
      (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): PReLU(num_parameters=1)
      (3): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
      (4): BatchNorm

Traceback (most recent call last):
  File "/run/nvme/job_5060361/tmp/ipykernel_4099380/1893251236.py", line 377, in <module>
    visualizer.plot_attention_weights(attention_weights.unsqueeze(0), "attention_heatmap.png")
  File "/run/nvme/job_5060361/tmp/ipykernel_4099380/1893251236.py", line 300, in plot_attention_weights
    sns.heatmap(attn, cmap='viridis')
  File "/usr/local/lib/python3.12/site-packages/seaborn/matrix.py", line 446, in heatmap
    plotter = _HeatMapper(data, vmin, vmax, cmap, center, robust, annot, fmt,
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/seaborn/matrix.py", line 110, in __init__
    data = pd.DataFrame(plot_data)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib64/python3.12/site-packages/pandas/core/frame.py", line 827, in __init__
    mgr = ndarray_to_mgr(
          ^^^^^^^^^^^^^^^
  File "/usr/local/lib64/python3.12/site-packages/pandas/core/internals/construction.p

<Figure size 1000x800 with 0 Axes>

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import os
import gc
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns

# 训练器类
class Trainer:
    def __init__(self, config=Config, model=None):
        """初始化训练器"""
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"使用设备: {self.device}")

        # 创建模型
        if model is None:
            raise ValueError("初始化Trainer时必须传入model参数")
        self.model = model.to(self.device)  # 使用传入的模型

        # 定义损失函数
        self.ce_loss = nn.CrossEntropyLoss()

        # 定义优化器
        self.optimizer = optim.Adam(
            self.model.parameters(),
            lr=config.LR,
            weight_decay=config.WEIGHT_DECAY
        )

        # 学习率调度器
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode='min',
            factor=0.5,
            patience=5,
        )

        # 学习率预热调度器
        self.warmup_scheduler = optim.lr_scheduler.LambdaLR(
            self.optimizer,
            lr_lambda=lambda epoch: min(1.0, epoch / config.WARMUP_EPOCHS)
        )

        # 混合精度训练
        if config.ENABLE_MIXED_PRECISION and torch.cuda.is_available():
            self.scaler = torch.cuda.amp.GradScaler()
            print("启用混合精度训练")
        else:
            self.scaler = None

        # 创建检查点目录和注意力可视化目录
        os.makedirs(config.CHECKPOINT_DIR, exist_ok=True)
        os.makedirs(config.ATTENTION_VIZ_DIR, exist_ok=True)

        # 早停计数器
        self.early_stop_counter = 0
        self.best_val_accuracy = 0.0
        print(f"模型预期输入长度: {config.MAX_SEQ_LEN}")

        # 训练历史记录
        self.train_history = {
            'loss': [],
            'accuracy': [],
            'val_loss': [],
            'val_accuracy': [],
            'attention_entropy': []  # 新增：注意力熵记录
        }

    def attention_regularization(self, attention_weights):
        """
        注意力正则化损失 - 鼓励注意力分布更加集中
        :param attention_weights: 注意力权重 [batch_size, num_heads, seq_len, seq_len]
        :return: 正则化损失
        """
        # 计算注意力分布的熵
        entropy = -torch.sum(attention_weights * torch.log(attention_weights + 1e-10), dim=-1)
        mean_entropy = torch.mean(entropy)
        
        # 记录注意力熵
        self.train_history['attention_entropy'].append(mean_entropy.item())
        
        return mean_entropy

    def train_one_epoch(self, dataloader, epoch):
        """
        训练一个epoch
        :param dataloader: 训练数据加载器
        :param epoch: 当前epoch
        :return: 平均训练损失和准确率
        """
        self.model.train()
        total_loss = 0.0
        total_ce_loss = 0.0
        total_attn_reg_loss = 0.0
        correct = 0
        total = 0

        # 梯度累积
        accumulation_steps = self.config.GRADIENT_ACCUMULATION_STEPS
        self.optimizer.zero_grad()

        progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))
        for i, (inputs, labels, attention_mask) in progress_bar:
            try:
                inputs, labels, attention_mask = inputs.to(self.device), labels.to(self.device), attention_mask.to(self.device)

                # 确保输入维度正确 [batch, 1, seq_len]
                if inputs.dim() == 2:  # [batch, seq_len]
                    inputs = inputs.unsqueeze(1)  # 添加通道维度 -> [batch, 1, seq_len]

                # 确保音频长度正确
                if inputs.size(2) > self.config.MAX_SEQ_LEN:
                    inputs = inputs[:, :, :self.config.MAX_SEQ_LEN]
                    attention_mask = attention_mask[:, :self.config.MAX_SEQ_LEN]
                elif inputs.size(2) < self.config.MAX_SEQ_LEN:
                    pad_len = self.config.MAX_SEQ_LEN - inputs.size(2)
                    inputs = torch.nn.functional.pad(inputs, (0, pad_len), value=0.0)
                    # 更新注意力掩码
                    attention_mask_padded = torch.ones(
                        inputs.size(0), self.config.MAX_SEQ_LEN, 
                        device=inputs.device
                    )
                    attention_mask_padded[:, :attention_mask.size(1)] = attention_mask
                    attention_mask = attention_mask_padded

                # 使用混合精度训练
                if self.scaler is not None:
                    with torch.cuda.amp.autocast():
                        # 前向传播 - 注意：复杂注意力模型返回三个值
                        embeddings, logits, attention_weights = self.model(inputs, attention_mask)

                        # 计算损失
                        ce = self.ce_loss(logits, labels)
                        attn_reg = self.attention_regularization(attention_weights)
                        loss = (self.config.CE_WEIGHT * ce + 
                               self.config.ATTENTION_REG_WEIGHT * attn_reg)

                        # 梯度累积
                        loss = loss / accumulation_steps

                    # 反向传播
                    self.scaler.scale(loss).backward()

                    # 梯度裁剪
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.GRAD_CLIP)

                    if (i + 1) % accumulation_steps == 0:
                        self.scaler.step(self.optimizer)  # 优化器步骤
                        self.scaler.update()
                        # 2. 再调用预热调度器
                        if epoch <= self.config.WARMUP_EPOCHS:
                            self.warmup_scheduler.step()
                        self.optimizer.zero_grad()

                else:
                    # 标准训练（不使用混合精度）
                    # 前向传播 - 注意：复杂注意力模型返回三个值
                    embeddings, logits, attention_weights = self.model(inputs, attention_mask)

                    # 计算损失
                    ce = self.ce_loss(logits, labels)
                    attn_reg = self.attention_regularization(attention_weights)
                    loss = (self.config.CE_WEIGHT * ce + 
                           self.config.ATTENTION_REG_WEIGHT * attn_reg)

                    # 梯度累积
                    loss = loss / accumulation_steps
                    loss.backward()

                    # 梯度裁剪
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.GRAD_CLIP)

                    if (i + 1) % accumulation_steps == 0:
                        # 1. 先更新参数
                        self.optimizer.step()
                        # 2. 再调用学习率预热调度器（仅在预热阶段）
                        if epoch <= self.config.WARMUP_EPOCHS:
                            self.warmup_scheduler.step()  # 移动到optimizer.step()之后
                        self.optimizer.zero_grad()

                # 统计
                total_loss += loss.item() * accumulation_steps
                total_ce_loss += ce.item()
                total_attn_reg_loss += attn_reg.item()
                _, predicted = logits.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

                # 更新进度条
                if i % 10 == 0:  # 减少更新频率
                    accuracy = 100. * correct / total
                    avg_loss = total_loss / (i + 1)
                    progress_bar.set_description(
                        f"Epoch {epoch}, Loss: {avg_loss:.4f}, Acc: {accuracy:.2f}%, "
                        f"CE: {total_ce_loss/(i+1):.4f}, AttnReg: {total_attn_reg_loss/(i+1):.4f}"
                    )

                # 手动垃圾回收
                if i % 50 == 0:
                    torch.cuda.empty_cache()
                    gc.collect()

            except RuntimeError as e:
                if "out of memory" in str(e):
                    print(f"内存不足错误在批次 {i}: {e}")
                    torch.cuda.empty_cache()
                    gc.collect()
                    continue
                else:
                    raise e

        avg_loss = total_loss / len(dataloader)
        accuracy = 100. * correct / total
        return avg_loss, accuracy

    def validate(self, dataloader):
        """
        验证模型性能
        :param dataloader: 验证数据加载器
        :return: 验证损失和准确率
        """
        self.model.eval()
        total_loss = 0.0
        total_ce_loss = 0.0
        total_attn_reg_loss = 0.0
        correct = 0
        total = 0
        all_labels = []
        all_predictions = []

        with torch.no_grad():
            for i, (inputs, labels, attention_mask) in enumerate(dataloader):
                try:
                    inputs, labels, attention_mask = inputs.to(self.device), labels.to(self.device), attention_mask.to(self.device)

                    # 确保输入维度正确 [batch, 1, seq_len]
                    if inputs.dim() == 2:  # [batch, seq_len]
                        inputs = inputs.unsqueeze(1)  # 添加通道维度 -> [batch, 1, seq_len]

                    # 确保音频长度正确
                    if inputs.size(2) > self.config.MAX_SEQ_LEN:
                        inputs = inputs[:, :, :self.config.MAX_SEQ_LEN]
                        attention_mask = attention_mask[:, :self.config.MAX_SEQ_LEN]
                    elif inputs.size(2) < self.config.MAX_SEQ_LEN:
                        pad_len = self.config.MAX_SEQ_LEN - inputs.size(2)
                        inputs = torch.nn.functional.pad(inputs, (0, pad_len), value=0.0)
                        # 更新注意力掩码
                        attention_mask_padded = torch.ones(inputs.size(0), self.config.MAX_SEQ_LEN, device=inputs.device)
                        attention_mask_padded[:, :attention_mask.size(1)] = attention_mask
                        attention_mask = attention_mask_padded

                    # 前向传播 - 注意：复杂注意力模型返回三个值
                    embeddings, logits, attention_weights = self.model(inputs, attention_mask)

                    # 计算损失
                    ce = self.ce_loss(logits, labels)
                    attn_reg = self.attention_regularization(attention_weights)
                    loss = (self.config.CE_WEIGHT * ce + 
                           self.config.ATTENTION_REG_WEIGHT * attn_reg)

                    total_loss += loss.item()
                    total_ce_loss += ce.item()
                    total_attn_reg_loss += attn_reg.item()

                    # 统计准确率
                    _, predicted = logits.max(1)
                    total += labels.size(0)
                    correct += predicted.eq(labels).sum().item()
                    
                    # 收集所有预测和标签用于混淆矩阵
                    all_labels.extend(labels.cpu().numpy())
                    all_predictions.extend(predicted.cpu().numpy())

                    # 内存管理
                    if i % 20 == 0:
                        torch.cuda.empty_cache()
                        gc.collect()

                except RuntimeError as e:
                    if "out of memory" in str(e):
                        print(f"验证时内存不足: {e}")
                        torch.cuda.empty_cache()
                        gc.collect()
                        continue
                    else:
                        raise e

        avg_loss = total_loss / len(dataloader)
        accuracy = 100. * correct / total
        
        # 计算混淆矩阵
        cm = confusion_matrix(all_labels, all_predictions)
        
        return avg_loss, accuracy, cm

    def visualize_attention(self, dataloader, epoch):
        """
        可视化注意力权重
        :param dataloader: 数据加载器
        :param epoch: 当前epoch
        """
        self.model.eval()
        
        with torch.no_grad():
            # 获取一个批次的数据
            inputs, labels, attention_mask = next(iter(dataloader))
            inputs, labels, attention_mask = inputs.to(self.device), labels.to(self.device), attention_mask.to(self.device)
            
            # 确保输入维度正确
            if inputs.dim() == 2:
                inputs = inputs.unsqueeze(1)
                
            # 前向传播获取注意力权重
            _, _, attention_weights = self.model(inputs, attention_mask)
            
            # 可视化注意力权重
            try:
                # 取第一个样本和第一个注意力头
                attn = attention_weights[0, 0].cpu().numpy()
                
                plt.figure(figsize=(10, 8))
                sns.heatmap(attn, cmap='viridis')
                plt.title(f"Attention Weights - Epoch {epoch}")
                plt.xlabel("Key Position")
                plt.ylabel("Query Position")
                
                # 保存图像
                save_path = os.path.join(self.config.ATTENTION_VIZ_DIR, f"attention_epoch_{epoch}.png")
                plt.savefig(save_path)
                plt.close()
                
                print(f"注意力权重可视化已保存到: {save_path}")
                
            except Exception as e:
                print(f"注意力可视化失败: {e}")

    def plot_training_history(self):
        """绘制训练历史"""
        try:
            # 创建子图
            fig, axes = plt.subplots(2, 2, figsize=(15, 10))
            
            # 绘制损失曲线
            axes[0, 0].plot(self.train_history['loss'], label='Training Loss')
            axes[0, 0].plot(self.train_history['val_loss'], label='Validation Loss')
            axes[0, 0].set_title('Loss Curve')
            axes[0, 0].set_xlabel('Epoch')
            axes[0, 0].set_ylabel('Loss')
            axes[0, 0].legend()
            axes[0, 0].grid(True)
            
            # 绘制准确率曲线
            axes[0, 1].plot(self.train_history['accuracy'], label='Training Accuracy')
            axes[0, 1].plot(self.train_history['val_accuracy'], label='Validation Accuracy')
            axes[0, 1].set_title('Accuracy Curve')
            axes[0, 1].set_xlabel('Epoch')
            axes[0, 1].set_ylabel('Accuracy (%)')
            axes[0, 1].legend()
            axes[0, 1].grid(True)
            
            # 绘制注意力熵曲线
            axes[1, 0].plot(self.train_history['attention_entropy'])
            axes[1, 0].set_title('Attention Entropy')
            axes[1, 0].set_xlabel('Epoch')
            axes[1, 0].set_ylabel('Entropy')
            axes[1, 0].grid(True)
            
            # 留空用于其他可视化
            axes[1, 1].axis('off')
            
            plt.tight_layout()
            
            # 保存图像
            save_path = os.path.join(self.config.CHECKPOINT_DIR, "training_history.png")
            plt.savefig(save_path)
            plt.close()
            
            print(f"训练历史可视化已保存到: {save_path}")
            
        except Exception as e:
            print(f"训练历史可视化失败: {e}")

    def train(self, train_dataloader, val_dataloader):
        """
        完整训练流程
        :param train_dataloader: 训练数据加载器
        :param val_dataloader: 验证数据加载器
        """
        print("开始训练T2模型（简化模型+复杂注意力）...")
        print(f"训练集批次数: {len(train_dataloader)}, 验证集批次数: {len(val_dataloader)}")

        for epoch in range(1, self.config.EPOCHS + 1):
            try:
                # 应用学习率预热
                if epoch <= self.config.WARMUP_EPOCHS:
                    self.warmup_scheduler.step()

                # 训练一个epoch
                train_loss, train_acc = self.train_one_epoch(train_dataloader, epoch)
                self.train_history['loss'].append(train_loss)
                self.train_history['accuracy'].append(train_acc)
                
                print(f"Epoch {epoch}/{self.config.EPOCHS}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")

                # 验证
                val_loss, val_acc, val_cm = self.validate(val_dataloader)
                self.train_history['val_loss'].append(val_loss)
                self.train_history['val_accuracy'].append(val_acc)
                
                print(f"Epoch {epoch}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

                # 更新学习率
                self.scheduler.step(val_loss)
                current_lr = self.optimizer.param_groups[0]['lr']
                print(f"当前学习率: {current_lr:.6f}")

                # 定期可视化注意力权重
                if epoch % self.config.ATTENTION_VIZ_INTERVAL == 0:
                    self.visualize_attention(val_dataloader, epoch)

                # 早停检查 (基于验证准确率)
                if val_acc > self.best_val_accuracy + self.config.MIN_DELTA:
                    self.best_val_accuracy = val_acc
                    self.early_stop_counter = 0
                    # 保存最佳模型
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': self.model.state_dict(),
                        'optimizer_state_dict': self.optimizer.state_dict(),
                        'val_loss': val_loss,
                        'val_accuracy': val_acc,
                        'confusion_matrix': val_cm
                    }, os.path.join(self.config.CHECKPOINT_DIR, 'best_model.pth'))
                    print(f"保存最佳模型，验证准确率: {val_acc:.2f}%")
                else:
                    self.early_stop_counter += 1
                    if self.early_stop_counter >= self.config.PATIENCE:
                        print(f"早停于第 {epoch} 轮")
                        break

                # 定期保存模型
                if epoch % self.config.SAVE_INTERVAL == 0:
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': self.model.state_dict(),
                        'optimizer_state_dict': self.optimizer.state_dict(),
                        'train_loss': train_loss,
                        'train_accuracy': train_acc,
                        'confusion_matrix': val_cm
                    }, os.path.join(self.config.CHECKPOINT_DIR, f'model_epoch_{epoch}.pth'))

                # 清理内存
                torch.cuda.empty_cache()
                gc.collect()

            except KeyboardInterrupt:
                print("训练被用户中断")
                break
            except Exception as e:
                print(f"训练错误: {e}")
                torch.cuda.empty_cache()
                gc.collect()
                continue

        # 训练完成后绘制训练历史
        self.plot_training_history()

    def load_checkpoint(self, checkpoint_path):
        """
        加载检查点
        :param checkpoint_path: 检查点路径
        """
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        if 'optimizer_state_dict' in checkpoint:
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint.get('epoch', 0)
        print(f"从第 {epoch} 轮加载检查点")
        return epoch


# 评估器类
class Evaluator:
    def __init__(self, model, config=Config):
        """初始化评估器"""
        self.model = model
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.model.eval()

    def evaluate_accuracy(self, dataloader):
        """
        评估模型准确率
        :param dataloader: 数据加载器
        :return: 准确率
        """
        correct = 0
        total = 0

        with torch.no_grad():
            for i, (inputs, labels, attention_mask) in enumerate(dataloader):
                try:
                    inputs, labels, attention_mask = inputs.to(self.device), labels.to(self.device), attention_mask.to(self.device)

                    # 确保输入维度正确
                    if inputs.dim() == 2:  # [batch, seq_len]
                        inputs = inputs.unsqueeze(1)  # 添加通道维度 -> [batch, 1, seq_len]

                    # 截断过长的序列
                    if inputs.size(2) > self.config.MAX_SEQ_LEN:
                        inputs = inputs[:, :, :self.config.MAX_SEQ_LEN]
                        attention_mask = attention_mask[:, :self.config.MAX_SEQ_LEN]

                    # 前向传播 - 注意：复杂注意力模型返回三个值
                    _, logits, _ = self.model(inputs, attention_mask)
                    _, predicted = logits.max(1)
                    total += labels.size(0)
                    correct += predicted.eq(labels).sum().item()

                    # 内存管理
                    if i % 20 == 0:
                        torch.cuda.empty_cache()
                        gc.collect()

                except RuntimeError as e:
                    if "out of memory" in str(e):
                        print(f"评估时内存不足: {e}")
                        torch.cuda.empty_cache()
                        gc.collect()
                        continue
                    else:
                        raise e

        accuracy = 100. * correct / total
        return accuracy

    def analyze_attention_patterns(self, dataloader, num_samples=5):
        """
        分析注意力模式
        :param dataloader: 数据加载器
        :param num_samples: 分析的样本数量
        :return: 注意力模式分析结果
        """
        self.model.eval()
        attention_patterns = []
        
        with torch.no_grad():
            for i, (inputs, labels, attention_mask) in enumerate(dataloader):
                if i >= num_samples:
                    break
                    
                inputs, labels, attention_mask = inputs.to(self.device), labels.to(self.device), attention_mask.to(self.device)
                
                # 确保输入维度正确
                if inputs.dim() == 2:
                    inputs = inputs.unsqueeze(1)
                    
                # 前向传播获取注意力权重
                _, _, attention_weights = self.model(inputs, attention_mask)
                
                # 分析注意力模式
                for j in range(attention_weights.size(0)):  # 遍历批次中的每个样本
                    sample_attention = attention_weights[j]
                    
                    # 计算注意力集中度（每个头的平均熵）
                    attention_entropy = -torch.sum(sample_attention * torch.log(sample_attention + 1e-10), dim=-1)
                    mean_entropy = torch.mean(attention_entropy).item()
                    
                    # 计算注意力对角线强度（关注局部模式的程度）
                    diag_strength = torch.mean(torch.diagonal(sample_attention, dim1=-2, dim2=-1)).item()
                    
                    attention_patterns.append({
                        'sample_id': i * self.config.BATCH_SIZE + j,
                        'label': labels[j].item(),
                        'mean_entropy': mean_entropy,
                        'diag_strength': diag_strength,
                        'attention_weights': sample_attention.cpu().numpy()
                    })
        
        return attention_patterns


# 内存友好的数据加载器创建函数
def create_small_dataloaders(dataset_name, batch_size=4):
    """创建小批次的数据加载器以节省内存"""
    try:
        train_dataset = SpeakerRecognitionDataset(split="train")
        val_dataset = SpeakerRecognitionDataset(split="val")
        test_dataset = SpeakerRecognitionDataset(split="test")

        return {
            "train": DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                                num_workers=0, pin_memory=False),
            "val": DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                              num_workers=0, pin_memory=False),
            "test": DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                               num_workers=0, pin_memory=False),
        }
    except Exception as e:
        print(f"数据加载器创建错误: {e}")
        return None


# 测试代码
if __name__ == "__main__":
    # 设置较小的批次大小
    batch_size = Config.BATCH_SIZE

    print("创建数据加载器...")
    try:
        # 使用内存友好的数据加载器
        dataloaders = create_small_dataloaders("librispeech", batch_size=batch_size)

        if dataloaders is None:
            print("无法创建数据加载器，退出")
            exit(1)

        print(f"训练集批次数: {len(dataloaders['train'])}")
        print(f"验证集批次数: {len(dataloaders['val'])}")
        print(f"测试集批次数: {len(dataloaders['test'])}")

        # 关键修改：从训练集中获取实际说话人数量（类别数）
        num_speakers = len(dataloaders["train"].dataset.speaker_to_idx)
        print(f"数据集中实际说话人数量（类别数）: {num_speakers}")

        # 创建训练器前，先初始化模型并传入正确的num_classes
        print("创建模型和训练器...")
        
        config = Config()
        # 1. 初始化模型（传入正确的num_classes）
        model = CHiLAPModel(num_classes=num_speakers)
        # 2. 将模型传入Trainer，确保优化器初始化时模型已存在
        trainer = Trainer(config=config, model=model)  # 传入model参数

        # 测试一个批次
        print("测试前向传播...")
        try:
            # 注意：现在数据加载器返回三个值（输入、标签、注意力掩码）
            x, y, mask = next(iter(dataloaders["train"]))
            print(f"原始输入形状: {x.shape}, 标签形状: {y.shape}, 掩码形状: {mask.shape}")

            # 处理输入维度
            if x.dim() == 2:  # [batch, seq_len]
                x = x.unsqueeze(1)  # 添加通道维度 -> [batch, 1, seq_len]

            # 确保音频长度正确
            if x.size(2) > Config.MAX_SEQ_LEN:
                x = x[:, :, :Config.MAX_SEQ_LEN]
                mask = mask[:, :Config.MAX_SEQ_LEN]
            elif x.size(2) < Config.MAX_SEQ_LEN:
                pad_len = Config.MAX_SEQ_LEN - x.size(2)
                x = torch.nn.functional.pad(x, (0, pad_len), value=0.0)
                # 更新注意力掩码
                mask_padded = torch.ones(x.size(0), Config.MAX_SEQ_LEN)
                mask_padded[:, :mask.size(1)] = mask
                mask = mask_padded

            print(f"处理后输入形状: {x.shape}, 掩码形状: {mask.shape}")

            x = x.to(trainer.device)
            y = y.to(trainer.device)
            mask = mask.to(trainer.device)

            with torch.no_grad():
                embeddings, logits, attention_weights = trainer.model(x, mask)
                print(f"嵌入形状: {embeddings.shape}, 输出形状: {logits.shape}, 注意力权重形状: {attention_weights.shape}")
                print("前向传播测试成功!")
        except Exception as e:
            print(f"前向传播测试失败: {e}")
            import traceback

            traceback.print_exc()
            exit(1)

        # 开始训练
        print("开始训练...")
        trainer.train(dataloaders["train"], dataloaders["val"])

        # 创建评估器
        print("创建评估器...")
        evaluator = Evaluator(trainer.model)

        # 评估模型
        print("评估模型...")
        test_accuracy = evaluator.evaluate_accuracy(dataloaders["test"])
        print(f"测试集准确率: {test_accuracy:.2f}%")
        
        # 分析注意力模式
        print("分析注意力模式...")
        attention_patterns = evaluator.analyze_attention_patterns(dataloaders["test"], num_samples=3)
        for pattern in attention_patterns:
            print(f"样本 {pattern['sample_id']} (标签: {pattern['label']}): "
                  f"平均熵={pattern['mean_entropy']:.4f}, 对角线强度={pattern['diag_strength']:.4f}")

    except Exception as e:
        print(f"程序执行错误: {e}")
        import traceback

        traceback.print_exc()

创建数据加载器...
加载LibriSpeech数据集: /users/tianyuey/Project/devDataset/LibriSpeech
处理子集: dev-clean
在 dev-clean 中找到 2703 个.flac文件
处理子集: dev-other
在 dev-other 中找到 5567 个.flac文件
找到 73 个不同的说话人
有效文件: 5000/5000
移除 0 个无效文件
找到 73 个不同的说话人
最终 train 数据集大小: 5000 个样本
加载LibriSpeech数据集: /users/tianyuey/Project/devDataset/LibriSpeech
处理子集: dev-clean
在 dev-clean 中找到 2703 个.flac文件
处理子集: dev-other
在 dev-other 中找到 5567 个.flac文件
找到 73 个不同的说话人
有效文件: 557/557
移除 0 个无效文件
找到 73 个不同的说话人
最终 val 数据集大小: 557 个样本
加载LibriSpeech数据集: /users/tianyuey/Project/devDataset/LibriSpeech
处理子集: dev-clean
在 dev-clean 中找到 2703 个.flac文件
处理子集: dev-other
在 dev-other 中找到 5567 个.flac文件
找到 73 个不同的说话人
有效文件: 5567/5567
移除 0 个无效文件
找到 73 个不同的说话人
最终 test 数据集大小: 5567 个样本
训练集批次数: 625
验证集批次数: 70
测试集批次数: 696
数据集中实际说话人数量（类别数）: 73
创建模型和训练器...
使用设备: cuda
模型预期输入长度: 16000
测试前向传播...
原始输入形状: torch.Size([8, 16000]), 标签形状: torch.Size([8]), 掩码形状: torch.Size([8, 16000])
处理后输入形状: torch.Size([8, 1, 16000]), 掩码形状: torch.Size([8, 16000])




嵌入形状: torch.Size([8, 128]), 输出形状: torch.Size([8, 73]), 注意力权重形状: torch.Size([8, 8, 250, 250])
前向传播测试成功!
开始训练...
开始训练T2模型（简化模型+复杂注意力）...
训练集批次数: 625, 验证集批次数: 70


Epoch 1, Loss: 4.2342, Acc: 3.58%, CE: 4.1949, AttnReg: 3.9282: 100%|██████████| 625/625 [00:58<00:00, 10.76it/s]


Epoch 1/500, Train Loss: 4.2339, Train Acc: 3.58%
Epoch 1, Val Loss: 4.0890, Val Acc: 4.67%
当前学习率: 0.000800
保存最佳模型，验证准确率: 4.67%


Epoch 2, Loss: 3.9014, Acc: 6.72%, CE: 3.8817, AttnReg: 1.9697: 100%|██████████| 625/625 [00:57<00:00, 10.92it/s]


Epoch 2/500, Train Loss: 3.8994, Train Acc: 6.74%
Epoch 2, Val Loss: 3.6824, Val Acc: 11.49%
当前学习率: 0.000800
保存最佳模型，验证准确率: 11.49%


Epoch 3, Loss: 3.6652, Acc: 10.19%, CE: 3.6505, AttnReg: 1.4632: 100%|██████████| 625/625 [00:57<00:00, 10.94it/s]


Epoch 3/500, Train Loss: 3.6644, Train Acc: 10.16%
Epoch 3, Val Loss: 3.4354, Val Acc: 14.72%
当前学习率: 0.000800
保存最佳模型，验证准确率: 14.72%


Epoch 4, Loss: 3.5418, Acc: 11.63%, CE: 3.5300, AttnReg: 1.1738: 100%|██████████| 625/625 [00:57<00:00, 10.93it/s]


Epoch 4/500, Train Loss: 3.5429, Train Acc: 11.62%
Epoch 4, Val Loss: 3.3310, Val Acc: 16.88%
当前学习率: 0.000800
保存最佳模型，验证准确率: 16.88%


Epoch 5, Loss: 3.4461, Acc: 13.29%, CE: 3.4365, AttnReg: 0.9625: 100%|██████████| 625/625 [00:57<00:00, 10.92it/s]


Epoch 5/500, Train Loss: 3.4461, Train Acc: 13.24%
Epoch 5, Val Loss: 3.2694, Val Acc: 16.34%
当前学习率: 0.000800
注意力权重可视化已保存到: ./attention_viz/attention_epoch_5.png


Epoch 6, Loss: 3.3360, Acc: 15.82%, CE: 3.3275, AttnReg: 0.8481: 100%|██████████| 625/625 [00:57<00:00, 10.85it/s]


Epoch 6/500, Train Loss: 3.3353, Train Acc: 15.82%
Epoch 6, Val Loss: 3.0644, Val Acc: 22.98%
当前学习率: 0.000800
保存最佳模型，验证准确率: 22.98%


Epoch 7, Loss: 3.2936, Acc: 16.45%, CE: 3.2862, AttnReg: 0.7389: 100%|██████████| 625/625 [00:57<00:00, 10.95it/s]


Epoch 7/500, Train Loss: 3.2938, Train Acc: 16.44%
Epoch 7, Val Loss: 3.0683, Val Acc: 19.93%
当前学习率: 0.000800


Epoch 8, Loss: 3.2213, Acc: 17.98%, CE: 3.2136, AttnReg: 0.7708: 100%|██████████| 625/625 [00:57<00:00, 10.93it/s]


Epoch 8/500, Train Loss: 3.2226, Train Acc: 17.96%
Epoch 8, Val Loss: 2.9636, Val Acc: 25.85%
当前学习率: 0.000800
保存最佳模型，验证准确率: 25.85%


Epoch 9, Loss: 3.1382, Acc: 19.30%, CE: 3.1313, AttnReg: 0.6865: 100%|██████████| 625/625 [00:59<00:00, 10.53it/s]


Epoch 9/500, Train Loss: 3.1360, Train Acc: 19.32%
Epoch 9, Val Loss: 2.9718, Val Acc: 25.49%
当前学习率: 0.000800


Epoch 10, Loss: 3.0647, Acc: 20.57%, CE: 3.0589, AttnReg: 0.5832: 100%|██████████| 625/625 [01:02<00:00, 10.05it/s]


Epoch 10/500, Train Loss: 3.0638, Train Acc: 20.60%
Epoch 10, Val Loss: 2.7731, Val Acc: 27.47%
当前学习率: 0.000800
注意力权重可视化已保存到: ./attention_viz/attention_epoch_10.png
保存最佳模型，验证准确率: 27.47%


Epoch 11, Loss: 3.0556, Acc: 21.22%, CE: 3.0502, AttnReg: 0.5411: 100%|██████████| 625/625 [01:01<00:00, 10.09it/s]


Epoch 11/500, Train Loss: 3.0562, Train Acc: 21.28%
Epoch 11, Val Loss: 2.8233, Val Acc: 25.67%
当前学习率: 0.000800


Epoch 12, Loss: 3.0021, Acc: 21.86%, CE: 2.9958, AttnReg: 0.6249: 100%|██████████| 625/625 [01:02<00:00, 10.04it/s]


Epoch 12/500, Train Loss: 3.0019, Train Acc: 21.84%
Epoch 12, Val Loss: 2.7916, Val Acc: 26.39%
当前学习率: 0.000800


Epoch 13, Loss: 2.9749, Acc: 22.46%, CE: 2.9686, AttnReg: 0.6275: 100%|██████████| 625/625 [01:02<00:00,  9.94it/s]


Epoch 13/500, Train Loss: 2.9748, Train Acc: 22.44%
Epoch 13, Val Loss: 2.7858, Val Acc: 27.29%
当前学习率: 0.000800


Epoch 14, Loss: 2.9147, Acc: 24.54%, CE: 2.9088, AttnReg: 0.5962: 100%|██████████| 625/625 [01:01<00:00, 10.12it/s]


Epoch 14/500, Train Loss: 2.9134, Train Acc: 24.62%
Epoch 14, Val Loss: 2.6714, Val Acc: 29.26%
当前学习率: 0.000800
保存最佳模型，验证准确率: 29.26%


Epoch 15, Loss: 2.8515, Acc: 25.04%, CE: 2.8457, AttnReg: 0.5712: 100%|██████████| 625/625 [01:03<00:00,  9.86it/s]


Epoch 15/500, Train Loss: 2.8536, Train Acc: 25.06%
Epoch 15, Val Loss: 2.5783, Val Acc: 33.93%
当前学习率: 0.000800
注意力权重可视化已保存到: ./attention_viz/attention_epoch_15.png
保存最佳模型，验证准确率: 33.93%


Epoch 16, Loss: 2.8573, Acc: 25.54%, CE: 2.8518, AttnReg: 0.5440: 100%|██████████| 625/625 [01:01<00:00, 10.11it/s]


Epoch 16/500, Train Loss: 2.8590, Train Acc: 25.48%
Epoch 16, Val Loss: 2.6434, Val Acc: 30.52%
当前学习率: 0.000800


Epoch 17, Loss: 2.7976, Acc: 26.23%, CE: 2.7921, AttnReg: 0.5443: 100%|██████████| 625/625 [01:01<00:00, 10.14it/s]


Epoch 17/500, Train Loss: 2.7985, Train Acc: 26.16%
Epoch 17, Val Loss: 2.5877, Val Acc: 32.50%
当前学习率: 0.000800


Epoch 18, Loss: 2.8190, Acc: 26.77%, CE: 2.8133, AttnReg: 0.5633: 100%|██████████| 625/625 [01:02<00:00, 10.07it/s]


Epoch 18/500, Train Loss: 2.8203, Train Acc: 26.74%
Epoch 18, Val Loss: 2.4838, Val Acc: 32.85%
当前学习率: 0.000800


Epoch 19, Loss: 2.7478, Acc: 28.24%, CE: 2.7425, AttnReg: 0.5335: 100%|██████████| 625/625 [01:01<00:00, 10.19it/s]


Epoch 19/500, Train Loss: 2.7461, Train Acc: 28.26%
Epoch 19, Val Loss: 2.4131, Val Acc: 33.39%
当前学习率: 0.000800


Epoch 20, Loss: 2.7245, Acc: 29.93%, CE: 2.7191, AttnReg: 0.5417: 100%|██████████| 625/625 [01:02<00:00, 10.01it/s]


Epoch 20/500, Train Loss: 2.7259, Train Acc: 29.94%
Epoch 20, Val Loss: 2.3193, Val Acc: 37.70%
当前学习率: 0.000800
注意力权重可视化已保存到: ./attention_viz/attention_epoch_20.png
保存最佳模型，验证准确率: 37.70%


Epoch 21, Loss: 2.6658, Acc: 29.67%, CE: 2.6602, AttnReg: 0.5650: 100%|██████████| 625/625 [01:00<00:00, 10.25it/s]


Epoch 21/500, Train Loss: 2.6681, Train Acc: 29.64%
Epoch 21, Val Loss: 2.3934, Val Acc: 39.14%
当前学习率: 0.000800
保存最佳模型，验证准确率: 39.14%


Epoch 22, Loss: 2.6564, Acc: 30.31%, CE: 2.6506, AttnReg: 0.5798: 100%|██████████| 625/625 [01:02<00:00, 10.07it/s]


Epoch 22/500, Train Loss: 2.6594, Train Acc: 30.24%
Epoch 22, Val Loss: 2.3863, Val Acc: 35.73%
当前学习率: 0.000800


Epoch 23, Loss: 2.6215, Acc: 31.48%, CE: 2.6160, AttnReg: 0.5404: 100%|██████████| 625/625 [01:00<00:00, 10.26it/s]


Epoch 23/500, Train Loss: 2.6234, Train Acc: 31.44%
Epoch 23, Val Loss: 2.4161, Val Acc: 34.47%
当前学习率: 0.000800


Epoch 24, Loss: 2.6309, Acc: 31.12%, CE: 2.6256, AttnReg: 0.5243: 100%|██████████| 625/625 [01:01<00:00, 10.16it/s]


Epoch 24/500, Train Loss: 2.6331, Train Acc: 31.04%
Epoch 24, Val Loss: 2.3870, Val Acc: 35.01%
当前学习率: 0.000800


Epoch 25, Loss: 2.5631, Acc: 32.93%, CE: 2.5581, AttnReg: 0.4943: 100%|██████████| 625/625 [00:57<00:00, 10.89it/s]


Epoch 25/500, Train Loss: 2.5621, Train Acc: 32.94%
Epoch 25, Val Loss: 2.3061, Val Acc: 40.22%
当前学习率: 0.000800
注意力权重可视化已保存到: ./attention_viz/attention_epoch_25.png
保存最佳模型，验证准确率: 40.22%


Epoch 26, Loss: 2.5537, Acc: 31.76%, CE: 2.5486, AttnReg: 0.5108: 100%|██████████| 625/625 [00:57<00:00, 10.87it/s]


Epoch 26/500, Train Loss: 2.5536, Train Acc: 31.76%
Epoch 26, Val Loss: 2.2504, Val Acc: 42.73%
当前学习率: 0.000800
保存最佳模型，验证准确率: 42.73%


Epoch 27, Loss: 2.5250, Acc: 33.37%, CE: 2.5195, AttnReg: 0.5507: 100%|██████████| 625/625 [00:57<00:00, 10.83it/s]


Epoch 27/500, Train Loss: 2.5216, Train Acc: 33.46%
Epoch 27, Val Loss: 2.2780, Val Acc: 38.06%
当前学习率: 0.000800


Epoch 28, Loss: 2.4994, Acc: 33.62%, CE: 2.4943, AttnReg: 0.5096: 100%|██████████| 625/625 [00:57<00:00, 10.86it/s]


Epoch 28/500, Train Loss: 2.4969, Train Acc: 33.72%
Epoch 28, Val Loss: 2.1630, Val Acc: 43.27%
当前学习率: 0.000800
保存最佳模型，验证准确率: 43.27%


Epoch 29, Loss: 2.4345, Acc: 35.12%, CE: 2.4291, AttnReg: 0.5466: 100%|██████████| 625/625 [00:57<00:00, 10.90it/s]


Epoch 29/500, Train Loss: 2.4349, Train Acc: 35.14%
Epoch 29, Val Loss: 2.1390, Val Acc: 42.01%
当前学习率: 0.000800


Epoch 30, Loss: 2.4749, Acc: 34.54%, CE: 2.4696, AttnReg: 0.5349: 100%|██████████| 625/625 [00:57<00:00, 10.90it/s]


Epoch 30/500, Train Loss: 2.4730, Train Acc: 34.52%
Epoch 30, Val Loss: 2.1167, Val Acc: 42.19%
当前学习率: 0.000800
注意力权重可视化已保存到: ./attention_viz/attention_epoch_30.png


Epoch 31, Loss: 2.4590, Acc: 35.06%, CE: 2.4535, AttnReg: 0.5481: 100%|██████████| 625/625 [00:57<00:00, 10.88it/s]


Epoch 31/500, Train Loss: 2.4599, Train Acc: 35.06%
Epoch 31, Val Loss: 2.2573, Val Acc: 41.83%
当前学习率: 0.000800


Epoch 32, Loss: 2.4180, Acc: 35.69%, CE: 2.4122, AttnReg: 0.5777: 100%|██████████| 625/625 [00:57<00:00, 10.87it/s]


Epoch 32/500, Train Loss: 2.4178, Train Acc: 35.72%
Epoch 32, Val Loss: 2.1262, Val Acc: 41.47%
当前学习率: 0.000800


Epoch 33, Loss: 2.3848, Acc: 36.43%, CE: 2.3793, AttnReg: 0.5540: 100%|██████████| 625/625 [00:57<00:00, 10.87it/s]


Epoch 33/500, Train Loss: 2.3855, Train Acc: 36.44%
Epoch 33, Val Loss: 2.1889, Val Acc: 41.29%
当前学习率: 0.000800


Epoch 34, Loss: 2.4009, Acc: 36.07%, CE: 2.3955, AttnReg: 0.5478: 100%|██████████| 625/625 [00:57<00:00, 10.87it/s]


Epoch 34/500, Train Loss: 2.4002, Train Acc: 36.08%
Epoch 34, Val Loss: 2.0857, Val Acc: 41.29%
当前学习率: 0.000800


Epoch 35, Loss: 2.3882, Acc: 36.88%, CE: 2.3828, AttnReg: 0.5487: 100%|██████████| 625/625 [00:57<00:00, 10.86it/s]


Epoch 35/500, Train Loss: 2.3888, Train Acc: 36.90%
Epoch 35, Val Loss: 2.0996, Val Acc: 42.19%
当前学习率: 0.000800
注意力权重可视化已保存到: ./attention_viz/attention_epoch_35.png


Epoch 36, Loss: 2.3471, Acc: 37.64%, CE: 2.3415, AttnReg: 0.5535: 100%|██████████| 625/625 [00:57<00:00, 10.89it/s]


Epoch 36/500, Train Loss: 2.3511, Train Acc: 37.62%
Epoch 36, Val Loss: 2.1240, Val Acc: 40.93%
当前学习率: 0.000800


Epoch 37, Loss: 2.3553, Acc: 38.02%, CE: 2.3493, AttnReg: 0.5981: 100%|██████████| 625/625 [00:57<00:00, 10.90it/s]


Epoch 37/500, Train Loss: 2.3532, Train Acc: 38.08%
Epoch 37, Val Loss: 2.0234, Val Acc: 45.60%
当前学习率: 0.000800
保存最佳模型，验证准确率: 45.60%


Epoch 38, Loss: 2.3025, Acc: 38.87%, CE: 2.2968, AttnReg: 0.5646: 100%|██████████| 625/625 [00:57<00:00, 10.91it/s]


Epoch 38/500, Train Loss: 2.3041, Train Acc: 38.80%
Epoch 38, Val Loss: 1.9612, Val Acc: 45.42%
当前学习率: 0.000800


Epoch 39, Loss: 2.3110, Acc: 38.00%, CE: 2.3052, AttnReg: 0.5796: 100%|██████████| 625/625 [00:57<00:00, 10.88it/s]


Epoch 39/500, Train Loss: 2.3128, Train Acc: 38.02%
Epoch 39, Val Loss: 2.0015, Val Acc: 44.52%
当前学习率: 0.000800


Epoch 40, Loss: 2.2600, Acc: 39.55%, CE: 2.2545, AttnReg: 0.5524: 100%|██████████| 625/625 [00:57<00:00, 10.86it/s]


Epoch 40/500, Train Loss: 2.2573, Train Acc: 39.62%
Epoch 40, Val Loss: 1.9331, Val Acc: 46.32%
当前学习率: 0.000800
注意力权重可视化已保存到: ./attention_viz/attention_epoch_40.png
保存最佳模型，验证准确率: 46.32%


Epoch 41, Loss: 2.2728, Acc: 39.61%, CE: 2.2670, AttnReg: 0.5727: 100%|██████████| 625/625 [00:57<00:00, 10.85it/s]


Epoch 41/500, Train Loss: 2.2734, Train Acc: 39.60%
Epoch 41, Val Loss: 1.9429, Val Acc: 47.94%
当前学习率: 0.000800
保存最佳模型，验证准确率: 47.94%


Epoch 42, Loss: 2.2470, Acc: 39.96%, CE: 2.2412, AttnReg: 0.5762: 100%|██████████| 625/625 [00:57<00:00, 10.84it/s]


Epoch 42/500, Train Loss: 2.2470, Train Acc: 39.96%
Epoch 42, Val Loss: 1.9363, Val Acc: 47.76%
当前学习率: 0.000800


Epoch 43, Loss: 2.2183, Acc: 40.98%, CE: 2.2128, AttnReg: 0.5495: 100%|██████████| 625/625 [00:57<00:00, 10.88it/s]


Epoch 43/500, Train Loss: 2.2171, Train Acc: 40.96%
Epoch 43, Val Loss: 1.8461, Val Acc: 49.55%
当前学习率: 0.000800
保存最佳模型，验证准确率: 49.55%


Epoch 44, Loss: 2.2377, Acc: 40.44%, CE: 2.2318, AttnReg: 0.5898: 100%|██████████| 625/625 [00:57<00:00, 10.88it/s]


Epoch 44/500, Train Loss: 2.2348, Train Acc: 40.50%
Epoch 44, Val Loss: 2.0995, Val Acc: 40.57%
当前学习率: 0.000800


Epoch 45, Loss: 2.2052, Acc: 41.67%, CE: 2.1996, AttnReg: 0.5603: 100%|██████████| 625/625 [00:57<00:00, 10.87it/s]


Epoch 45/500, Train Loss: 2.2051, Train Acc: 41.62%
Epoch 45, Val Loss: 1.9144, Val Acc: 48.83%
当前学习率: 0.000800
注意力权重可视化已保存到: ./attention_viz/attention_epoch_45.png


Epoch 46, Loss: 2.1668, Acc: 41.65%, CE: 2.1614, AttnReg: 0.5384: 100%|██████████| 625/625 [00:57<00:00, 10.87it/s]


Epoch 46/500, Train Loss: 2.1672, Train Acc: 41.66%
Epoch 46, Val Loss: 1.9136, Val Acc: 46.50%
当前学习率: 0.000800


Epoch 47, Loss: 2.1449, Acc: 42.59%, CE: 2.1395, AttnReg: 0.5431: 100%|██████████| 625/625 [00:57<00:00, 10.87it/s]


Epoch 47/500, Train Loss: 2.1440, Train Acc: 42.58%
Epoch 47, Val Loss: 1.8094, Val Acc: 49.55%
当前学习率: 0.000800


Epoch 48, Loss: 2.1439, Acc: 41.81%, CE: 2.1380, AttnReg: 0.5875: 100%|██████████| 625/625 [00:57<00:00, 10.85it/s]


Epoch 48/500, Train Loss: 2.1463, Train Acc: 41.80%
Epoch 48, Val Loss: 1.9198, Val Acc: 47.58%
当前学习率: 0.000800


Epoch 49, Loss: 2.1324, Acc: 42.83%, CE: 2.1270, AttnReg: 0.5449: 100%|██████████| 625/625 [00:57<00:00, 10.86it/s]


Epoch 49/500, Train Loss: 2.1315, Train Acc: 42.82%
Epoch 49, Val Loss: 1.9069, Val Acc: 47.58%
当前学习率: 0.000800


Epoch 50, Loss: 2.1807, Acc: 42.45%, CE: 2.1753, AttnReg: 0.5445: 100%|██████████| 625/625 [00:57<00:00, 10.88it/s]


Epoch 50/500, Train Loss: 2.1809, Train Acc: 42.48%
Epoch 50, Val Loss: 1.9032, Val Acc: 46.50%
当前学习率: 0.000800
注意力权重可视化已保存到: ./attention_viz/attention_epoch_50.png


Epoch 51, Loss: 2.0909, Acc: 44.40%, CE: 2.0853, AttnReg: 0.5557: 100%|██████████| 625/625 [00:57<00:00, 10.86it/s]


Epoch 51/500, Train Loss: 2.0919, Train Acc: 44.42%
Epoch 51, Val Loss: 1.7638, Val Acc: 50.09%
当前学习率: 0.000800
保存最佳模型，验证准确率: 50.09%


Epoch 52, Loss: 2.1607, Acc: 42.57%, CE: 2.1550, AttnReg: 0.5718: 100%|██████████| 625/625 [00:57<00:00, 10.85it/s]


Epoch 52/500, Train Loss: 2.1600, Train Acc: 42.56%
Epoch 52, Val Loss: 1.8406, Val Acc: 50.63%
当前学习率: 0.000800
保存最佳模型，验证准确率: 50.63%


Epoch 53, Loss: 2.1417, Acc: 42.23%, CE: 2.1361, AttnReg: 0.5589: 100%|██████████| 625/625 [00:57<00:00, 10.89it/s]


Epoch 53/500, Train Loss: 2.1414, Train Acc: 42.26%
Epoch 53, Val Loss: 1.8220, Val Acc: 49.01%
当前学习率: 0.000800


Epoch 54, Loss: 2.1467, Acc: 42.61%, CE: 2.1409, AttnReg: 0.5889: 100%|██████████| 625/625 [00:57<00:00, 10.84it/s]


Epoch 54/500, Train Loss: 2.1448, Train Acc: 42.66%
Epoch 54, Val Loss: 1.8691, Val Acc: 47.04%
当前学习率: 0.000800


Epoch 55, Loss: 2.0949, Acc: 44.24%, CE: 2.0891, AttnReg: 0.5820: 100%|██████████| 625/625 [00:57<00:00, 10.87it/s]


Epoch 55/500, Train Loss: 2.0987, Train Acc: 44.20%
Epoch 55, Val Loss: 1.8847, Val Acc: 48.83%
当前学习率: 0.000800
注意力权重可视化已保存到: ./attention_viz/attention_epoch_55.png


Epoch 56, Loss: 2.0963, Acc: 44.16%, CE: 2.0904, AttnReg: 0.5963: 100%|██████████| 625/625 [00:57<00:00, 10.84it/s]


Epoch 56/500, Train Loss: 2.0953, Train Acc: 44.16%
Epoch 56, Val Loss: 1.7805, Val Acc: 51.71%
当前学习率: 0.000800
保存最佳模型，验证准确率: 51.71%


Epoch 57, Loss: 2.0506, Acc: 44.24%, CE: 2.0446, AttnReg: 0.5969: 100%|██████████| 625/625 [00:57<00:00, 10.84it/s]


Epoch 57/500, Train Loss: 2.0509, Train Acc: 44.26%
Epoch 57, Val Loss: 1.7929, Val Acc: 52.06%
当前学习率: 0.000400
保存最佳模型，验证准确率: 52.06%


Epoch 58, Loss: 2.0336, Acc: 45.29%, CE: 2.0276, AttnReg: 0.6013: 100%|██████████| 625/625 [00:57<00:00, 10.86it/s]


Epoch 58/500, Train Loss: 2.0341, Train Acc: 45.30%
Epoch 58, Val Loss: 1.6416, Val Acc: 56.73%
当前学习率: 0.000400
保存最佳模型，验证准确率: 56.73%


Epoch 59, Loss: 1.9845, Acc: 46.40%, CE: 1.9783, AttnReg: 0.6200: 100%|██████████| 625/625 [00:57<00:00, 10.79it/s]


Epoch 59/500, Train Loss: 1.9837, Train Acc: 46.38%
Epoch 59, Val Loss: 1.7016, Val Acc: 51.35%
当前学习率: 0.000400


Epoch 60, Loss: 1.9738, Acc: 46.54%, CE: 1.9676, AttnReg: 0.6167: 100%|██████████| 625/625 [00:57<00:00, 10.80it/s]


Epoch 60/500, Train Loss: 1.9706, Train Acc: 46.68%
Epoch 60, Val Loss: 1.7076, Val Acc: 52.06%
当前学习率: 0.000400
注意力权重可视化已保存到: ./attention_viz/attention_epoch_60.png


Epoch 61, Loss: 1.9181, Acc: 47.60%, CE: 1.9120, AttnReg: 0.6075: 100%|██████████| 625/625 [00:57<00:00, 10.80it/s]


Epoch 61/500, Train Loss: 1.9256, Train Acc: 47.50%
Epoch 61, Val Loss: 1.5845, Val Acc: 59.43%
当前学习率: 0.000400
保存最佳模型，验证准确率: 59.43%


Epoch 62, Loss: 1.8923, Acc: 48.41%, CE: 1.8861, AttnReg: 0.6158: 100%|██████████| 625/625 [00:57<00:00, 10.78it/s]


Epoch 62/500, Train Loss: 1.8906, Train Acc: 48.46%
Epoch 62, Val Loss: 1.6352, Val Acc: 52.42%
当前学习率: 0.000400


Epoch 63, Loss: 1.9004, Acc: 50.14%, CE: 1.8940, AttnReg: 0.6379: 100%|██████████| 625/625 [00:57<00:00, 10.78it/s]


Epoch 63/500, Train Loss: 1.9017, Train Acc: 50.04%
Epoch 63, Val Loss: 1.7019, Val Acc: 52.96%
当前学习率: 0.000400


Epoch 64, Loss: 1.8867, Acc: 48.39%, CE: 1.8802, AttnReg: 0.6424: 100%|██████████| 625/625 [00:57<00:00, 10.81it/s]


Epoch 64/500, Train Loss: 1.8857, Train Acc: 48.38%
Epoch 64, Val Loss: 1.6462, Val Acc: 53.14%
当前学习率: 0.000400


Epoch 65, Loss: 1.8756, Acc: 49.36%, CE: 1.8693, AttnReg: 0.6318: 100%|██████████| 625/625 [00:57<00:00, 10.84it/s]


Epoch 65/500, Train Loss: 1.8728, Train Acc: 49.46%
Epoch 65, Val Loss: 1.5944, Val Acc: 54.94%
当前学习率: 0.000400
注意力权重可视化已保存到: ./attention_viz/attention_epoch_65.png


Epoch 66, Loss: 1.8941, Acc: 49.21%, CE: 1.8878, AttnReg: 0.6301: 100%|██████████| 625/625 [00:57<00:00, 10.80it/s]


Epoch 66/500, Train Loss: 1.8935, Train Acc: 49.24%
Epoch 66, Val Loss: 1.6325, Val Acc: 53.50%
当前学习率: 0.000400


Epoch 67, Loss: 1.9132, Acc: 48.37%, CE: 1.9067, AttnReg: 0.6450: 100%|██████████| 625/625 [00:57<00:00, 10.85it/s]


Epoch 67/500, Train Loss: 1.9117, Train Acc: 48.42%
Epoch 67, Val Loss: 1.6447, Val Acc: 56.73%
当前学习率: 0.000200


Epoch 68, Loss: 1.8375, Acc: 50.10%, CE: 1.8313, AttnReg: 0.6259: 100%|██████████| 625/625 [00:57<00:00, 10.83it/s]


Epoch 68/500, Train Loss: 1.8392, Train Acc: 50.08%
Epoch 68, Val Loss: 1.5375, Val Acc: 56.73%
当前学习率: 0.000200


Epoch 69, Loss: 1.8295, Acc: 50.58%, CE: 1.8229, AttnReg: 0.6556: 100%|██████████| 625/625 [00:57<00:00, 10.81it/s]


Epoch 69/500, Train Loss: 1.8273, Train Acc: 50.62%
Epoch 69, Val Loss: 1.4909, Val Acc: 59.07%
当前学习率: 0.000200


Epoch 70, Loss: 1.7663, Acc: 52.05%, CE: 1.7599, AttnReg: 0.6391: 100%|██████████| 625/625 [00:57<00:00, 10.81it/s]


Epoch 70/500, Train Loss: 1.7661, Train Acc: 52.02%
Epoch 70, Val Loss: 1.6113, Val Acc: 53.68%
当前学习率: 0.000200
注意力权重可视化已保存到: ./attention_viz/attention_epoch_70.png


Epoch 71, Loss: 1.8060, Acc: 50.99%, CE: 1.7995, AttnReg: 0.6564: 100%|██████████| 625/625 [00:58<00:00, 10.75it/s]


Epoch 71/500, Train Loss: 1.8088, Train Acc: 50.94%
Epoch 71, Val Loss: 1.4632, Val Acc: 59.96%
当前学习率: 0.000200
保存最佳模型，验证准确率: 59.96%


Epoch 72, Loss: 1.7989, Acc: 51.99%, CE: 1.7924, AttnReg: 0.6434: 100%|██████████| 625/625 [00:57<00:00, 10.81it/s]


Epoch 72/500, Train Loss: 1.7974, Train Acc: 52.04%
Epoch 72, Val Loss: 1.6409, Val Acc: 53.68%
当前学习率: 0.000200


Epoch 73, Loss: 1.7730, Acc: 52.36%, CE: 1.7662, AttnReg: 0.6855: 100%|██████████| 625/625 [00:57<00:00, 10.81it/s]


Epoch 73/500, Train Loss: 1.7783, Train Acc: 52.26%
Epoch 73, Val Loss: 1.5379, Val Acc: 54.40%
当前学习率: 0.000200


Epoch 74, Loss: 1.7543, Acc: 51.59%, CE: 1.7475, AttnReg: 0.6788: 100%|██████████| 625/625 [00:57<00:00, 10.90it/s]


Epoch 74/500, Train Loss: 1.7560, Train Acc: 51.56%
Epoch 74, Val Loss: 1.5071, Val Acc: 57.63%
当前学习率: 0.000200


Epoch 75, Loss: 1.7580, Acc: 51.57%, CE: 1.7513, AttnReg: 0.6743: 100%|██████████| 625/625 [00:57<00:00, 10.93it/s]


Epoch 75/500, Train Loss: 1.7585, Train Acc: 51.52%
Epoch 75, Val Loss: 1.4991, Val Acc: 59.25%
当前学习率: 0.000200
注意力权重可视化已保存到: ./attention_viz/attention_epoch_75.png


Epoch 76, Loss: 1.8024, Acc: 51.45%, CE: 1.7956, AttnReg: 0.6813: 100%|██████████| 625/625 [01:00<00:00, 10.38it/s]


Epoch 76/500, Train Loss: 1.8013, Train Acc: 51.50%
Epoch 76, Val Loss: 1.4953, Val Acc: 57.99%
当前学习率: 0.000200


Epoch 77, Loss: 1.7757, Acc: 51.99%, CE: 1.7688, AttnReg: 0.6881: 100%|██████████| 625/625 [00:57<00:00, 10.80it/s]


Epoch 77/500, Train Loss: 1.7747, Train Acc: 52.02%
Epoch 77, Val Loss: 1.4489, Val Acc: 58.17%
当前学习率: 0.000200


Epoch 78, Loss: 1.7751, Acc: 52.03%, CE: 1.7684, AttnReg: 0.6702: 100%|██████████| 625/625 [00:57<00:00, 10.83it/s]


Epoch 78/500, Train Loss: 1.7744, Train Acc: 52.08%
Epoch 78, Val Loss: 1.5143, Val Acc: 57.99%
当前学习率: 0.000200


Epoch 79, Loss: 1.7381, Acc: 53.02%, CE: 1.7312, AttnReg: 0.6890: 100%|██████████| 625/625 [00:57<00:00, 10.85it/s]


Epoch 79/500, Train Loss: 1.7362, Train Acc: 53.04%
Epoch 79, Val Loss: 1.5328, Val Acc: 56.37%
当前学习率: 0.000200


Epoch 80, Loss: 1.7594, Acc: 52.13%, CE: 1.7524, AttnReg: 0.6903: 100%|██████████| 625/625 [00:57<00:00, 10.84it/s]


Epoch 80/500, Train Loss: 1.7590, Train Acc: 52.18%
Epoch 80, Val Loss: 1.3988, Val Acc: 60.68%
当前学习率: 0.000200
注意力权重可视化已保存到: ./attention_viz/attention_epoch_80.png
保存最佳模型，验证准确率: 60.68%


Epoch 81, Loss: 1.7165, Acc: 53.42%, CE: 1.7096, AttnReg: 0.6861: 100%|██████████| 625/625 [01:03<00:00,  9.89it/s]


Epoch 81/500, Train Loss: 1.7187, Train Acc: 53.34%
Epoch 81, Val Loss: 1.4355, Val Acc: 59.43%
当前学习率: 0.000200


Epoch 82, Loss: 1.7210, Acc: 53.70%, CE: 1.7141, AttnReg: 0.6857: 100%|██████████| 625/625 [00:59<00:00, 10.52it/s]


Epoch 82/500, Train Loss: 1.7204, Train Acc: 53.68%
Epoch 82, Val Loss: 1.5482, Val Acc: 56.73%
当前学习率: 0.000200


Epoch 83, Loss: 1.7024, Acc: 54.43%, CE: 1.6954, AttnReg: 0.6952: 100%|██████████| 625/625 [00:58<00:00, 10.61it/s]


Epoch 83/500, Train Loss: 1.7007, Train Acc: 54.44%
Epoch 83, Val Loss: 1.4442, Val Acc: 57.99%
当前学习率: 0.000200


Epoch 84, Loss: 1.7371, Acc: 52.70%, CE: 1.7301, AttnReg: 0.7092: 100%|██████████| 625/625 [00:58<00:00, 10.61it/s]


Epoch 84/500, Train Loss: 1.7357, Train Acc: 52.74%
Epoch 84, Val Loss: 1.6399, Val Acc: 56.19%
当前学习率: 0.000200


Epoch 85, Loss: 1.7094, Acc: 53.70%, CE: 1.7023, AttnReg: 0.7081: 100%|██████████| 625/625 [00:57<00:00, 10.78it/s]


Epoch 85/500, Train Loss: 1.7064, Train Acc: 53.76%
Epoch 85, Val Loss: 1.4580, Val Acc: 60.50%
当前学习率: 0.000200
注意力权重可视化已保存到: ./attention_viz/attention_epoch_85.png


Epoch 86, Loss: 1.7487, Acc: 52.70%, CE: 1.7416, AttnReg: 0.7099: 100%|██████████| 625/625 [01:05<00:00,  9.52it/s]


Epoch 86/500, Train Loss: 1.7463, Train Acc: 52.70%
Epoch 86, Val Loss: 1.3705, Val Acc: 60.50%
当前学习率: 0.000200


Epoch 87, Loss: 1.7140, Acc: 53.48%, CE: 1.7068, AttnReg: 0.7249: 100%|██████████| 625/625 [00:59<00:00, 10.54it/s]


Epoch 87/500, Train Loss: 1.7136, Train Acc: 53.52%
Epoch 87, Val Loss: 1.3854, Val Acc: 61.76%
当前学习率: 0.000200
保存最佳模型，验证准确率: 61.76%


Epoch 88, Loss: 1.7190, Acc: 53.64%, CE: 1.7118, AttnReg: 0.7243: 100%|██████████| 625/625 [00:57<00:00, 10.82it/s]


Epoch 88/500, Train Loss: 1.7186, Train Acc: 53.66%
Epoch 88, Val Loss: 1.4973, Val Acc: 57.99%
当前学习率: 0.000200


Epoch 89, Loss: 1.7293, Acc: 53.48%, CE: 1.7219, AttnReg: 0.7387: 100%|██████████| 625/625 [00:57<00:00, 10.83it/s]


Epoch 89/500, Train Loss: 1.7300, Train Acc: 53.50%
Epoch 89, Val Loss: 1.5469, Val Acc: 54.94%
当前学习率: 0.000200


Epoch 90, Loss: 1.7076, Acc: 54.23%, CE: 1.7001, AttnReg: 0.7472: 100%|██████████| 625/625 [00:58<00:00, 10.73it/s]


Epoch 90/500, Train Loss: 1.7054, Train Acc: 54.20%
Epoch 90, Val Loss: 1.4043, Val Acc: 58.89%
当前学习率: 0.000200
注意力权重可视化已保存到: ./attention_viz/attention_epoch_90.png


Epoch 91, Loss: 1.6974, Acc: 53.66%, CE: 1.6899, AttnReg: 0.7447: 100%|██████████| 625/625 [01:09<00:00,  8.99it/s]


Epoch 91/500, Train Loss: 1.6938, Train Acc: 53.76%
Epoch 91, Val Loss: 1.4112, Val Acc: 58.35%
当前学习率: 0.000200


Epoch 92, Loss: 1.6910, Acc: 54.17%, CE: 1.6835, AttnReg: 0.7457: 100%|██████████| 625/625 [01:01<00:00, 10.23it/s]


Epoch 92/500, Train Loss: 1.6905, Train Acc: 54.14%
Epoch 92, Val Loss: 1.4557, Val Acc: 57.63%
当前学习率: 0.000100


Epoch 93, Loss: 1.6498, Acc: 55.86%, CE: 1.6423, AttnReg: 0.7489: 100%|██████████| 625/625 [00:58<00:00, 10.76it/s]


Epoch 93/500, Train Loss: 1.6508, Train Acc: 55.84%
Epoch 93, Val Loss: 1.3097, Val Acc: 59.96%
当前学习率: 0.000100


Epoch 94, Loss: 1.6649, Acc: 54.63%, CE: 1.6574, AttnReg: 0.7494: 100%|██████████| 625/625 [00:58<00:00, 10.75it/s]


Epoch 94/500, Train Loss: 1.6666, Train Acc: 54.58%
Epoch 94, Val Loss: 1.3219, Val Acc: 61.94%
当前学习率: 0.000100
保存最佳模型，验证准确率: 61.94%


Epoch 95, Loss: 1.6629, Acc: 54.75%, CE: 1.6553, AttnReg: 0.7592: 100%|██████████| 625/625 [00:58<00:00, 10.77it/s]


Epoch 95/500, Train Loss: 1.6613, Train Acc: 54.86%
Epoch 95, Val Loss: 1.5209, Val Acc: 56.91%
当前学习率: 0.000100
注意力权重可视化已保存到: ./attention_viz/attention_epoch_95.png


Epoch 96, Loss: 1.6534, Acc: 54.95%, CE: 1.6459, AttnReg: 0.7517: 100%|██████████| 625/625 [00:57<00:00, 10.81it/s]


Epoch 96/500, Train Loss: 1.6542, Train Acc: 54.94%
Epoch 96, Val Loss: 1.3890, Val Acc: 60.86%
当前学习率: 0.000100


Epoch 97, Loss: 1.6312, Acc: 55.70%, CE: 1.6238, AttnReg: 0.7488: 100%|██████████| 625/625 [00:57<00:00, 10.82it/s]


Epoch 97/500, Train Loss: 1.6337, Train Acc: 55.64%
Epoch 97, Val Loss: 1.3640, Val Acc: 62.66%
当前学习率: 0.000100
保存最佳模型，验证准确率: 62.66%


Epoch 98, Loss: 1.6547, Acc: 55.35%, CE: 1.6470, AttnReg: 0.7680: 100%|██████████| 625/625 [00:57<00:00, 10.81it/s]


Epoch 98/500, Train Loss: 1.6524, Train Acc: 55.38%
Epoch 98, Val Loss: 1.3329, Val Acc: 62.12%
当前学习率: 0.000100


Epoch 99, Loss: 1.6229, Acc: 54.51%, CE: 1.6152, AttnReg: 0.7751: 100%|██████████| 625/625 [00:57<00:00, 10.80it/s]


Epoch 99/500, Train Loss: 1.6237, Train Acc: 54.48%
Epoch 99, Val Loss: 1.4090, Val Acc: 64.81%
当前学习率: 0.000050
保存最佳模型，验证准确率: 64.81%


Epoch 100, Loss: 1.6224, Acc: 55.39%, CE: 1.6147, AttnReg: 0.7697: 100%|██████████| 625/625 [00:57<00:00, 10.80it/s]


Epoch 100/500, Train Loss: 1.6232, Train Acc: 55.42%
Epoch 100, Val Loss: 1.3686, Val Acc: 63.02%
当前学习率: 0.000050
注意力权重可视化已保存到: ./attention_viz/attention_epoch_100.png


Epoch 101, Loss: 1.5757, Acc: 57.29%, CE: 1.5681, AttnReg: 0.7589: 100%|██████████| 625/625 [00:57<00:00, 10.79it/s]


Epoch 101/500, Train Loss: 1.5753, Train Acc: 57.32%
Epoch 101, Val Loss: 1.5030, Val Acc: 58.17%
当前学习率: 0.000050


Epoch 102, Loss: 1.5948, Acc: 56.20%, CE: 1.5871, AttnReg: 0.7710: 100%|██████████| 625/625 [00:57<00:00, 10.79it/s]


Epoch 102/500, Train Loss: 1.5960, Train Acc: 56.18%
Epoch 102, Val Loss: 1.4514, Val Acc: 58.35%
当前学习率: 0.000050


Epoch 103, Loss: 1.6142, Acc: 56.14%, CE: 1.6065, AttnReg: 0.7647: 100%|██████████| 625/625 [00:57<00:00, 10.81it/s]


Epoch 103/500, Train Loss: 1.6139, Train Acc: 56.12%
Epoch 103, Val Loss: 1.3217, Val Acc: 62.12%
当前学习率: 0.000050


Epoch 104, Loss: 1.5786, Acc: 57.25%, CE: 1.5708, AttnReg: 0.7798: 100%|██████████| 625/625 [00:58<00:00, 10.77it/s]


Epoch 104/500, Train Loss: 1.5807, Train Acc: 57.22%
Epoch 104, Val Loss: 1.3462, Val Acc: 62.12%
当前学习率: 0.000050


Epoch 105, Loss: 1.6116, Acc: 56.02%, CE: 1.6038, AttnReg: 0.7815: 100%|██████████| 625/625 [00:58<00:00, 10.77it/s]


Epoch 105/500, Train Loss: 1.6120, Train Acc: 55.98%
Epoch 105, Val Loss: 1.3496, Val Acc: 61.76%
当前学习率: 0.000025
注意力权重可视化已保存到: ./attention_viz/attention_epoch_105.png


Epoch 106, Loss: 1.6162, Acc: 55.84%, CE: 1.6082, AttnReg: 0.7962: 100%|██████████| 625/625 [00:57<00:00, 10.79it/s]


Epoch 106/500, Train Loss: 1.6162, Train Acc: 55.86%
Epoch 106, Val Loss: 1.4227, Val Acc: 58.53%
当前学习率: 0.000025


Epoch 107, Loss: 1.5919, Acc: 56.52%, CE: 1.5840, AttnReg: 0.7851: 100%|██████████| 625/625 [00:57<00:00, 10.80it/s]


Epoch 107/500, Train Loss: 1.5928, Train Acc: 56.48%
Epoch 107, Val Loss: 1.4070, Val Acc: 58.71%
当前学习率: 0.000025


Epoch 108, Loss: 1.6129, Acc: 55.68%, CE: 1.6049, AttnReg: 0.8004: 100%|██████████| 625/625 [00:57<00:00, 10.78it/s]


Epoch 108/500, Train Loss: 1.6133, Train Acc: 55.68%
Epoch 108, Val Loss: 1.3856, Val Acc: 62.12%
当前学习率: 0.000025


Epoch 109, Loss: 1.5847, Acc: 55.78%, CE: 1.5768, AttnReg: 0.7901: 100%|██████████| 625/625 [00:57<00:00, 10.81it/s]


Epoch 109/500, Train Loss: 1.5863, Train Acc: 55.74%
Epoch 109, Val Loss: 1.2740, Val Acc: 63.20%
当前学习率: 0.000025


Epoch 110, Loss: 1.5832, Acc: 57.02%, CE: 1.5753, AttnReg: 0.7954: 100%|██████████| 625/625 [01:09<00:00,  9.05it/s]


Epoch 110/500, Train Loss: 1.5810, Train Acc: 57.06%
Epoch 110, Val Loss: 1.4691, Val Acc: 58.17%
当前学习率: 0.000025
注意力权重可视化已保存到: ./attention_viz/attention_epoch_110.png


Epoch 111, Loss: 1.5488, Acc: 58.15%, CE: 1.5409, AttnReg: 0.7898: 100%|██████████| 625/625 [00:58<00:00, 10.60it/s]


Epoch 111/500, Train Loss: 1.5495, Train Acc: 58.06%
Epoch 111, Val Loss: 1.2864, Val Acc: 61.76%
当前学习率: 0.000025


Epoch 112, Loss: 1.6282, Acc: 55.84%, CE: 1.6201, AttnReg: 0.8134: 100%|██████████| 625/625 [00:57<00:00, 10.79it/s]


Epoch 112/500, Train Loss: 1.6296, Train Acc: 55.80%
Epoch 112, Val Loss: 1.5091, Val Acc: 58.53%
当前学习率: 0.000025


Epoch 113, Loss: 1.6147, Acc: 55.90%, CE: 1.6067, AttnReg: 0.8010: 100%|██████████| 625/625 [00:58<00:00, 10.74it/s]


Epoch 113/500, Train Loss: 1.6148, Train Acc: 55.94%
Epoch 113, Val Loss: 1.2805, Val Acc: 64.09%
当前学习率: 0.000025


Epoch 114, Loss: 1.6136, Acc: 56.10%, CE: 1.6057, AttnReg: 0.7920: 100%|██████████| 625/625 [00:58<00:00, 10.74it/s]


Epoch 114/500, Train Loss: 1.6129, Train Acc: 56.18%
Epoch 114, Val Loss: 1.3716, Val Acc: 63.55%
当前学习率: 0.000025
早停于第 114 轮
训练历史可视化已保存到: ./checkpoints_T2/training_history.png
创建评估器...
评估模型...
测试集准确率: 64.83%
分析注意力模式...
样本 0 (标签: 51): 平均熵=0.7785, 对角线强度=0.0020
样本 1 (标签: 51): 平均熵=0.5801, 对角线强度=0.0016
样本 2 (标签: 51): 平均熵=0.4537, 对角线强度=0.0023
样本 3 (标签: 51): 平均熵=0.4665, 对角线强度=0.0023
样本 4 (标签: 51): 平均熵=3.0330, 对角线强度=0.0036
样本 5 (标签: 51): 平均熵=0.7846, 对角线强度=0.0021
样本 6 (标签: 51): 平均熵=0.7193, 对角线强度=0.0040
样本 7 (标签: 51): 平均熵=0.3430, 对角线强度=0.0037
样本 8 (标签: 51): 平均熵=0.3234, 对角线强度=0.0015
样本 9 (标签: 51): 平均熵=0.8505, 对角线强度=0.0026
样本 10 (标签: 51): 平均熵=0.4510, 对角线强度=0.0022
样本 11 (标签: 51): 平均熵=2.4415, 对角线强度=0.0033
样本 12 (标签: 51): 平均熵=0.5113, 对角线强度=0.0016
样本 13 (标签: 51): 平均熵=0.4706, 对角线强度=0.0026
样本 14 (标签: 51): 平均熵=0.4680, 对角线强度=0.0027
样本 15 (标签: 51): 平均熵=0.4532, 对角线强度=0.0044
样本 16 (标签: 51): 平均熵=0.9774, 对角线强度=0.0020
样本 17 (标签: 51): 平均熵=0.4740, 对角线强度=0.0026
样本 18 (标签: 51): 平均熵=0.2884, 对角线强度=0.0030
样本 19 (标签: 5