### Config

In [1]:
# 导入配置
from config import Config

### Data_loader.py

In [2]:
import os
import random
import librosa
import numpy as np
import os
import random
import librosa
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import soundfile as sf
from tqdm import tqdm


# 噪声生成与注入
class NoiseInjector:
    @staticmethod
    def generate_white_noise(length):
        return np.random.randn(length).astype(np.float32)

    @staticmethod
    def generate_babble_noise(length, num_speakers=3):
        noise = np.zeros(length, dtype=np.float32)
        for _ in range(num_speakers):
            start = random.randint(0, max(0, length - Config.SAMPLE_RATE))
            end = min(start + Config.SAMPLE_RATE, length)
            noise[start:end] += np.random.randn(end - start).astype(np.float32)
        return noise / num_speakers

    @staticmethod
    def add_noise(signal, noise_type="white", snr_db=10):
        if len(signal) == 0:
            return signal

        signal_power = np.mean(signal ** 2)
        if signal_power < 1e-10:
            return signal

        signal_db = 10 * np.log10(signal_power)

        if noise_type == "white":
            noise = NoiseInjector.generate_white_noise(len(signal))
        elif noise_type == "babble":
            noise = NoiseInjector.generate_babble_noise(len(signal))
        else:
            raise ValueError(f"不支持的噪声类型：{noise_type}")

        noise_power = np.mean(noise ** 2)
        noise_db = -100 if noise_power < 1e-10 else 10 * np.log10(noise_power)

        target_noise_db = signal_db - snr_db
        noise_scale = 10 ** ((target_noise_db - noise_db) / 20)
        noisy_signal = signal + noise * noise_scale

        # 归一化
        max_val = np.max(np.abs(noisy_signal))
        if max_val > 1e-5:
            noisy_signal = noisy_signal / max_val

        return noisy_signal

# 数据集类
class SpeakerRecognitionDataset(Dataset):
    def __init__(self, split="train", add_noise=False, noise_type="white", snr_db=10):
        self.split = split
        self.add_noise = add_noise
        self.noise_type = noise_type
        self.snr_db = snr_db

        # 加载数据集
        self.audio_paths, self.labels = self._load_dataset()
        self.speaker_to_idx = self._build_speaker_map()

        # 调试模式
        if Config.DEBUG_MODE:
            if split == "train":
                self.audio_paths = self.audio_paths[:Config.DEBUG_SAMPLE_SIZE]
                self.labels = self.labels[:Config.DEBUG_SAMPLE_SIZE]
            elif split == "val":
                self.audio_paths = self.audio_paths[:min(Config.DEBUG_SAMPLE_SIZE // 2, len(self.audio_paths))]
                self.labels = self.labels[:min(Config.DEBUG_SAMPLE_SIZE // 2, len(self.labels))]

        # 验证数据集
        self._validate_dataset()
        print(f"最终 {split} 数据集大小: {len(self.audio_paths)} 个样本")

    def _load_dataset(self):
        audio_paths = []
        labels = []

        root = Config.LIBRISPEECH_PATH
        if not os.path.exists(root):
            print(f"错误: LibriSpeech路径不存在 - {root}")
            return [], []

        print(f"加载LibriSpeech数据集: {root}")

        # 只遍历 dev-clean 和 dev-other
        for subset_dir in ["dev-clean", "dev-other"]:
            subset_path = os.path.join(root, subset_dir)
            if not os.path.isdir(subset_path):
                continue

            print(f"处理子集: {subset_dir}")
            for speaker_dir in os.listdir(subset_path):
                speaker_path = os.path.join(subset_path, speaker_dir)
                if not os.path.isdir(speaker_path):
                    continue

                for chapter_dir in os.listdir(speaker_path):
                    chapter_path = os.path.join(speaker_path, chapter_dir)
                    if not os.path.isdir(chapter_path):
                        continue

                    for file in os.listdir(chapter_path):
                        if file.endswith(".flac"):
                            full_path = os.path.join(chapter_path, file)
                            audio_paths.append(full_path)
                            labels.append(speaker_dir)

            print(f"在 {subset_dir} 中找到 {len(audio_paths)} 个.flac文件")

        # 分割训练/验证
        if self.split != "test" and len(audio_paths) > 0:
            train_paths, val_paths, train_labels, val_labels = train_test_split(
                audio_paths, labels, test_size=Config.VALID_RATIO, random_state=42
            )
            if self.split == "train":
                return train_paths, train_labels
            else:
                return val_paths, val_labels

        return audio_paths, labels

    def _build_speaker_map(self):
        unique_speakers = sorted(set(self.labels))
        print(f"找到 {len(unique_speakers)} 个不同的说话人")
        return {speaker: idx for idx, speaker in enumerate(unique_speakers)}

    def _validate_dataset(self):
        if len(self.audio_paths) == 0:
            print(f"警告: {self.split}数据集为空")
            return

        valid_count = 0
        invalid_indices = []
        for i in range(len(self.audio_paths) - 1, -1, -1):
            path = self.audio_paths[i]
            try:
                if not os.path.exists(path):
                    raise FileNotFoundError("文件不存在")

                if os.path.getsize(path) < 1024:
                    raise ValueError("文件太小可能已损坏")

                if path.endswith('.flac'):
                    signal, sr = sf.read(path)
                else:
                    signal, sr = librosa.load(path, sr=Config.SAMPLE_RATE, mono=True)

                if len(signal) < Config.SAMPLE_RATE // 2:
                    raise ValueError("音频过短")

                if np.max(np.abs(signal)) < 1e-5:
                    raise ValueError("接近静音")

                valid_count += 1
            except Exception as e:
                invalid_indices.append(i)
                print(f"无效文件: {path} - {str(e)}")

        for i in invalid_indices:
            self.audio_paths.pop(i)
            self.labels.pop(i)

        print(f"有效文件: {valid_count}/{len(self.audio_paths) + len(invalid_indices)}")
        print(f"移除 {len(invalid_indices)} 个无效文件")
        self.speaker_to_idx = self._build_speaker_map()

    def __len__(self):
        return len(self.audio_paths)

    def __getitem__(self, idx):
        path = self.audio_paths[idx]
        speaker_id = self.labels[idx]
        label = self.speaker_to_idx[speaker_id]

        try:
            if path.endswith('.flac'):
                signal, sr = sf.read(path)
                if sr != Config.SAMPLE_RATE:
                    signal = librosa.resample(signal, orig_sr=sr, target_sr=Config.SAMPLE_RATE)
            else:
                signal, sr = librosa.load(path, sr=Config.SAMPLE_RATE, mono=True)
        except Exception as e:
            print(f"加载音频错误 {path}: {str(e)}")
            signal = np.zeros(Config.MAX_SAMPLES, dtype=np.float32)

        if len(signal) > Config.MAX_SAMPLES:
            signal = signal[:Config.MAX_SAMPLES]
        elif len(signal) < Config.MAX_SAMPLES:
            signal = np.pad(signal, (0, Config.MAX_SAMPLES - len(signal)), mode='constant')

        max_val = np.max(np.abs(signal))
        if max_val > 1e-5:
            signal = signal / max_val

        if self.add_noise and self.split == "train":
            noise_type = random.choice(Config.NOISE_TYPES)
            snr_db = random.choice(Config.SNR_LEVELS)
            signal = NoiseInjector.add_noise(signal, noise_type, snr_db)

        return torch.FloatTensor(signal), label

# 数据加载器
def get_dataloaders(batch_size=None):
    if batch_size is None:
        batch_size = Config.BATCH_SIZE

    train_dataset = SpeakerRecognitionDataset(split="train")
    val_dataset = SpeakerRecognitionDataset(split="val")
    test_dataset = SpeakerRecognitionDataset(split="test")

    if len(train_dataset) == 0 and len(test_dataset) > 0:
        print("警告: 训练集为空，使用测试集作为训练集")
        train_dataset = test_dataset

    noisy_test_dataset = SpeakerRecognitionDataset(
        split="test", add_noise=True, noise_type="white", snr_db=5
    )

    print(f"训练集: {len(train_dataset)} 样本")
    print(f"验证集: {len(val_dataset)} 样本")
    print(f"测试集: {len(test_dataset)} 样本")
    print(f"带噪声测试集: {len(noisy_test_dataset)} 样本")
    print(f"总说话人数: {len(train_dataset.speaker_to_idx)}")

    dataloaders = {
        "train": DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                            num_workers=Config.NUM_WORKERS, pin_memory=True),
        "val": DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                          num_workers=Config.NUM_WORKERS, pin_memory=True),
        "test": DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                           num_workers=Config.NUM_WORKERS, pin_memory=True),
        "noisy_test": DataLoader(noisy_test_dataset, batch_size=batch_size, shuffle=False,
                                 num_workers=Config.NUM_WORKERS, pin_memory=True),
    }
    return dataloaders

if __name__ == "__main__":
    print("测试数据加载器...")
    dataloaders = get_dataloaders()
    x, y = next(iter(dataloaders["train"]))
    print(f"音频数据形状: {x.shape}")
    print(f"标签数据形状: {y.shape}")

测试数据加载器...
加载LibriSpeech数据集: /users/tianyuey/Project/devDataset/LibriSpeech
处理子集: dev-clean
在 dev-clean 中找到 2703 个.flac文件
处理子集: dev-other
在 dev-other 中找到 5567 个.flac文件
找到 73 个不同的说话人
有效文件: 5000/5000
移除 0 个无效文件
找到 73 个不同的说话人
最终 train 数据集大小: 5000 个样本
加载LibriSpeech数据集: /users/tianyuey/Project/devDataset/LibriSpeech
处理子集: dev-clean
在 dev-clean 中找到 2703 个.flac文件
处理子集: dev-other
在 dev-other 中找到 5567 个.flac文件
找到 73 个不同的说话人
有效文件: 557/557
移除 0 个无效文件
找到 73 个不同的说话人
最终 val 数据集大小: 557 个样本
加载LibriSpeech数据集: /users/tianyuey/Project/devDataset/LibriSpeech
处理子集: dev-clean
在 dev-clean 中找到 2703 个.flac文件
处理子集: dev-other
在 dev-other 中找到 5567 个.flac文件
找到 73 个不同的说话人
有效文件: 5567/5567
移除 0 个无效文件
找到 73 个不同的说话人
最终 test 数据集大小: 5567 个样本
加载LibriSpeech数据集: /users/tianyuey/Project/devDataset/LibriSpeech
处理子集: dev-clean
在 dev-clean 中找到 2703 个.flac文件
处理子集: dev-other
在 dev-other 中找到 5567 个.flac文件
找到 73 个不同的说话人
有效文件: 5567/5567
移除 0 个无效文件
找到 73 个不同的说话人
最终 test 数据集大小: 5567 个样本
训练集: 5000 样本
验证集: 557 样本
测试集: 5567 样本
带噪声测试集: 55

### c-hilap-model.py

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# 导入配置
from config import Config

# 简化混沌激励模块
class ChaoticStimulus(nn.Module):
    def __init__(self, input_dim, output_dim):
        """
        简化的混沌激励模块
        :param input_dim: 输入维度
        :param output_dim: 输出维度
        """
        super().__init__()
        self.chaos_transform = nn.Sequential(
            nn.Conv1d(input_dim, output_dim, kernel_size=3, padding=1),
            nn.BatchNorm1d(output_dim),
            nn.PReLU(),
            nn.Conv1d(output_dim, output_dim, kernel_size=3, padding=1),
            nn.BatchNorm1d(output_dim),
            nn.PReLU()
        )

        # 混沌扰动参数
        self.chaos_factor = nn.Parameter(torch.tensor(0.1))

    def forward(self, x):
        """
        前向传播
        :param x: 输入特征 [batch_size, channels, seq_len]
        :return: 混沌处理后的特征 [batch_size, channels, seq_len]
        """
        # 常规特征变换
        transformed = self.chaos_transform(x)

        # 添加混沌扰动
        batch_size, channels, seq_len = transformed.size()
        if self.training:  # 仅在训练时添加混沌扰动
            # 生成与特征相同形状的混沌噪声
            chaos_noise = torch.randn_like(transformed) * self.chaos_factor
            # 应用非线性激活增强混沌特性
            chaos_noise = torch.tanh(chaos_noise)
            transformed = transformed + chaos_noise

        return transformed

# 简化注意力机制
class SimpleAttention(nn.Module):
    def __init__(self, input_dim):
        """
        简化的注意力机制
        :param input_dim: 输入维度
        """
        super().__init__()
        self.attention = nn.Sequential(
            nn.Conv1d(input_dim, 1, kernel_size=1),
            nn.Softmax(dim=2)
        )

    def forward(self, x):
        """
        前向传播
        :param x: 输入特征 [batch_size, channels, seq_len]
        :return: 注意力加权后的特征 [batch_size, channels, seq_len]
        """
        # 计算注意力权重 [batch_size, 1, seq_len]
        attn_weights = self.attention(x)

        # 应用注意力权重
        return x * attn_weights

# 统计池化层
class StatisticalPooling(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        """
        前向传播
        :param x: 输入特征 [batch_size, channels, seq_len]
        :return: 池化后的特征 [batch_size, channels*2]
        """
        # 计算均值和标准差
        mean = torch.mean(x, dim=2)
        std = torch.std(x, dim=2)

        # 拼接均值和标准差
        return torch.cat((mean, std), dim=1)

# 完整的C-HiLAP模型（简化版）
class CHiLAPModel(nn.Module):
    def __init__(self, input_dim=Config.INPUT_DIM, hidden_dim=Config.HIDDEN_DIM,
                 embedding_dim=Config.EMBEDDING_DIM, num_classes=None):
        """
        混沌层次吸引子传播(C-HiLAP)模型 - 简化版
        """
        super().__init__()

        # 若未传入num_classes，可设置一个默认值（但实际使用时必须从数据集获取后传入）
        if num_classes is None:
            raise ValueError("必须指定num_classes（说话人数量），请从数据集获取后传入")

        # 特征提取层
        self.feature_extractor = nn.Sequential(
            nn.Conv1d(input_dim, hidden_dim, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm1d(hidden_dim),
            nn.PReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm1d(hidden_dim),
            nn.PReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm1d(hidden_dim),
            nn.PReLU()
        )

        # 混沌激励模块
        self.chaos_layer = ChaoticStimulus(hidden_dim, hidden_dim)

        # 注意力层
        self.attention = SimpleAttention(hidden_dim)

        # TDNN层
        self.tdnn_block = nn.Sequential(
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, dilation=1, padding=1),
            nn.BatchNorm1d(hidden_dim),
            nn.PReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, dilation=2, padding=2),
            nn.BatchNorm1d(hidden_dim),
            nn.PReLU()
        )

        # 池化层
        self.pooling = StatisticalPooling()

        # 嵌入层
        self.embedding = nn.Sequential(
            nn.Linear(hidden_dim * 2, embedding_dim),  # 统计池化输出channels*2
            nn.BatchNorm1d(embedding_dim),
            nn.PReLU(),
            nn.Dropout(0.2)
        )

        # 分类器
        self.classifier = nn.Linear(embedding_dim, num_classes)

    def forward(self, x):
        """
        前向传播
        :param x: 输入特征 [batch_size, channels, seq_len] 或 [batch_size, seq_len, channels]
        :return: 嵌入向量和分类结果
        """
        # 检查输入维度并转换为正确的格式 [batch_size, channels, seq_len]
        if x.dim() == 3:
            # 如果是 [batch_size, seq_len, channels] 格式
            if x.size(1) > x.size(2):  # 序列长度应该大于通道数
                x = x.permute(0, 2, 1)  # 转换为 [batch_size, channels, seq_len]

        # 限制序列长度防止内存溢出
        seq_len = x.size(2)
        if seq_len > Config.MAX_SEQ_LEN:
            x = x[:, :, :Config.MAX_SEQ_LEN]

        # 特征提取
        x = self.feature_extractor(x)

        # 混沌处理
        x = self.chaos_layer(x)

        # 注意力加权
        x = self.attention(x)

        # TDNN处理
        x = self.tdnn_block(x)

        # 池化
        x = self.pooling(x)

        # 嵌入向量
        embedding = self.embedding(x)

        # 分类
        logits = self.classifier(embedding)

        return embedding, logits

# 测试代码
if __name__ == "__main__":
    # 创建模型实例
    model = CHiLAPModel(num_classes=73)  # 需要传入num_classes参数

    print("模型结构:")
    print(model)

    # 生成随机输入（使用较小的序列长度进行测试）
    batch_size = 2
    seq_len = Config.MAX_SEQ_LEN
    # 正确的输入格式：[batch_size, channels, seq_len]
    x = torch.randn(batch_size, 1, seq_len)

    print(f"\n测试前向传播:")
    print(f"输入形状: {x.shape}")

    # 前向传播
    try:
        embedding, logits = model(x)
        print(f"嵌入向量形状: {embedding.shape}")
        print(f"分类输出形状: {logits.shape}")
        print("前向传播成功!")
    except Exception as e:
        print(f"前向传播错误: {e}")
        import traceback
        traceback.print_exc()

模型结构:
CHiLAPModel(
  (feature_extractor): Sequential(
    (0): Conv1d(1, 256, kernel_size=(5,), stride=(2,), padding=(2,))
    (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): PReLU(num_parameters=1)
    (3): Conv1d(256, 256, kernel_size=(5,), stride=(2,), padding=(2,))
    (4): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): PReLU(num_parameters=1)
    (6): Conv1d(256, 256, kernel_size=(5,), stride=(2,), padding=(2,))
    (7): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): PReLU(num_parameters=1)
  )
  (chaos_layer): ChaoticStimulus(
    (chaos_transform): Sequential(
      (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): PReLU(num_parameters=1)
      (3): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
      (4): BatchNorm

### trainer.py

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import os
import gc  # 垃圾回收
import math

# 导入配置
from config import Config


# 训练器类
class Trainer:
    def __init__(self, config=Config, model=None):
        """初始化训练器"""
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"使用设备: {self.device}")

        # 创建模型
        if model is None:
            raise ValueError("初始化Trainer时必须传入model参数")
        self.model = model.to(self.device)  # 使用传入的模型

        # 定义损失函数
        self.ce_loss = nn.CrossEntropyLoss()

        # 定义优化器
        self.optimizer = optim.Adam(
            self.model.parameters(),
            lr=config.LR,
            weight_decay=config.WEIGHT_DECAY
        )

        # 学习率调度器
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode='min',
            factor=0.5,
            patience=3,
        )

        # 学习率预热调度器
        self.warmup_scheduler = optim.lr_scheduler.LambdaLR(
            self.optimizer,
            lr_lambda=lambda epoch: min(1.0, epoch / config.WARMUP_EPOCHS)
        )

        # 混合精度训练
        if config.ENABLE_MIXED_PRECISION and torch.cuda.is_available():
            self.scaler = torch.cuda.amp.GradScaler()
            print("启用混合精度训练")
        else:
            self.scaler = None

        # 创建检查点目录
        os.makedirs(config.CHECKPOINT_DIR, exist_ok=True)

        # 早停计数器
        self.early_stop_counter = 0
        self.best_val_accuracy = 0.0  # 改为基于准确率早停
        print(f"模型预期输入长度: {config.MAX_SEQ_LEN}")

    def train_one_epoch(self, dataloader, epoch):
        """
        训练一个epoch
        :param dataloader: 训练数据加载器
        :param epoch: 当前epoch
        :return: 平均训练损失和准确率
        """
        self.model.train()
        total_loss = 0.0
        total_ce_loss = 0.0
        correct = 0
        total = 0

        # 梯度累积
        accumulation_steps = self.config.GRADIENT_ACCUMULATION_STEPS
        self.optimizer.zero_grad()

        progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))
        for i, (inputs, labels) in progress_bar:
            try:
                inputs, labels = inputs.to(self.device), labels.to(self.device)

                # 确保输入维度正确 [batch, 1, seq_len]
                if inputs.dim() == 2:  # [batch, seq_len]
                    inputs = inputs.unsqueeze(1)  # 添加通道维度 -> [batch, 1, seq_len]

                # 确保音频长度正确
                if inputs.size(2) > self.config.MAX_SEQ_LEN:
                    inputs = inputs[:, :, :self.config.MAX_SEQ_LEN]
                elif inputs.size(2) < self.config.MAX_SEQ_LEN:
                    pad_len = self.config.MAX_SEQ_LEN - inputs.size(2)
                    inputs = torch.nn.functional.pad(inputs, (0, pad_len), value=0.0)

                # 使用混合精度训练
                if self.scaler is not None:
                    with torch.cuda.amp.autocast():
                        # 前向传播
                        embeddings, logits = self.model(inputs)

                        # 计算损失
                        ce = self.ce_loss(logits, labels)
                        loss = self.config.CE_WEIGHT * ce

                        # 梯度累积
                        loss = loss / accumulation_steps

                    # 反向传播
                    self.scaler.scale(loss).backward()

                    # 梯度裁剪
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.GRAD_CLIP)

                    if (i + 1) % accumulation_steps == 0:
                        self.scaler.step(self.optimizer)  # 优化器步骤
                        self.scaler.update()
                        # 2. 再调用预热调度器
                        if epoch <= self.config.WARMUP_EPOCHS:
                            self.warmup_scheduler.step()
                        self.optimizer.zero_grad()

                else:
                    # 标准训练（不使用混合精度）
                    embeddings, logits = self.model(inputs)

                    # 计算损失
                    ce = self.ce_loss(logits, labels)
                    loss = self.config.CE_WEIGHT * ce

                    # 梯度累积
                    loss = loss / accumulation_steps
                    loss.backward()

                    # 梯度裁剪
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.GRAD_CLIP)

                    if (i + 1) % accumulation_steps == 0:
                        # 1. 先更新参数
                        self.optimizer.step()
                        # 2. 再调用学习率预热调度器（仅在预热阶段）
                        if epoch <= self.config.WARMUP_EPOCHS:
                            self.warmup_scheduler.step()  # 移动到optimizer.step()之后
                        self.optimizer.zero_grad()

                # 统计
                total_loss += loss.item() * accumulation_steps
                total_ce_loss += ce.item()
                _, predicted = logits.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

                # 更新进度条
                if i % 10 == 0:  # 减少更新频率
                    accuracy = 100. * correct / total
                    avg_loss = total_loss / (i + 1)
                    progress_bar.set_description(
                        f"Epoch {epoch}, Loss: {avg_loss:.4f}, Acc: {accuracy:.2f}%"
                    )

                # 手动垃圾回收
                if i % 50 == 0:
                    torch.cuda.empty_cache()
                    gc.collect()

            except RuntimeError as e:
                if "out of memory" in str(e):
                    print(f"内存不足错误在批次 {i}: {e}")
                    torch.cuda.empty_cache()
                    gc.collect()
                    continue
                else:
                    raise e

        avg_loss = total_loss / len(dataloader)
        accuracy = 100. * correct / total
        return avg_loss, accuracy

    def validate(self, dataloader):
        """
        验证模型性能
        :param dataloader: 验证数据加载器
        :return: 验证损失和准确率
        """
        self.model.eval()
        total_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for i, (inputs, labels) in enumerate(dataloader):
                try:
                    inputs, labels = inputs.to(self.device), labels.to(self.device)

                    # 确保输入维度正确 [batch, 1, seq_len]
                    if inputs.dim() == 2:  # [batch, seq_len]
                        inputs = inputs.unsqueeze(1)  # 添加通道维度 -> [batch, 1, seq_len]

                    # 确保音频长度正确
                    if inputs.size(2) > self.config.MAX_SEQ_LEN:
                        inputs = inputs[:, :, :self.config.MAX_SEQ_LEN]
                    elif inputs.size(2) < self.config.MAX_SEQ_LEN:
                        pad_len = self.config.MAX_SEQ_LEN - inputs.size(2)
                        inputs = torch.nn.functional.pad(inputs, (0, pad_len), value=0.0)

                    # 前向传播
                    embeddings, logits = self.model(inputs)

                    # 计算损失
                    ce = self.ce_loss(logits, labels)
                    loss = self.config.CE_WEIGHT * ce

                    total_loss += loss.item()

                    # 统计准确率
                    _, predicted = logits.max(1)
                    total += labels.size(0)
                    correct += predicted.eq(labels).sum().item()

                    # 内存管理
                    if i % 20 == 0:
                        torch.cuda.empty_cache()
                        gc.collect()

                except RuntimeError as e:
                    if "out of memory" in str(e):
                        print(f"验证时内存不足: {e}")
                        torch.cuda.empty_cache()
                        gc.collect()
                        continue
                    else:
                        raise e

        avg_loss = total_loss / len(dataloader)
        accuracy = 100. * correct / total
        return avg_loss, accuracy

    def train(self, train_dataloader, val_dataloader):
        """
        完整训练流程
        :param train_dataloader: 训练数据加载器
        :param val_dataloader: 验证数据加载器
        """
        print("开始训练...")
        print(f"训练集批次: {len(train_dataloader)}, 验证集批次: {len(val_dataloader)}")

        for epoch in range(1, self.config.EPOCHS + 1):
            try:
                # 应用学习率预热
                if epoch <= self.config.WARMUP_EPOCHS:
                    self.warmup_scheduler.step()

                # 训练一个epoch
                train_loss, train_acc = self.train_one_epoch(train_dataloader, epoch)
                print(f"Epoch {epoch}/{self.config.EPOCHS}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")

                # 验证
                val_loss, val_acc = self.validate(val_dataloader)
                print(f"Epoch {epoch}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

                # 更新学习率
                self.scheduler.step(val_loss)
                current_lr = self.optimizer.param_groups[0]['lr']
                print(f"当前学习率: {current_lr:.6f}")

                # 早停检查 (基于验证准确率)
                if val_acc > self.best_val_accuracy + self.config.MIN_DELTA:
                    self.best_val_accuracy = val_acc
                    self.early_stop_counter = 0
                    # 保存最佳模型
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': self.model.state_dict(),
                        'optimizer_state_dict': self.optimizer.state_dict(),
                        'val_loss': val_loss,
                        'val_accuracy': val_acc
                    }, os.path.join(self.config.CHECKPOINT_DIR, 'best_model.pth'))
                    print(f"保存最佳模型，验证准确率: {val_acc:.2f}%")
                else:
                    self.early_stop_counter += 1
                    if self.early_stop_counter >= self.config.PATIENCE:
                        print(f"早停于第 {epoch} 轮")
                        break

                # 定期保存模型
                if epoch % self.config.SAVE_INTERVAL == 0:
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': self.model.state_dict(),
                        'optimizer_state_dict': self.optimizer.state_dict(),
                        'train_loss': train_loss,
                        'train_accuracy': train_acc
                    }, os.path.join(self.config.CHECKPOINT_DIR, f'model_epoch_{epoch}.pth'))

                # 清理内存
                torch.cuda.empty_cache()
                gc.collect()

            except KeyboardInterrupt:
                print("训练被用户中断")
                break
            except Exception as e:
                print(f"训练错误: {e}")
                torch.cuda.empty_cache()
                gc.collect()
                continue

    def load_checkpoint(self, checkpoint_path):
        """
        加载检查点
        :param checkpoint_path: 检查点路径
        """
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        if 'optimizer_state_dict' in checkpoint:
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint.get('epoch', 0)
        print(f"从第 {epoch} 轮加载检查点")
        return epoch

# 评估器类
class Evaluator:
    def __init__(self, model, config=Config):
        """初始化评估器"""
        self.model = model
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.model.eval()

    def evaluate_accuracy(self, dataloader):
        """
        评估模型准确率
        :param dataloader: 数据加载器
        :return: 准确率
        """
        correct = 0
        total = 0

        with torch.no_grad():
            for i, (inputs, labels) in enumerate(dataloader):
                try:
                    inputs, labels = inputs.to(self.device), labels.to(self.device)

                    # 确保输入维度正确
                    if inputs.dim() == 2:  # [batch, seq_len]
                        inputs = inputs.unsqueeze(1)  # 添加通道维度 -> [batch, 1, seq_len]

                    # 截断过长的序列
                    if inputs.size(2) > self.config.MAX_SEQ_LEN:
                        inputs = inputs[:, :, :self.config.MAX_SEQ_LEN]

                    _, logits = self.model(inputs)
                    _, predicted = logits.max(1)
                    total += labels.size(0)
                    correct += predicted.eq(labels).sum().item()

                    # 内存管理
                    if i % 20 == 0:
                        torch.cuda.empty_cache()
                        gc.collect()

                except RuntimeError as e:
                    if "out of memory" in str(e):
                        print(f"评估时内存不足: {e}")
                        torch.cuda.empty_cache()
                        gc.collect()
                        continue
                    else:
                        raise e

        accuracy = 100. * correct / total
        return accuracy

# 测试代码
if __name__ == "__main__":
    # 设置较小的批次大小
    batch_size = Config.BATCH_SIZE

    print("创建数据加载器...")
    try:
        dataloaders = get_dataloaders(batch_size=batch_size)

        if dataloaders is None:
            print("无法创建数据加载器，退出")
            exit(1)

        print(f"训练集批次数: {len(dataloaders['train'])}")
        print(f"验证集批次数: {len(dataloaders['val'])}")
        print(f"测试集批次数: {len(dataloaders['test'])}")

        # 关键修改：从训练集中获取实际说话人数量（类别数）
        num_speakers = len(dataloaders["train"].dataset.speaker_to_idx)
        print(f"数据集中实际说话人数量（类别数）: {num_speakers}")

        # 创建训练器前，先初始化模型并传入正确的num_classes
        print("创建模型和训练器...")
        # 关键修改：确保在创建Trainer时使用正确的num_classes初始化模型
        # 1. 初始化模型（传入正确的num_classes）
        model = CHiLAPModel(num_classes=num_speakers)
        # 2. 将模型传入Trainer，确保优化器初始化时模型已存在
        trainer = Trainer(config=Config, model=model)  # 传入model参数

        # 测试一个批次
        print("测试前向传播...")
        try:
            x, y = next(iter(dataloaders["train"]))
            print(f"原始输入形状: {x.shape}, 标签形状: {y.shape}")

            # 处理输入维度
            if x.dim() == 2:  # [batch, seq_len]
                x = x.unsqueeze(1)  # 添加通道维度 -> [batch, 1, seq_len]

            # 确保音频长度正确
            if x.size(2) > Config.MAX_SEQ_LEN:
                x = x[:, :, :Config.MAX_SEQ_LEN]
            elif x.size(2) < Config.MAX_SEQ_LEN:
                pad_len = Config.MAX_SEQ_LEN - x.size(2)
                x = torch.nn.functional.pad(x, (0, pad_len), value=0.0)

            print(f"处理后输入形状: {x.shape}")

            x = x.to(trainer.device)
            y = y.to(trainer.device)

            with torch.no_grad():
                embeddings, logits = trainer.model(x)
                print(f"嵌入形状: {embeddings.shape}, 输出形状: {logits.shape}")
                print("前向传播测试成功!")
        except Exception as e:
            print(f"前向传播测试失败: {e}")
            import traceback
            traceback.print_exc()
            exit(1)

        # 开始训练
        print("开始训练...")
        trainer.train(dataloaders["train"], dataloaders["val"])

        # 创建评估器
        print("创建评估器...")
        evaluator = Evaluator(trainer.model)

        # 评估模型
        print("评估模型...")
        test_accuracy = evaluator.evaluate_accuracy(dataloaders["test"])
        print(f"测试集准确率: {test_accuracy:.2f}%")

    except Exception as e:
        print(f"程序执行错误: {e}")
        import traceback
        traceback.print_exc()

创建数据加载器...
加载LibriSpeech数据集: /users/tianyuey/Project/devDataset/LibriSpeech
处理子集: dev-clean
在 dev-clean 中找到 2703 个.flac文件
处理子集: dev-other
在 dev-other 中找到 5567 个.flac文件
找到 73 个不同的说话人
有效文件: 5000/5000
移除 0 个无效文件
找到 73 个不同的说话人
最终 train 数据集大小: 5000 个样本
加载LibriSpeech数据集: /users/tianyuey/Project/devDataset/LibriSpeech
处理子集: dev-clean
在 dev-clean 中找到 2703 个.flac文件
处理子集: dev-other
在 dev-other 中找到 5567 个.flac文件
找到 73 个不同的说话人
有效文件: 557/557
移除 0 个无效文件
找到 73 个不同的说话人
最终 val 数据集大小: 557 个样本
加载LibriSpeech数据集: /users/tianyuey/Project/devDataset/LibriSpeech
处理子集: dev-clean
在 dev-clean 中找到 2703 个.flac文件
处理子集: dev-other
在 dev-other 中找到 5567 个.flac文件
找到 73 个不同的说话人
有效文件: 5567/5567
移除 0 个无效文件
找到 73 个不同的说话人
最终 test 数据集大小: 5567 个样本
加载LibriSpeech数据集: /users/tianyuey/Project/devDataset/LibriSpeech
处理子集: dev-clean
在 dev-clean 中找到 2703 个.flac文件
处理子集: dev-other
在 dev-other 中找到 5567 个.flac文件
找到 73 个不同的说话人
有效文件: 5567/5567
移除 0 个无效文件
找到 73 个不同的说话人
最终 test 数据集大小: 5567 个样本
训练集: 5000 样本
验证集: 557 样本
测试集: 5567 样本
带噪声测试集: 55

Epoch 1, Loss: 4.2879, Acc: 3.20%: 100%|██████████| 313/313 [00:30<00:00, 10.21it/s]

Epoch 1/500, Train Loss: 4.2878, Train Acc: 3.20%





Epoch 1, Val Loss: 4.2943, Val Acc: 4.49%
当前学习率: 0.010000
保存最佳模型，验证准确率: 4.49%


Epoch 2, Loss: 3.9846, Acc: 5.43%: 100%|██████████| 313/313 [00:29<00:00, 10.55it/s]

Epoch 2/500, Train Loss: 3.9835, Train Acc: 5.44%





Epoch 2, Val Loss: 3.9189, Val Acc: 4.67%
当前学习率: 0.010000
保存最佳模型，验证准确率: 4.67%


Epoch 3, Loss: 3.8470, Acc: 6.55%: 100%|██████████| 313/313 [00:29<00:00, 10.57it/s]

Epoch 3/500, Train Loss: 3.8472, Train Acc: 6.54%





Epoch 3, Val Loss: 3.8246, Val Acc: 7.00%
当前学习率: 0.010000
保存最佳模型，验证准确率: 7.00%


Epoch 4, Loss: 3.8319, Acc: 6.71%: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s]

Epoch 4/500, Train Loss: 3.8336, Train Acc: 6.70%





Epoch 4, Val Loss: 3.7173, Val Acc: 8.98%
当前学习率: 0.010000
保存最佳模型，验证准确率: 8.98%


Epoch 5, Loss: 3.7633, Acc: 7.64%: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s]

Epoch 5/500, Train Loss: 3.7624, Train Acc: 7.64%





Epoch 5, Val Loss: 3.7060, Val Acc: 7.00%
当前学习率: 0.010000


Epoch 6, Loss: 3.7326, Acc: 7.98%: 100%|██████████| 313/313 [00:29<00:00, 10.44it/s]

Epoch 6/500, Train Loss: 3.7310, Train Acc: 8.02%





Epoch 6, Val Loss: 3.6323, Val Acc: 8.08%
当前学习率: 0.010000


Epoch 7, Loss: 3.7107, Acc: 8.34%: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s]

Epoch 7/500, Train Loss: 3.7108, Train Acc: 8.32%





Epoch 7, Val Loss: 3.7878, Val Acc: 6.64%
当前学习率: 0.010000


Epoch 8, Loss: 3.6578, Acc: 8.74%: 100%|██████████| 313/313 [00:29<00:00, 10.59it/s]

Epoch 8/500, Train Loss: 3.6602, Train Acc: 8.72%





Epoch 8, Val Loss: 3.5876, Val Acc: 8.26%
当前学习率: 0.010000


Epoch 9, Loss: 3.6678, Acc: 8.78%: 100%|██████████| 313/313 [00:29<00:00, 10.57it/s]

Epoch 9/500, Train Loss: 3.6691, Train Acc: 8.74%





Epoch 9, Val Loss: 3.6274, Val Acc: 10.59%
当前学习率: 0.010000
保存最佳模型，验证准确率: 10.59%


Epoch 10, Loss: 3.5756, Acc: 10.51%: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s]

Epoch 10/500, Train Loss: 3.5741, Train Acc: 10.54%





Epoch 10, Val Loss: 4.0464, Val Acc: 4.85%
当前学习率: 0.010000


Epoch 11, Loss: 3.5627, Acc: 10.59%: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s]

Epoch 11/500, Train Loss: 3.5604, Train Acc: 10.62%





Epoch 11, Val Loss: 3.4058, Val Acc: 9.52%
当前学习率: 0.010000


Epoch 12, Loss: 3.5832, Acc: 9.75%: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s] 

Epoch 12/500, Train Loss: 3.5837, Train Acc: 9.74%





Epoch 12, Val Loss: 3.3315, Val Acc: 13.82%
当前学习率: 0.010000
保存最佳模型，验证准确率: 13.82%


Epoch 13, Loss: 3.4781, Acc: 10.87%: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s]

Epoch 13/500, Train Loss: 3.4790, Train Acc: 10.82%





Epoch 13, Val Loss: 3.5743, Val Acc: 7.36%
当前学习率: 0.010000


Epoch 14, Loss: 3.4826, Acc: 10.97%: 100%|██████████| 313/313 [00:29<00:00, 10.57it/s]

Epoch 14/500, Train Loss: 3.4819, Train Acc: 11.02%





Epoch 14, Val Loss: 3.3422, Val Acc: 12.39%
当前学习率: 0.010000


Epoch 15, Loss: 3.4127, Acc: 11.62%: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s]

Epoch 15/500, Train Loss: 3.4114, Train Acc: 11.68%





Epoch 15, Val Loss: 3.3296, Val Acc: 11.85%
当前学习率: 0.010000


Epoch 16, Loss: 3.4321, Acc: 10.51%: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s]

Epoch 16/500, Train Loss: 3.4311, Train Acc: 10.50%





Epoch 16, Val Loss: 3.2898, Val Acc: 13.29%
当前学习率: 0.010000


Epoch 17, Loss: 3.3706, Acc: 12.38%: 100%|██████████| 313/313 [00:29<00:00, 10.61it/s]

Epoch 17/500, Train Loss: 3.3716, Train Acc: 12.36%





Epoch 17, Val Loss: 3.2959, Val Acc: 14.00%
当前学习率: 0.010000
保存最佳模型，验证准确率: 14.00%


Epoch 18, Loss: 3.3698, Acc: 12.72%: 100%|██████████| 313/313 [00:29<00:00, 10.60it/s]

Epoch 18/500, Train Loss: 3.3686, Train Acc: 12.76%





Epoch 18, Val Loss: 3.2184, Val Acc: 14.54%
当前学习率: 0.010000
保存最佳模型，验证准确率: 14.54%


Epoch 19, Loss: 3.3453, Acc: 12.80%: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s]

Epoch 19/500, Train Loss: 3.3443, Train Acc: 12.80%





Epoch 19, Val Loss: 3.2155, Val Acc: 13.29%
当前学习率: 0.010000


Epoch 20, Loss: 3.3695, Acc: 12.90%: 100%|██████████| 313/313 [00:29<00:00, 10.57it/s]

Epoch 20/500, Train Loss: 3.3683, Train Acc: 12.92%





Epoch 20, Val Loss: 3.2025, Val Acc: 15.44%
当前学习率: 0.010000
保存最佳模型，验证准确率: 15.44%


Epoch 21, Loss: 3.4779, Acc: 10.65%: 100%|██████████| 313/313 [00:29<00:00, 10.56it/s]

Epoch 21/500, Train Loss: 3.4760, Train Acc: 10.66%





Epoch 21, Val Loss: 3.3470, Val Acc: 9.87%
当前学习率: 0.010000


Epoch 22, Loss: 3.3783, Acc: 12.54%: 100%|██████████| 313/313 [00:29<00:00, 10.56it/s]

Epoch 22/500, Train Loss: 3.3753, Train Acc: 12.54%





Epoch 22, Val Loss: 3.2038, Val Acc: 14.36%
当前学习率: 0.010000


Epoch 23, Loss: 3.3225, Acc: 12.86%: 100%|██████████| 313/313 [00:29<00:00, 10.60it/s]

Epoch 23/500, Train Loss: 3.3228, Train Acc: 12.84%





Epoch 23, Val Loss: 3.2073, Val Acc: 15.62%
当前学习率: 0.010000
保存最佳模型，验证准确率: 15.62%


Epoch 24, Loss: 3.2211, Acc: 15.19%: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s]

Epoch 24/500, Train Loss: 3.2214, Train Acc: 15.16%





Epoch 24, Val Loss: 3.0701, Val Acc: 15.80%
当前学习率: 0.010000
保存最佳模型，验证准确率: 15.80%


Epoch 25, Loss: 3.1989, Acc: 16.24%: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s]

Epoch 25/500, Train Loss: 3.2020, Train Acc: 16.20%





Epoch 25, Val Loss: 2.9419, Val Acc: 20.47%
当前学习率: 0.010000
保存最佳模型，验证准确率: 20.47%


Epoch 26, Loss: 3.0983, Acc: 17.58%: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s]

Epoch 26/500, Train Loss: 3.0973, Train Acc: 17.60%





Epoch 26, Val Loss: 2.9998, Val Acc: 17.77%
当前学习率: 0.010000


Epoch 27, Loss: 3.1347, Acc: 17.04%: 100%|██████████| 313/313 [00:29<00:00, 10.57it/s]

Epoch 27/500, Train Loss: 3.1347, Train Acc: 17.04%





Epoch 27, Val Loss: 3.0775, Val Acc: 16.34%
当前学习率: 0.010000


Epoch 28, Loss: 3.2193, Acc: 15.49%: 100%|██████████| 313/313 [00:29<00:00, 10.57it/s]

Epoch 28/500, Train Loss: 3.2190, Train Acc: 15.52%





Epoch 28, Val Loss: 2.9107, Val Acc: 19.75%
当前学习率: 0.010000


Epoch 29, Loss: 3.1839, Acc: 16.44%: 100%|██████████| 313/313 [00:29<00:00, 10.57it/s]

Epoch 29/500, Train Loss: 3.1862, Train Acc: 16.48%





Epoch 29, Val Loss: 3.0096, Val Acc: 16.52%
当前学习率: 0.010000


Epoch 30, Loss: 3.1374, Acc: 16.86%: 100%|██████████| 313/313 [00:30<00:00, 10.42it/s]

Epoch 30/500, Train Loss: 3.1390, Train Acc: 16.84%





Epoch 30, Val Loss: 2.9409, Val Acc: 21.36%
当前学习率: 0.010000
保存最佳模型，验证准确率: 21.36%


Epoch 31, Loss: 3.1026, Acc: 17.20%: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s]

Epoch 31/500, Train Loss: 3.1024, Train Acc: 17.14%





Epoch 31, Val Loss: 2.9739, Val Acc: 19.21%
当前学习率: 0.010000


Epoch 32, Loss: 3.0991, Acc: 16.98%: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s]

Epoch 32/500, Train Loss: 3.1004, Train Acc: 16.94%





Epoch 32, Val Loss: 3.1562, Val Acc: 15.26%
当前学习率: 0.005000


Epoch 33, Loss: 3.0488, Acc: 19.23%: 100%|██████████| 313/313 [00:29<00:00, 10.57it/s]

Epoch 33/500, Train Loss: 3.0481, Train Acc: 19.24%





Epoch 33, Val Loss: 2.8229, Val Acc: 21.72%
当前学习率: 0.005000
保存最佳模型，验证准确率: 21.72%


Epoch 34, Loss: 2.9691, Acc: 19.19%: 100%|██████████| 313/313 [00:29<00:00, 10.55it/s]

Epoch 34/500, Train Loss: 2.9681, Train Acc: 19.16%





Epoch 34, Val Loss: 2.7329, Val Acc: 23.34%
当前学习率: 0.005000
保存最佳模型，验证准确率: 23.34%


Epoch 35, Loss: 2.9097, Acc: 21.46%: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s]

Epoch 35/500, Train Loss: 2.9160, Train Acc: 21.38%





Epoch 35, Val Loss: 2.6943, Val Acc: 25.67%
当前学习率: 0.005000
保存最佳模型，验证准确率: 25.67%


Epoch 36, Loss: 2.9532, Acc: 19.92%: 100%|██████████| 313/313 [00:29<00:00, 10.57it/s]

Epoch 36/500, Train Loss: 2.9539, Train Acc: 19.90%





Epoch 36, Val Loss: 2.7509, Val Acc: 23.52%
当前学习率: 0.005000


Epoch 37, Loss: 2.9690, Acc: 18.71%: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s]

Epoch 37/500, Train Loss: 2.9714, Train Acc: 18.70%





Epoch 37, Val Loss: 2.7705, Val Acc: 20.47%
当前学习率: 0.005000


Epoch 38, Loss: 2.9032, Acc: 20.12%: 100%|██████████| 313/313 [00:29<00:00, 10.55it/s]

Epoch 38/500, Train Loss: 2.9028, Train Acc: 20.06%





Epoch 38, Val Loss: 2.7445, Val Acc: 25.31%
当前学习率: 0.005000


Epoch 39, Loss: 2.8708, Acc: 21.97%: 100%|██████████| 313/313 [00:29<00:00, 10.57it/s]

Epoch 39/500, Train Loss: 2.8681, Train Acc: 22.04%





Epoch 39, Val Loss: 2.6476, Val Acc: 26.75%
当前学习率: 0.005000
保存最佳模型，验证准确率: 26.75%


Epoch 40, Loss: 2.8545, Acc: 21.40%: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s]

Epoch 40/500, Train Loss: 2.8546, Train Acc: 21.42%





Epoch 40, Val Loss: 2.6193, Val Acc: 24.78%
当前学习率: 0.005000


Epoch 41, Loss: 2.9121, Acc: 20.90%: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s]

Epoch 41/500, Train Loss: 2.9128, Train Acc: 20.98%





Epoch 41, Val Loss: 2.7122, Val Acc: 22.80%
当前学习率: 0.005000


Epoch 42, Loss: 2.8939, Acc: 21.58%: 100%|██████████| 313/313 [00:29<00:00, 10.57it/s]

Epoch 42/500, Train Loss: 2.8933, Train Acc: 21.58%





Epoch 42, Val Loss: 2.7545, Val Acc: 23.88%
当前学习率: 0.005000


Epoch 43, Loss: 2.8486, Acc: 21.97%: 100%|██████████| 313/313 [00:29<00:00, 10.57it/s]

Epoch 43/500, Train Loss: 2.8522, Train Acc: 21.98%





Epoch 43, Val Loss: 2.6534, Val Acc: 23.52%
当前学习率: 0.005000


Epoch 44, Loss: 2.7831, Acc: 22.73%: 100%|██████████| 313/313 [00:29<00:00, 10.57it/s]

Epoch 44/500, Train Loss: 2.7804, Train Acc: 22.80%





Epoch 44, Val Loss: 2.6147, Val Acc: 26.93%
当前学习率: 0.005000
保存最佳模型，验证准确率: 26.93%


Epoch 45, Loss: 2.7762, Acc: 24.04%: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s]

Epoch 45/500, Train Loss: 2.7763, Train Acc: 24.06%





Epoch 45, Val Loss: 2.6202, Val Acc: 24.42%
当前学习率: 0.005000


Epoch 46, Loss: 2.7906, Acc: 23.09%: 100%|██████████| 313/313 [00:29<00:00, 10.59it/s]

Epoch 46/500, Train Loss: 2.7903, Train Acc: 23.10%





Epoch 46, Val Loss: 2.6276, Val Acc: 27.47%
当前学习率: 0.005000
保存最佳模型，验证准确率: 27.47%


Epoch 47, Loss: 2.7077, Acc: 24.56%: 100%|██████████| 313/313 [00:29<00:00, 10.57it/s]

Epoch 47/500, Train Loss: 2.7081, Train Acc: 24.62%





Epoch 47, Val Loss: 2.5426, Val Acc: 29.08%
当前学习率: 0.005000
保存最佳模型，验证准确率: 29.08%


Epoch 48, Loss: 2.7619, Acc: 24.08%: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s]

Epoch 48/500, Train Loss: 2.7615, Train Acc: 24.14%





Epoch 48, Val Loss: 2.5791, Val Acc: 28.73%
当前学习率: 0.005000


Epoch 49, Loss: 2.7505, Acc: 24.36%: 100%|██████████| 313/313 [00:29<00:00, 10.59it/s]

Epoch 49/500, Train Loss: 2.7519, Train Acc: 24.34%





Epoch 49, Val Loss: 2.6097, Val Acc: 24.78%
当前学习率: 0.005000


Epoch 50, Loss: 2.8014, Acc: 23.23%: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s]

Epoch 50/500, Train Loss: 2.8071, Train Acc: 23.20%





Epoch 50, Val Loss: 2.6112, Val Acc: 27.11%
当前学习率: 0.005000


Epoch 51, Loss: 2.7297, Acc: 24.30%: 100%|██████████| 313/313 [00:29<00:00, 10.57it/s]

Epoch 51/500, Train Loss: 2.7286, Train Acc: 24.28%





Epoch 51, Val Loss: 2.5773, Val Acc: 26.75%
当前学习率: 0.002500


Epoch 52, Loss: 2.7050, Acc: 24.54%: 100%|██████████| 313/313 [00:29<00:00, 10.56it/s]

Epoch 52/500, Train Loss: 2.7064, Train Acc: 24.52%





Epoch 52, Val Loss: 2.4935, Val Acc: 29.80%
当前学习率: 0.002500
保存最佳模型，验证准确率: 29.80%


Epoch 53, Loss: 2.6813, Acc: 25.10%: 100%|██████████| 313/313 [00:29<00:00, 10.55it/s]

Epoch 53/500, Train Loss: 2.6827, Train Acc: 25.10%





Epoch 53, Val Loss: 2.5013, Val Acc: 28.73%
当前学习率: 0.002500


Epoch 54, Loss: 2.6217, Acc: 26.63%: 100%|██████████| 313/313 [00:29<00:00, 10.57it/s]

Epoch 54/500, Train Loss: 2.6225, Train Acc: 26.58%





Epoch 54, Val Loss: 2.4719, Val Acc: 28.01%
当前学习率: 0.002500


Epoch 55, Loss: 2.6051, Acc: 26.89%: 100%|██████████| 313/313 [00:29<00:00, 10.59it/s]

Epoch 55/500, Train Loss: 2.6029, Train Acc: 26.92%





Epoch 55, Val Loss: 2.4002, Val Acc: 30.70%
当前学习率: 0.002500
保存最佳模型，验证准确率: 30.70%


Epoch 56, Loss: 2.6342, Acc: 27.17%: 100%|██████████| 313/313 [00:29<00:00, 10.57it/s]

Epoch 56/500, Train Loss: 2.6362, Train Acc: 27.16%





Epoch 56, Val Loss: 2.4575, Val Acc: 28.73%
当前学习率: 0.002500


Epoch 57, Loss: 2.6323, Acc: 26.25%: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s]

Epoch 57/500, Train Loss: 2.6292, Train Acc: 26.30%





Epoch 57, Val Loss: 2.4179, Val Acc: 30.70%
当前学习率: 0.002500


Epoch 58, Loss: 2.6060, Acc: 26.97%: 100%|██████████| 313/313 [00:29<00:00, 10.55it/s]

Epoch 58/500, Train Loss: 2.6075, Train Acc: 26.96%





Epoch 58, Val Loss: 2.3946, Val Acc: 30.88%
当前学习率: 0.002500
保存最佳模型，验证准确率: 30.88%


Epoch 59, Loss: 2.5827, Acc: 27.31%: 100%|██████████| 313/313 [00:29<00:00, 10.59it/s]

Epoch 59/500, Train Loss: 2.5851, Train Acc: 27.32%





Epoch 59, Val Loss: 2.3736, Val Acc: 32.68%
当前学习率: 0.002500
保存最佳模型，验证准确率: 32.68%


Epoch 60, Loss: 2.5795, Acc: 28.09%: 100%|██████████| 313/313 [00:30<00:00, 10.25it/s]

Epoch 60/500, Train Loss: 2.5803, Train Acc: 28.10%





Epoch 60, Val Loss: 2.3833, Val Acc: 28.73%
当前学习率: 0.002500


Epoch 61, Loss: 2.5447, Acc: 28.18%: 100%|██████████| 313/313 [00:35<00:00,  8.80it/s]

Epoch 61/500, Train Loss: 2.5419, Train Acc: 28.28%





Epoch 61, Val Loss: 2.3799, Val Acc: 30.52%
当前学习率: 0.002500


Epoch 62, Loss: 2.5236, Acc: 28.26%: 100%|██████████| 313/313 [00:35<00:00,  8.82it/s]

Epoch 62/500, Train Loss: 2.5243, Train Acc: 28.26%





Epoch 62, Val Loss: 2.3961, Val Acc: 31.42%
当前学习率: 0.002500


Epoch 63, Loss: 2.5070, Acc: 28.58%: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s]

Epoch 63/500, Train Loss: 2.5065, Train Acc: 28.56%





Epoch 63, Val Loss: 2.3865, Val Acc: 30.70%
当前学习率: 0.001250


Epoch 64, Loss: 2.4953, Acc: 29.64%: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s]

Epoch 64/500, Train Loss: 2.4946, Train Acc: 29.64%





Epoch 64, Val Loss: 2.3556, Val Acc: 32.68%
当前学习率: 0.001250


Epoch 65, Loss: 2.4545, Acc: 30.93%: 100%|██████████| 313/313 [00:29<00:00, 10.56it/s]

Epoch 65/500, Train Loss: 2.4554, Train Acc: 30.92%





Epoch 65, Val Loss: 2.3563, Val Acc: 34.47%
当前学习率: 0.001250
保存最佳模型，验证准确率: 34.47%


Epoch 66, Loss: 2.4767, Acc: 29.46%: 100%|██████████| 313/313 [00:29<00:00, 10.59it/s]

Epoch 66/500, Train Loss: 2.4799, Train Acc: 29.42%





Epoch 66, Val Loss: 2.3052, Val Acc: 35.37%
当前学习率: 0.001250
保存最佳模型，验证准确率: 35.37%


Epoch 67, Loss: 2.4716, Acc: 29.20%: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s]

Epoch 67/500, Train Loss: 2.4738, Train Acc: 29.16%





Epoch 67, Val Loss: 2.3524, Val Acc: 32.50%
当前学习率: 0.001250


Epoch 68, Loss: 2.4777, Acc: 30.02%: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s]

Epoch 68/500, Train Loss: 2.4789, Train Acc: 29.98%





Epoch 68, Val Loss: 2.3673, Val Acc: 33.75%
当前学习率: 0.001250


Epoch 69, Loss: 2.4653, Acc: 29.92%: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s]

Epoch 69/500, Train Loss: 2.4613, Train Acc: 29.94%





Epoch 69, Val Loss: 2.3262, Val Acc: 33.03%
当前学习率: 0.001250


Epoch 70, Loss: 2.4663, Acc: 30.12%: 100%|██████████| 313/313 [00:29<00:00, 10.59it/s]

Epoch 70/500, Train Loss: 2.4641, Train Acc: 30.12%





Epoch 70, Val Loss: 2.3154, Val Acc: 32.14%
当前学习率: 0.000625


Epoch 71, Loss: 2.4129, Acc: 31.27%: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s]

Epoch 71/500, Train Loss: 2.4128, Train Acc: 31.24%





Epoch 71, Val Loss: 2.3041, Val Acc: 33.57%
当前学习率: 0.000625


Epoch 72, Loss: 2.4333, Acc: 31.05%: 100%|██████████| 313/313 [00:29<00:00, 10.59it/s]

Epoch 72/500, Train Loss: 2.4381, Train Acc: 30.98%





Epoch 72, Val Loss: 2.3196, Val Acc: 31.96%
当前学习率: 0.000625


Epoch 73, Loss: 2.3913, Acc: 31.57%: 100%|██████████| 313/313 [00:29<00:00, 10.57it/s]

Epoch 73/500, Train Loss: 2.3996, Train Acc: 31.48%





Epoch 73, Val Loss: 2.3012, Val Acc: 33.21%
当前学习率: 0.000625


Epoch 74, Loss: 2.4327, Acc: 30.91%: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s]

Epoch 74/500, Train Loss: 2.4313, Train Acc: 30.94%





Epoch 74, Val Loss: 2.3223, Val Acc: 33.03%
当前学习率: 0.000625


Epoch 75, Loss: 2.3952, Acc: 31.51%: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s]

Epoch 75/500, Train Loss: 2.3975, Train Acc: 31.46%





Epoch 75, Val Loss: 2.2859, Val Acc: 34.29%
当前学习率: 0.000625


Epoch 76, Loss: 2.4062, Acc: 31.95%: 100%|██████████| 313/313 [00:29<00:00, 10.57it/s]

Epoch 76/500, Train Loss: 2.4088, Train Acc: 31.88%





Epoch 76, Val Loss: 2.3037, Val Acc: 34.11%
当前学习率: 0.000625
早停于第 76 轮
创建评估器...
评估模型...
测试集准确率: 37.58%
