In [5]:
import os
class Config:
    # 数据加载器配置
    DEBUG_MODE = True
    DEBUG_SAMPLE_SIZE = 2000  # 减少样本数量以适应3秒音频
    BASE_DIR = os.getcwd()
    LIBRISPEECH_PATH = os.path.join(BASE_DIR, "devDataset", "LibriSpeech")
    SAMPLE_RATE = 16000
    DURATION = 3.0
    MAX_SAMPLES = int(SAMPLE_RATE * DURATION)
    NOISE_TYPES = ["white", "babble"]
    SNR_LEVELS = [0, 5, 10]
    BATCH_SIZE = 8
    NUM_WORKERS = 2
    VALID_RATIO = 0.1
    MAX_SEQ_LEN = 48000
    
    # 模型配置
    INPUT_DIM = 1
    HIDDEN_DIM = 256
    EMBEDDING_DIM = 128
    CHAOS_DIM = 64
    CHAOS_TIME_STEPS = 5
    ATTENTION_HEADS = 4
    
    # 训练配置 
    EPOCHS = 200 # 减少训练轮数以适应更长的训练时间
    LR = 0.01
    LR_DECAY = 0.95
    WEIGHT_DECAY = 1e-5
    SAVE_INTERVAL = 10
    VAL_INTERVAL = 1
    CHECKPOINT_DIR = "../checkpoints_T4"
    WARMUP_EPOCHS = 5
    GRAD_CLIP = 1.0
    CE_WEIGHT = 1.0
    PATIENCE = 10
    MIN_DELTA = 0.001
    GRADIENT_ACCUMULATION_STEPS = 2  # 减少梯度累积步数
    ENABLE_MIXED_PRECISION = False

In [6]:
import os
import random
import librosa
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import soundfile as sf
from tqdm import tqdm

# 噪声生成与注入
class NoiseInjector:
    @staticmethod
    def generate_white_noise(length):
        return np.random.randn(length).astype(np.float32)

    @staticmethod
    def generate_babble_noise(length, num_speakers=3):
        noise = np.zeros(length, dtype=np.float32)
        for _ in range(num_speakers):
            start = random.randint(0, max(0, length - Config.SAMPLE_RATE))
            end = min(start + Config.SAMPLE_RATE, length)
            noise[start:end] += np.random.randn(end - start).astype(np.float32)
        return noise / num_speakers

    @staticmethod
    def add_noise(signal, noise_type="white", snr_db=10):
        if len(signal) == 0:
            return signal
        signal_power = np.mean(signal ** 2)
        if signal_power < 1e-10:
            return signal

        signal_db = 10 * np.log10(signal_power)
        noise = (NoiseInjector.generate_white_noise(len(signal)) if noise_type == "white"
                 else NoiseInjector.generate_babble_noise(len(signal)))
        noise_power = np.mean(noise ** 2)
        noise_db = -100 if noise_power < 1e-10 else 10 * np.log10(noise_power)
        target_noise_db = signal_db - snr_db
        noise_scale = 10 ** ((target_noise_db - noise_db) / 20)
        noisy_signal = signal + noise * noise_scale
        max_val = np.max(np.abs(noisy_signal))
        if max_val > 1e-5:
            noisy_signal = noisy_signal / max_val
        return noisy_signal


# 数据集类
class SpeakerRecognitionDataset(Dataset):
    def __init__(self, split="train", add_noise=False, noise_type="white", snr_db=10):
        self.split = split
        self.add_noise = add_noise
        self.noise_type = noise_type
        self.snr_db = snr_db

        self.audio_paths, self.labels = self._load_dataset()
        self.speaker_to_idx = self._build_speaker_map()

        # 调试模式：限制样本数
        if Config.DEBUG_MODE:
            if split == "train":
                self.audio_paths = self.audio_paths[:Config.DEBUG_SAMPLE_SIZE]
                self.labels = self.labels[:Config.DEBUG_SAMPLE_SIZE]
            elif split == "val":
                self.audio_paths = self.audio_paths[:min(Config.DEBUG_SAMPLE_SIZE // 2, len(self.audio_paths))]
                self.labels = self.labels[:min(Config.DEBUG_SAMPLE_SIZE // 2, len(self.labels))]

        self._validate_dataset()
        print(f"最终 {split} 数据集大小: {len(self.audio_paths)} 个样本")

    def _load_dataset(self):
        audio_paths, labels = [], []
        root = Config.LIBRISPEECH_PATH
        
        # 检查路径是否存在
        if not os.path.exists(root):
            print(f"错误: LibriSpeech路径不存在 - {root}")
            print("请确保数据集已下载并放置在正确位置")
            return [], []

        print(f"加载LibriSpeech数据集: {root}")
        
        # 只遍历 dev-clean 和 dev-other
        for subset_dir in ["dev-clean", "dev-other"]:
            subset_path = os.path.join(root, subset_dir)
            if not os.path.isdir(subset_path):
                print(f"警告: 子集 {subset_dir} 不存在")
                continue
                
            print(f"处理子集: {subset_dir}")
            for speaker_dir in os.listdir(subset_path):
                speaker_path = os.path.join(subset_path, speaker_dir)
                if not os.path.isdir(speaker_path):
                    continue
                    
                for chapter_dir in os.listdir(speaker_path):
                    chapter_path = os.path.join(speaker_path, chapter_dir)
                    if not os.path.isdir(chapter_path):
                        continue
                        
                    for file in os.listdir(chapter_path):
                        if file.endswith(".flac"):
                            full_path = os.path.join(chapter_path, file)
                            audio_paths.append(full_path)
                            labels.append(speaker_dir)

        # 数据集分割
        if self.split != "test" and len(audio_paths) > 0:
            train_paths, val_paths, train_labels, val_labels = train_test_split(
                audio_paths, labels, test_size=Config.VALID_RATIO, random_state=42
            )
            return (train_paths, train_labels) if self.split == "train" else (val_paths, val_labels)
            
        return audio_paths, labels

    def _build_speaker_map(self):
        unique_speakers = sorted(set(self.labels))
        print(f"找到 {len(unique_speakers)} 个不同的说话人")
        return {spk: idx for idx, spk in enumerate(unique_speakers)}

    def _validate_dataset(self):
        if len(self.audio_paths) == 0:
            print(f"警告: {self.split}数据集为空")
            return
            
        valid_count, invalid_indices = 0, []
        for i in range(len(self.audio_paths) - 1, -1, -1):
            path = self.audio_paths[i]
            try:
                if not os.path.exists(path):
                    raise FileNotFoundError("文件不存在")
                    
                if os.path.getsize(path) < 1024:
                    raise ValueError("文件太小可能已损坏")
                    
                # 读取音频文件
                if path.endswith('.flac'):
                    signal, sr = sf.read(path)
                else:
                    signal, sr = librosa.load(path, sr=Config.SAMPLE_RATE, mono=True)
                    
                # 检查音频长度
                if len(signal) < Config.SAMPLE_RATE * 2:  # 至少需要2秒音频
                    raise ValueError(f"音频过短 ({len(signal)}采样点)，需要至少2秒音频")
                    
                # 检查信号幅度
                if np.max(np.abs(signal)) < 1e-5:
                    raise ValueError("接近静音")
                    
                valid_count += 1
            except Exception as e:
                invalid_indices.append(i)
                print(f"无效文件: {path} - {str(e)}")
                
        # 移除无效文件
        for i in invalid_indices:
            self.audio_paths.pop(i)
            self.labels.pop(i)
            
        print(f"有效文件: {valid_count}/{len(self.audio_paths) + len(invalid_indices)}")
        self.speaker_to_idx = self._build_speaker_map()

    def __len__(self):
        return len(self.audio_paths)

    def __getitem__(self, idx):
        path, speaker_id = self.audio_paths[idx], self.labels[idx]
        label = self.speaker_to_idx[speaker_id]
        
        try:
            # 加载音频文件
            if path.endswith('.flac'):
                signal, sr = sf.read(path)
                if sr != Config.SAMPLE_RATE:
                    signal = librosa.resample(signal, orig_sr=sr, target_sr=Config.SAMPLE_RATE)
            else:
                signal, sr = librosa.load(path, sr=Config.SAMPLE_RATE, mono=True)
        except Exception as e:
            print(f"加载音频错误 {path}: {str(e)}")
            signal = np.zeros(Config.MAX_SAMPLES, dtype=np.float32)

        # 处理音频长度
        if len(signal) > Config.MAX_SAMPLES:
            # 随机裁剪3秒片段
            start_idx = random.randint(0, len(signal) - Config.MAX_SAMPLES)
            signal = signal[start_idx:start_idx + Config.MAX_SAMPLES]
        elif len(signal) < Config.MAX_SAMPLES:
            # 如果音频太短，重复填充
            repeat_times = Config.MAX_SAMPLES // len(signal) + 1
            signal = np.tile(signal, repeat_times)[:Config.MAX_SAMPLES]

        # 归一化
        max_val = np.max(np.abs(signal))
        if max_val > 1e-5:
            signal = signal / max_val

        # 添加噪声
        if self.add_noise and self.split == "train":
            noise_type = random.choice(Config.NOISE_TYPES)
            snr_db = random.choice(Config.SNR_LEVELS)
            signal = NoiseInjector.add_noise(signal, noise_type, snr_db)

        return torch.FloatTensor(signal), label


# 数据加载器
def get_dataloaders(batch_size=None):
    if batch_size is None:
        batch_size = Config.BATCH_SIZE
        
    print(f"创建T4数据加载器: 批大小={batch_size}")
    
    # 创建数据集
    try:
        train_dataset = SpeakerRecognitionDataset(split="train")
        val_dataset = SpeakerRecognitionDataset(split="val")
        test_dataset = SpeakerRecognitionDataset(split="test")

        # 如果训练集为空，尝试直接加载测试集
        if len(train_dataset) == 0 and len(test_dataset) > 0:
            print("警告: 训练集为空，使用测试集作为训练集")
            train_dataset = test_dataset

        # 如果所有数据集都为空，抛出错误
        if len(train_dataset) == 0 and len(val_dataset) == 0 and len(test_dataset) == 0:
            raise RuntimeError("所有数据集均为空，请检查数据集路径和加载逻辑")

        # 创建带噪声的测试集
        noisy_test_dataset = SpeakerRecognitionDataset(split="test", add_noise=True)

        # 打印数据集统计信息
        print(f"训练集: {len(train_dataset)} 样本")
        print(f"验证集: {len(val_dataset)} 样本")
        print(f"测试集: {len(test_dataset)} 样本")
        print(f"带噪声测试集: {len(noisy_test_dataset)} 样本")
        print(f"总说话人数: {len(train_dataset.speaker_to_idx)}")

        # 创建数据加载器
        dataloaders = {
            "train": DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                                num_workers=Config.NUM_WORKERS, pin_memory=True),
            "val": DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                              num_workers=Config.NUM_WORKERS, pin_memory=True),
            "test": DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                               num_workers=Config.NUM_WORKERS, pin_memory=True),
            "noisy_test": DataLoader(noisy_test_dataset, batch_size=batch_size, shuffle=False,
                                     num_workers=Config.NUM_WORKERS, pin_memory=True),
        }

        return dataloaders

    except Exception as e:
        print(f"创建数据加载器时出错: {e}")
        return None


# 测试代码
if __name__ == "__main__":
    dataloaders = get_dataloaders()
    if dataloaders and "train" in dataloaders:
        x, y = next(iter(dataloaders["train"]))
        print(f"音频数据形状: {x.shape}, 标签形状: {y.shape}")
        print(f"音频长度: {x.shape[1]} 采样点 ({x.shape[1]/Config.SAMPLE_RATE:.2f} 秒)")

创建T4数据加载器: 批大小=8
加载LibriSpeech数据集: /users/tianyuey/Project/devDataset/LibriSpeech
处理子集: dev-clean
处理子集: dev-other
找到 73 个不同的说话人
无效文件: /users/tianyuey/Project/devDataset/LibriSpeech/dev-other/1651/136854/1651-136854-0012.flac - 音频过短 (17040采样点)，需要至少2秒音频
无效文件: /users/tianyuey/Project/devDataset/LibriSpeech/dev-other/1255/90407/1255-90407-0010.flac - 音频过短 (25920采样点)，需要至少2秒音频
无效文件: /users/tianyuey/Project/devDataset/LibriSpeech/dev-other/5849/50873/5849-50873-0030.flac - 音频过短 (30560采样点)，需要至少2秒音频
无效文件: /users/tianyuey/Project/devDataset/LibriSpeech/dev-other/6841/88291/6841-88291-0043.flac - 音频过短 (31601采样点)，需要至少2秒音频
无效文件: /users/tianyuey/Project/devDataset/LibriSpeech/dev-other/1585/131718/1585-131718-0040.flac - 音频过短 (31120采样点)，需要至少2秒音频
无效文件: /users/tianyuey/Project/devDataset/LibriSpeech/dev-clean/2428/83699/2428-83699-0004.flac - 音频过短 (30080采样点)，需要至少2秒音频
无效文件: /users/tianyuey/Project/devDataset/LibriSpeech/dev-other/1255/138279/1255-138279-0020.flac - 音频过短 (25040采样点)，需要至少2秒音频
无效文件: /users

KeyboardInterrupt: 

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class ChaoticStimulus(nn.Module):
    def __init__(self, input_dim, output_dim):
        """
        简化的混沌激励模块
        :param input_dim: 输入维度
        :param output_dim: 输出维度
        """
        super().__init__()
        self.chaos_transform = nn.Sequential(
            nn.Conv1d(input_dim, output_dim, kernel_size=3, padding=1),
            nn.BatchNorm1d(output_dim),
            nn.PReLU(),
            nn.Conv1d(output_dim, output_dim, kernel_size=3, padding=1),
            nn.BatchNorm1d(output_dim),
            nn.PReLU()
        )

        # 混沌扰动参数
        self.chaos_factor = nn.Parameter(torch.tensor(0.1))

    def forward(self, x):
        """
        前向传播
        :param x: 输入特征 [batch_size, channels, seq_len]
        :return: 混沌处理后的特征 [batch_size, channels, seq_len]
        """
        # 常规特征变换
        transformed = self.chaos_transform(x)

        # 添加混沌扰动
        batch_size, channels, seq_len = transformed.size()
        if self.training:  # 仅在训练时添加混沌扰动
            # 生成与特征相同形状的混沌噪声
            chaos_noise = torch.randn_like(transformed) * self.chaos_factor
            # 应用非线性激活增强混沌特性
            chaos_noise = torch.tanh(chaos_noise)
            transformed = transformed + chaos_noise

        return transformed


class SimpleAttention(nn.Module):
    def __init__(self, input_dim):
        """
        简化的注意力机制
        :param input_dim: 输入维度
        """
        super().__init__()
        self.attention = nn.Sequential(
            nn.Conv1d(input_dim, 1, kernel_size=1),
            nn.Softmax(dim=2)
        )

    def forward(self, x):
        """
        前向传播
        :param x: 输入特征 [batch_size, channels, seq_len]
        :return: 注意力加权后的特征 [batch_size, channels, seq_len]
        """
        # 计算注意力权重 [batch_size, 1, seq_len]
        attn_weights = self.attention(x)

        # 应用注意力权重
        return x * attn_weights


class StatisticalPooling(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        """
        前向传播
        :param x: 输入特征 [batch_size, channels, seq_len]
        :return: 池化后的特征 [batch_size, channels*2]
        """
        # 计算均值和标准差
        mean = torch.mean(x, dim=2)
        std = torch.std(x, dim=2)

        # 拼接均值和标准差
        return torch.cat((mean, std), dim=1)


# 复杂TDNN模块（来自完整模型）
class ComplexTDNN(nn.Module):
    def __init__(self, hidden_dim):
        """
        复杂TDNN模块
        :param hidden_dim: 隐藏层维度
        """
        super().__init__()
        self.tdnn1 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, dilation=1, padding=1)
        self.bn_tdnn1 = nn.BatchNorm1d(hidden_dim)
        self.prelu_tdnn1 = nn.PReLU()
        
        self.tdnn2 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, dilation=2, padding=2)
        self.bn_tdnn2 = nn.BatchNorm1d(hidden_dim)
        self.prelu_tdnn2 = nn.PReLU()
        
        self.tdnn3 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, dilation=3, padding=3)
        self.bn_tdnn3 = nn.BatchNorm1d(hidden_dim)
        self.prelu_tdnn3 = nn.PReLU()

    def forward(self, x):
        """
        前向传播
        :param x: 输入特征 [batch_size, hidden_dim, seq_len]
        :return: TDNN处理后的特征 [batch_size, hidden_dim, seq_len]
        """
        x = self.prelu_tdnn1(self.bn_tdnn1(self.tdnn1(x)))
        x = self.prelu_tdnn2(self.bn_tdnn2(self.tdnn2(x)))
        x = self.prelu_tdnn3(self.bn_tdnn3(self.tdnn3(x)))
        return x


# T4模型：简化模型+复杂TDNN
class CHiLAPModel_T4(nn.Module):
    def __init__(self, input_dim=Config.INPUT_DIM, hidden_dim=Config.HIDDEN_DIM,
                 embedding_dim=Config.EMBEDDING_DIM, num_classes=None):
        """
        混沌层次吸引子传播(C-HiLAP)模型 - T4版本（简化模型+复杂TDNN）
        """
        super().__init__()

        # 若未传入num_classes，可设置一个默认值（但实际使用时必须从数据集获取后传入）
        if num_classes is None:
            raise ValueError("必须指定num_classes（说话人数量），请从数据集获取后传入")

        # 特征提取层
        self.feature_extractor = nn.Sequential(
            nn.Conv1d(input_dim, hidden_dim, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm1d(hidden_dim),
            nn.PReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm1d(hidden_dim),
            nn.PReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm1d(hidden_dim),
            nn.PReLU()
        )

        # 混沌激励模块
        self.chaos_layer = ChaoticStimulus(hidden_dim, hidden_dim)

        # 注意力层
        self.attention = SimpleAttention(hidden_dim)

        # 复杂TDNN层（替换简化版的TDNN）
        self.tdnn_block = ComplexTDNN(hidden_dim)

        # 池化层
        self.pooling = StatisticalPooling()

        # 嵌入层
        self.embedding = nn.Sequential(
            nn.Linear(hidden_dim * 2, embedding_dim),  # 统计池化输出channels*2
            nn.BatchNorm1d(embedding_dim),
            nn.PReLU(),
            nn.Dropout(0.2)
        )

        # 分类器
        self.classifier = nn.Linear(embedding_dim, num_classes)

    def forward(self, x):
        """
        前向传播
        :param x: 输入特征 [batch_size, channels, seq_len] 或 [batch_size, seq_len, channels]
        :return: 嵌入向量和分类结果
        """
        # 检查输入维度并转换为正确的格式 [batch_size, channels, seq_len]
        if x.dim() == 3:
            # 如果是 [batch_size, seq_len, channels] 格式
            if x.size(1) > x.size(2):  # 序列长度应该大于通道数
                x = x.permute(0, 2, 1)  # 转换为 [batch_size, channels, seq_len]

        # 限制序列长度防止内存溢出
        seq_len = x.size(2)
        if seq_len > Config.MAX_SEQ_LEN:
            x = x[:, :, :Config.MAX_SEQ_LEN]

        # 特征提取
        x = self.feature_extractor(x)

        # 混沌处理
        x = self.chaos_layer(x)

        # 注意力加权
        x = self.attention(x)

        # 复杂TDNN处理
        x = self.tdnn_block(x)

        # 池化
        x = self.pooling(x)

        # 嵌入向量
        embedding = self.embedding(x)

        # 分类
        logits = self.classifier(embedding)

        return embedding, logits


# 测试代码
if __name__ == "__main__":
    # 创建模型实例
    model = CHiLAPModel_T4(num_classes=10)  # 假设有10个说话人

    print("T4模型结构:")
    print(model)

    # 生成随机输入（使用较小的序列长度进行测试）
    batch_size = 2
    seq_len = Config.MAX_SEQ_LEN
    # 正确的输入格式：[batch_size, channels, seq_len]
    x = torch.randn(batch_size, 1, seq_len)

    print(f"\n测试前向传播:")
    print(f"输入形状: {x.shape}")

    # 前向传播
    try:
        embedding, logits = model(x)
        print(f"嵌入向量形状: {embedding.shape}")
        print(f"分类输出形状: {logits.shape}")
        print("前向传播成功!")
    except Exception as e:
        print(f"前向传播错误: {e}")
        import traceback
        traceback.print_exc()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import os
import gc  # 垃圾回收
import math

# 训练器类
class Trainer:
    def __init__(self, config=Config, model=None):
        """初始化训练器"""
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"使用设备: {self.device}")

        # 创建模型
        if model is None:
            raise ValueError("初始化Trainer时必须传入model参数")
        self.model = model.to(self.device)  # 使用传入的模型

        # 定义损失函数
        self.ce_loss = nn.CrossEntropyLoss()

        # 定义优化器
        self.optimizer = optim.Adam(
            self.model.parameters(),
            lr=config.LR,
            weight_decay=config.WEIGHT_DECAY
        )

        # 学习率调度器
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode='min',
            factor=0.5,
            patience=3,
        )

        # 学习率预热调度器
        self.warmup_scheduler = optim.lr_scheduler.LambdaLR(
            self.optimizer,
            lr_lambda=lambda epoch: min(1.0, epoch / config.WARMUP_EPOCHS)
        )

        # 混合精度训练
        if config.ENABLE_MIXED_PRECISION and torch.cuda.is_available():
            self.scaler = torch.cuda.amp.GradScaler()
            print("启用混合精度训练")
        else:
            self.scaler = None

        # 创建检查点目录
        os.makedirs(config.CHECKPOINT_DIR, exist_ok=True)

        # 早停计数器
        self.early_stop_counter = 0
        self.best_val_accuracy = 0.0  # 改为基于准确率早停
        print(f"模型预期输入长度: {config.MAX_SEQ_LEN}")

    def train_one_epoch(self, dataloader, epoch):
        """
        训练一个epoch
        :param dataloader: 训练数据加载器
        :param epoch: 当前epoch
        :return: 平均训练损失和准确率
        """
        self.model.train()
        total_loss = 0.0
        total_ce_loss = 0.0
        correct = 0
        total = 0

        # 梯度累积
        accumulation_steps = self.config.GRADIENT_ACCUMULATION_STEPS
        self.optimizer.zero_grad()

        progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))
        for i, (inputs, labels) in progress_bar:
            try:
                inputs, labels = inputs.to(self.device), labels.to(self.device)

                # 确保输入维度正确 [batch, 1, seq_len]
                if inputs.dim() == 2:  # [batch, seq_len]
                    inputs = inputs.unsqueeze(1)  # 添加通道维度 -> [batch, 1, seq_len]

                # 确保音频长度正确
                if inputs.size(2) > self.config.MAX_SEQ_LEN:
                    inputs = inputs[:, :, :self.config.MAX_SEQ_LEN]
                elif inputs.size(2) < self.config.MAX_SEQ_LEN:
                    pad_len = self.config.MAX_SEQ_LEN - inputs.size(2)
                    inputs = torch.nn.functional.pad(inputs, (0, pad_len), value=0.0)

                # 使用混合精度训练
                if self.scaler is not None:
                    with torch.cuda.amp.autocast():
                        # 前向传播
                        embeddings, logits = self.model(inputs)

                        # 计算损失
                        ce = self.ce_loss(logits, labels)
                        loss = self.config.CE_WEIGHT * ce

                        # 梯度累积
                        loss = loss / accumulation_steps

                    # 反向传播
                    self.scaler.scale(loss).backward()

                    # 梯度裁剪
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.GRAD_CLIP)

                    if (i + 1) % accumulation_steps == 0:
                        self.scaler.step(self.optimizer)  # 优化器步骤
                        self.scaler.update()
                        # 2. 再调用预热调度器
                        if epoch <= self.config.WARMUP_EPOCHS:
                            self.warmup_scheduler.step()
                        self.optimizer.zero_grad()

                else:
                    # 标准训练（不使用混合精度）
                    embeddings, logits = self.model(inputs)

                    # 计算损失
                    ce = self.ce_loss(logits, labels)
                    loss = self.config.CE_WEIGHT * ce

                    # 梯度累积
                    loss = loss / accumulation_steps
                    loss.backward()

                    # 梯度裁剪
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.GRAD_CLIP)

                    if (i + 1) % accumulation_steps == 0:
                        # 1. 先更新参数
                        self.optimizer.step()
                        # 2. 再调用学习率预热调度器（仅在预热阶段）
                        if epoch <= self.config.WARMUP_EPOCHS:
                            self.warmup_scheduler.step()  # 移动到optimizer.step()之后
                        self.optimizer.zero_grad()

                # 统计
                total_loss += loss.item() * accumulation_steps
                total_ce_loss += ce.item()
                _, predicted = logits.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

                # 更新进度条
                if i % 10 == 0:  # 减少更新频率
                    accuracy = 100. * correct / total
                    avg_loss = total_loss / (i + 1)
                    progress_bar.set_description(
                        f"Epoch {epoch}, Loss: {avg_loss:.4f}, Acc: {accuracy:.2f}%"
                    )

                # 手动垃圾回收
                if i % 50 == 0:
                    torch.cuda.empty_cache()
                    gc.collect()

            except RuntimeError as e:
                if "out of memory" in str(e):
                    print(f"内存不足错误在批次 {i}: {e}")
                    torch.cuda.empty_cache()
                    gc.collect()
                    continue
                else:
                    raise e

        avg_loss = total_loss / len(dataloader)
        accuracy = 100. * correct / total
        return avg_loss, accuracy

    def validate(self, dataloader):
        """
        验证模型性能
        :param dataloader: 验证数据加载器
        :return: 验证损失和准确率
        """
        self.model.eval()
        total_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for i, (inputs, labels) in enumerate(dataloader):
                try:
                    inputs, labels = inputs.to(self.device), labels.to(self.device)

                    # 确保输入维度正确 [batch, 1, seq_len]
                    if inputs.dim() == 2:  # [batch, seq_len]
                        inputs = inputs.unsqueeze(1)  # 添加通道维度 -> [batch, 1, seq_len]

                    # 确保音频长度正确
                    if inputs.size(2) > self.config.MAX_SEQ_LEN:
                        inputs = inputs[:, :, :self.config.MAX_SEQ_LEN]
                    elif inputs.size(2) < self.config.MAX_SEQ_LEN:
                        pad_len = self.config.MAX_SEQ_LEN - inputs.size(2)
                        inputs = torch.nn.functional.pad(inputs, (0, pad_len), value=0.0)

                    # 前向传播
                    embeddings, logits = self.model(inputs)

                    # 计算损失
                    ce = self.ce_loss(logits, labels)
                    loss = self.config.CE_WEIGHT * ce

                    total_loss += loss.item()

                    # 统计准确率
                    _, predicted = logits.max(1)
                    total += labels.size(0)
                    correct += predicted.eq(labels).sum().item()

                    # 内存管理
                    if i % 20 == 0:
                        torch.cuda.empty_cache()
                        gc.collect()

                except RuntimeError as e:
                    if "out of memory" in str(e):
                        print(f"验证时内存不足: {e}")
                        torch.cuda.empty_cache()
                        gc.collect()
                        continue
                    else:
                        raise e

        avg_loss = total_loss / len(dataloader)
        accuracy = 100. * correct / total
        return avg_loss, accuracy

    def train(self, train_dataloader, val_dataloader):
        """
        完整训练流程
        :param train_dataloader: 训练数据加载器
        :param val_dataloader: 验证数据加载器
        """
        print("开始训练...")
        print(f"训练集批次: {len(train_dataloader)}, 验证集批次: {len(val_dataloader)}")

        for epoch in range(1, self.config.EPOCHS + 1):
            try:
                # 应用学习率预热
                if epoch <= self.config.WARMUP_EPOCHS:
                    self.warmup_scheduler.step()

                # 训练一个epoch
                train_loss, train_acc = self.train_one_epoch(train_dataloader, epoch)
                print(f"Epoch {epoch}/{self.config.EPOCHS}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")

                # 验证
                val_loss, val_acc = self.validate(val_dataloader)
                print(f"Epoch {epoch}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

                # 更新学习率
                self.scheduler.step(val_loss)
                current_lr = self.optimizer.param_groups[0]['lr']
                print(f"当前学习率: {current_lr:.6f}")

                # 早停检查 (基于验证准确率)
                if val_acc > self.best_val_accuracy + self.config.MIN_DELTA:
                    self.best_val_accuracy = val_acc
                    self.early_stop_counter = 0
                    # 保存最佳模型
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': self.model.state_dict(),
                        'optimizer_state_dict': self.optimizer.state_dict(),
                        'val_loss': val_loss,
                        'val_accuracy': val_acc
                    }, os.path.join(self.config.CHECKPOINT_DIR, 'best_model.pth'))
                    print(f"保存最佳模型，验证准确率: {val_acc:.2f}%")
                else:
                    self.early_stop_counter += 1
                    if self.early_stop_counter >= self.config.PATIENCE:
                        print(f"早停于第 {epoch} 轮")
                        break

                # 定期保存模型
                if epoch % self.config.SAVE_INTERVAL == 0:
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': self.model.state_dict(),
                        'optimizer_state_dict': self.optimizer.state_dict(),
                        'train_loss': train_loss,
                        'train_accuracy': train_acc
                    }, os.path.join(self.config.CHECKPOINT_DIR, f'model_epoch_{epoch}.pth'))

                # 清理内存
                torch.cuda.empty_cache()
                gc.collect()

            except KeyboardInterrupt:
                print("训练被用户中断")
                break
            except Exception as e:
                print(f"训练错误: {e}")
                torch.cuda.empty_cache()
                gc.collect()
                continue

    def load_checkpoint(self, checkpoint_path):
        """
        加载检查点
        :param checkpoint_path: 检查点路径
        """
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        if 'optimizer_state_dict' in checkpoint:
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint.get('epoch', 0)
        print(f"从第 {epoch} 轮加载检查点")
        return epoch

# 评估器类
class Evaluator:
    def __init__(self, model, config=Config):
        """初始化评估器"""
        self.model = model
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.model.eval()

    def evaluate_accuracy(self, dataloader):
        """
        评估模型准确率
        :param dataloader: 数据加载器
        :return: 准确率
        """
        correct = 0
        total = 0

        with torch.no_grad():
            for i, (inputs, labels) in enumerate(dataloader):
                try:
                    inputs, labels = inputs.to(self.device), labels.to(self.device)

                    # 确保输入维度正确
                    if inputs.dim() == 2:  # [batch, seq_len]
                        inputs = inputs.unsqueeze(1)  # 添加通道维度 -> [batch, 1, seq_len]

                    # 截断过长的序列
                    if inputs.size(2) > self.config.MAX_SEQ_LEN:
                        inputs = inputs[:, :, :self.config.MAX_SEQ_LEN]

                    _, logits = self.model(inputs)
                    _, predicted = logits.max(1)
                    total += labels.size(0)
                    correct += predicted.eq(labels).sum().item()

                    # 内存管理
                    if i % 20 == 0:
                        torch.cuda.empty_cache()
                        gc.collect()

                except RuntimeError as e:
                    if "out of memory" in str(e):
                        print(f"评估时内存不足: {e}")
                        torch.cuda.empty_cache()
                        gc.collect()
                        continue
                    else:
                        raise e

        accuracy = 100. * correct / total
        return accuracy

# 测试代码
if __name__ == "__main__":
    # 设置较小的批次大小
    batch_size = Config.BATCH_SIZE

    print("创建数据加载器...")
    try:
        dataloaders = get_dataloaders(batch_size=batch_size)

        if dataloaders is None:
            print("无法创建数据加载器，退出")
            exit(1)

        print(f"训练集批次数: {len(dataloaders['train'])}")
        print(f"验证集批次数: {len(dataloaders['val'])}")
        print(f"测试集批次数: {len(dataloaders['test'])}")

        # 关键修改：从训练集中获取实际说话人数量（类别数）
        num_speakers = len(dataloaders["train"].dataset.speaker_to_idx)
        print(f"数据集中实际说话人数量（类别数）: {num_speakers}")

        # 创建训练器前，先初始化模型并传入正确的num_classes
        print("创建模型和训练器...")
        # 关键修改：确保在创建Trainer时使用正确的num_classes初始化模型
        # 1. 初始化模型（传入正确的num_classes）
        model = CHiLAPModel(num_classes=num_speakers)
        # 2. 将模型传入Trainer，确保优化器初始化时模型已存在
        trainer = Trainer(config=Config, model=model)  # 传入model参数

        # 测试一个批次
        print("测试前向传播...")
        try:
            x, y = next(iter(dataloaders["train"]))
            print(f"原始输入形状: {x.shape}, 标签形状: {y.shape}")

            # 处理输入维度
            if x.dim() == 2:  # [batch, seq_len]
                x = x.unsqueeze(1)  # 添加通道维度 -> [batch, 1, seq_len]

            # 确保音频长度正确
            if x.size(2) > Config.MAX_SEQ_LEN:
                x = x[:, :, :Config.MAX_SEQ_LEN]
            elif x.size(2) < Config.MAX_SEQ_LEN:
                pad_len = Config.MAX_SEQ_LEN - x.size(2)
                x = torch.nn.functional.pad(x, (0, pad_len), value=0.0)

            print(f"处理后输入形状: {x.shape}")

            x = x.to(trainer.device)
            y = y.to(trainer.device)

            with torch.no_grad():
                embeddings, logits = trainer.model(x)
                print(f"嵌入形状: {embeddings.shape}, 输出形状: {logits.shape}")
                print("前向传播测试成功!")
        except Exception as e:
            print(f"前向传播测试失败: {e}")
            import traceback
            traceback.print_exc()
            exit(1)

        # 开始训练
        print("开始训练...")
        trainer.train(dataloaders["train"], dataloaders["val"])

        # 创建评估器
        print("创建评估器...")
        evaluator = Evaluator(trainer.model)

        # 评估模型
        print("评估模型...")
        test_accuracy = evaluator.evaluate_accuracy(dataloaders["test"])
        print(f"测试集准确率: {test_accuracy:.2f}%")

    except Exception as e:
        print(f"程序执行错误: {e}")
        import traceback
        traceback.print_exc()