# 准备

## 导入必要的库

In [None]:
import os, random, time, math, warnings
import numpy as np
from PIL import Image
from tqdm.notebook import tqdm
import pandas as pd
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR

import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display

# --- 全局配置 ---
sns.set_theme(style="whitegrid", palette="deep", font='SimHei', font_scale=1.1)
plt.rcParams['axes.unicode_minus'] = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed); torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = False
set_seed(42)
print(f"配置完成: 设备={device}, 随机种子=42")

# --- 辅助函数 ---
def add_gaussian_noise_np(image_np: np.ndarray, mean=0, sigma=25) -> np.ndarray:
    noise = np.random.normal(mean, sigma, image_np.shape).astype(np.float32)
    return np.clip(image_np.astype(np.float32) + noise, 0, 255).astype(np.uint8)

##  所有类和函数定义 (模型, 数据集, 加密)

In [None]:
# --- 加密函数 ---
def encrypt_block_permutation(img_tensor: torch.Tensor, block_size: int = 16) -> torch.Tensor:
    B, C, H, W = img_tensor.shape
    if H % block_size != 0 or W % block_size != 0: return img_tensor
    blocks = F.unfold(img_tensor, kernel_size=block_size, stride=block_size)
    permuted_indices = torch.randperm(blocks.shape[-1], device=img_tensor.device)
    permuted_blocks = blocks[:, :, permuted_indices]
    return F.fold(permuted_blocks, output_size=(H, W), kernel_size=block_size, stride=block_size)

def encrypt_global_pixel_shuffle(img_tensor: torch.Tensor) -> torch.Tensor:
    B, C, H, W = img_tensor.shape
    flat_img = img_tensor.view(B, C, -1)
    permuted_indices = torch.randperm(H * W, device=img_tensor.device)
    permuted_flat_img = flat_img[:, :, permuted_indices]
    return permuted_flat_img.view(B, C, H, W)

def encrypt_block_pixel_shuffle(img_tensor: torch.Tensor, block_size: int = 16) -> torch.Tensor:
    B, C, H, W = img_tensor.shape
    if H % block_size != 0 or W % block_size != 0: return img_tensor
    blocks = F.unfold(img_tensor, kernel_size=block_size, stride=block_size)
    blocks = blocks.permute(0, 2, 1).reshape(B, -1, C, block_size * block_size)
    permuted_indices = torch.randperm(block_size * block_size, device=img_tensor.device)
    shuffled_blocks = blocks[:, :, :, permuted_indices]
    shuffled_blocks = shuffled_blocks.reshape(B, -1, C * block_size * block_size).permute(0, 2, 1)
    return F.fold(shuffled_blocks, output_size=(H, W), kernel_size=block_size, stride=block_size)

print("所有加密函数已成功定义。")

# --- 数据集定义 ---
class TripletImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        super().__init__()
        self.root_dir = root_dir
        self.transform = transform
        self.class_dirs = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
        self.class_to_images = {d: [os.path.join(root_dir, d, i) for i in os.listdir(os.path.join(root_dir, d))] for d in self.class_dirs}
        self.all_images = [(p, i) for i, d in enumerate(self.class_dirs) for p in self.class_to_images[d]]
    def __len__(self): return len(self.all_images)
    def __getitem__(self, index):
        anchor_path, anchor_idx = self.all_images[index]
        anchor_class = self.class_dirs[anchor_idx]
        pos_list = self.class_to_images[anchor_class]
        pos_path = anchor_path
        while pos_path == anchor_path and len(pos_list) > 1: pos_path = random.choice(pos_list)
        neg_class = anchor_class
        while neg_class == anchor_class and len(self.class_dirs) > 1: neg_class = random.choice(self.class_dirs)
        neg_path = random.choice(self.class_to_images[neg_class])
        return self._load_image(anchor_path), self._load_image(pos_path), self._load_image(neg_path)
    def _load_image(self, path):
        img = Image.open(path).convert('RGB')
        return self.transform(img) if self.transform else img

class EvaluationDataset(Dataset):
    def __init__(self, root_dir):
        super().__init__()
        self.root_dir = root_dir
        self.class_dirs = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
        self.all_images = []
        for class_idx, class_dir in enumerate(self.class_dirs):
            class_path = os.path.join(root_dir, class_dir)
            for img_name in os.listdir(class_path): self.all_images.append((os.path.join(class_path, img_name), class_idx))
    def __len__(self): return len(self.all_images)
    def __getitem__(self, index):
        img_path, label = self.all_images[index]
        image = Image.open(img_path).convert('RGB')
        return image, label

# --- 导入模型 ---
try:
#    from LPMP_NET import LPMPNet_V6_Turbo, LPMP_DeepSetNet_V7, LPMPNet_DCT_Turbo_V7, DCT_Guardian_V8, EnhancedLPMP
    from final_model import HDCT, HDCT_G, HieraDCT_v2
    from comparison_models import SimpleCNN, DeepSetNet, DCTNet, HistogramMLP
    print("所有模型类已成功导入。")
except ImportError as e:
    print(f"导入错误: {e}")

## 超参数设置与数据加载

In [None]:
# --- 路径与超参数 ---
TRAIN_VALID_DIR, TEST_DIR = './training_data/', './testing_data/'
EXPERIMENT_NAME = './logs/'
MODEL_SAVE_PATH = './saved_models/'; os.makedirs(MODEL_SAVE_PATH, exist_ok=True)
BATCH_SIZE, BLOCK_SIZE, OUTPUT_DIM = 32, 32, 256
LEARNING_RATE, NUM_EPOCHS, VALIDATION_SPLIT = 1e-4, 100, 0.2
LOAD_BALANCE_ALPHA = 1e-2
MARGIN = 1.0

# --- 数据加载器 ---
data_transform = transforms.Compose([transforms.Resize((BLOCK_SIZE, BLOCK_SIZE)), transforms.ToTensor()])
try:
    print("\n--- 配置数据加载器 ---")
    full_train_dataset = TripletImageDataset(root_dir=TRAIN_VALID_DIR, transform=data_transform)
    train_size = int((1 - VALIDATION_SPLIT) * len(full_train_dataset))
    train_subset, valid_subset = random_split(full_train_dataset, [train_size, len(full_train_dataset) - train_size], generator=torch.Generator().manual_seed(42))

    train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
    valid_loader = DataLoader(valid_subset, batch_size=BATCH_SIZE * 2, shuffle=False, num_workers=0, pin_memory=True)
    test_dataset_plain = EvaluationDataset(root_dir=TEST_DIR)

    print(f"数据加载器准备就绪: 训练集 {len(train_subset)} | 验证集 {len(valid_subset)} | 测试集 {len(test_dataset_plain)}")
except Exception as e:
    print(f"错误: 数据集加载失败。请确保文件夹路径和内容正确。\n错误信息: {e}"); raise e

## 模型实例化与断点训练

### 非学习型模型

In [None]:
# --- 检查scikit-image依赖 ---
try:
    from skimage.feature import local_binary_pattern
    print("scikit-image 已成功导入，LBP功能可用。")
except ImportError:
    print("警告: scikit-image 未安装。LBP特征提取器将不可用。")
    print("请运行: pip install scikit-image")
    local_binary_pattern = None

# --- 1. 统计直方图代表：颜色直方图 ---
def extract_color_histogram(image_np: np.ndarray, bins: int = 64) -> np.ndarray:
    """计算并拼接B, G, R三通道的颜色直方图。"""
    if image_np is None or image_np.size == 0: return np.zeros(bins * 3)
    b_hist = cv2.calcHist([image_np], [0], None, [bins], [0, 256])
    g_hist = cv2.calcHist([image_np], [1], None, [bins], [0, 256])
    r_hist = cv2.calcHist([image_np], [2], None, [bins], [0, 256])
    combined_hist = np.concatenate([b_hist, g_hist, r_hist]).flatten()
    return cv2.normalize(combined_hist, combined_hist, norm_type=cv2.NORM_L1).flatten()

# --- 2. 局部特征代表：LBP纹理特征 ---
def extract_slbp_features(image_np: np.ndarray, radius: int = 1, n_points: int = 8) -> np.ndarray:
    """使用旋转不变的均匀LBP来描述纹理。"""
    if local_binary_pattern is None:
        warnings.warn("scikit-image 未安装，LBP特征为空。")
        return np.array([])
    if image_np is None or image_np.size == 0: return np.zeros(n_points + 2)
    
    gray_image = cv2.cvtColor(image_np, cv2.COLOR_BGR2GRAY) if image_np.ndim == 3 else image_np
    lbp = local_binary_pattern(gray_image, n_points, radius, method='uniform')
    (hist, _) = np.histogram(lbp.ravel(), bins=np.arange(0, n_points + 3), range=(0, n_points + 2))
    hist = hist.astype("float")
    hist /= (hist.sum() + 1e-6)
    return hist

print("\n非学习对照组模型函数 (extract_color_histogram, extract_slbp_features) 已定义。")

In [None]:
# --- 1. 定义模型字典 ---
standard_models = {
#    'LPMP_Net_V6_Turbo': LPMPNet_V6_Turbo(output_dim=OUTPUT_DIM),
    'HDCT': HDCT(output_dim=OUTPUT_DIM),
    'HDCT_G': HDCT_G(output_dim=OUTPUT_DIM),
    'HDCT_V2': HieraDCT_v2(output_dim=OUTPUT_DIM),
    'SimpleCNN': SimpleCNN(output_dim=OUTPUT_DIM, block_size=BLOCK_SIZE),
    'DeepSetNet': DeepSetNet(output_dim=OUTPUT_DIM),
    'DCTNet': DCTNet(output_dim=OUTPUT_DIM),
    'HistogramMLP': HistogramMLP(output_dim=OUTPUT_DIM),
}
# 定义需要特殊训练策略的模型
lpmp_models = {}
# 合并成一个总字典
models = {**standard_models, **lpmp_models}

# --- 2. 自动识别特殊模型 ---
def detect_moe_models(models_dict: dict) -> set:
    moe_names = set()
    dummy_input = torch.randn(2, 3, BLOCK_SIZE, BLOCK_SIZE, device=device)
    for name, model in models_dict.items():
        if hasattr(model, 'forward'):
            try:
                output = model(dummy_input)
                if isinstance(output, (tuple, list)) and len(output) == 2 and isinstance(output[1], torch.Tensor):
                    moe_names.add(name)
            except Exception: pass
    return moe_names

print("\n--- 自动识别 MoE 模型 ---")
moe_model_names = detect_moe_models(models)
print(f"  - MoE 模型: {moe_model_names if moe_model_names else '无'}")
print(f"  - 特殊LPMP模型: {list(lpmp_models.keys())}")

# --- 3. 初始化所有训练组件 ---
optimizers = {name: optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4) if name in lpmp_models else optim.Adam(model.parameters(), lr=LEARNING_RATE) for name, model in models.items()}
lpmp_schedulers = {}
for name, optimizer in optimizers.items():
    if name in lpmp_models:
        warmup_epochs = 5
        main_epochs = NUM_EPOCHS - warmup_epochs
        if main_epochs > 0:
            warmup_scheduler = LinearLR(optimizer, start_factor=1e-3, end_factor=1.0, total_iters=warmup_epochs)
            main_scheduler = CosineAnnealingLR(optimizer, T_max=main_epochs, eta_min=1e-6)
            lpmp_schedulers[name] = SequentialLR(optimizer, schedulers=[warmup_scheduler, main_scheduler], milestones=[warmup_epochs])

scalers = {name: GradScaler(enabled=(device.type == 'cuda')) for name, model in models.items()}
print("\n混合精度缩放器 (Scalers) 已为所有模型创建。")

# --- 4. 断点续训状态初始化与加载 ---
start_epoch = 0
history = {name: {'train_loss': [], 'valid_loss': [], 'valid_acc': [], 'moe_lb_loss': []} for name in models}
best_acc = {name: 0.0 for name in models}

CHECKPOINT_PATH = os.path.join(MODEL_SAVE_PATH, 'latest_checkpoint.pth')
if os.path.exists(CHECKPOINT_PATH):
    print(f"\n发现检查点文件，正在加载... '{CHECKPOINT_PATH}'")
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
    start_epoch = checkpoint.get('epoch', -1) + 1

    loaded_history = checkpoint.get('history', {})
    for name in history:
        if name in loaded_history: history[name].update(loaded_history[name])
    best_acc.update(checkpoint.get('best_acc', {}))

    for name, model in models.items():
        if name in checkpoint.get('model_state_dicts', {}):
            try:
                model.load_state_dict(checkpoint['model_state_dicts'][name])
                optimizers[name].load_state_dict(checkpoint['optimizer_state_dicts'][name])
                if name in lpmp_schedulers: lpmp_schedulers[name].load_state_dict(checkpoint['scheduler_state_dicts'][name])
                print(f"  - 已成功加载状态: {name}")
            except Exception as e: print(f"  - 警告: 加载 {name} 状态失败. {e}")
    print(f"将从 Epoch {start_epoch + 1} 继续训练。")
else:
    print("\n未发现检查点文件，将从头开始训练。")

# --- 5. 将模型移动到设备 ---
for name, model in models.items():
    models[name] = model.to(device)

print("\n所有模型和优化器实例化完成:")
for name, model in models.items():
    params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"  - {name:<25}: {params/1e6:.3f}M params")

# 训练

## 训练与验证循环

In [None]:
# --- 定义辅助函数 ---

def set_seed(seed=42):
    """
    设置随机种子以确保实验的可复现性。
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        # 为了完全可复现，可以牺牲一些性能
        # torch.backends.cudnn.deterministic = True
        # torch.backends.cudnn.benchmark = False
    print(f"随机种子已设置为: {seed}")

def save_best_model(model, name, acc, best_acc_dict, save_path):
    """
    检查并保存当前最佳模型。
    """
    if acc > best_acc_dict.get(name, 0.0):
        best_acc_dict[name] = acc
        model_save_file = os.path.join(save_path, f'best_{name}_model.pth')
        torch.save(model.state_dict(), model_save_file)
        # 这条信息现在可以省略，因为总结表格会标出最佳模型
        # print(f"      -> New best model for {name} saved! Accuracy: {acc:.2%}")

# --- 初始化设置 ---

# 设置随机种子
set_seed(42)


# ==============================================================================
# 2. 核心训练与验证函数 (Core Training & Validation Function)
# ==============================================================================

def run_one_epoch(model, name, data_loader, optimizer, scaler, margin, is_training):
    """
    在一个epoch上运行训练或验证。
    """
    model.train(is_training) # 根据 is_training 设置 model.train() 或 model.eval()

    epoch_losses = {'task': 0.0, 'lb': 0.0}
    corrects, total_samples = 0, 0
    
    # 准备tqdm进度条
    mode = "Training" if is_training else "Validating"
    # 在Jupyter中，tqdm.notebook.tqdm 提供了更好的显示效果
    iterable = tqdm(data_loader, desc=f"  - {mode} {name}", leave=False)

    for i, (anchor, positive, negative) in enumerate(iterable):
        anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)
        batch_size = anchor.size(0)

        with torch.set_grad_enabled(is_training):
            with autocast(enabled=(device.type == 'cuda' and scaler is not None)):
                # 初始化损失
                task_loss, lb_loss = torch.tensor(0.0, device=device), torch.tensor(0.0, device=device)

                # 模型前向传播
                if name in moe_model_names:
                    a_f, lb_a = model(anchor)
                    p_f, lb_p = model(positive)
                    n_f, lb_n = model(negative)
                    task_loss = F.triplet_margin_loss(a_f, p_f, n_f, margin=margin)
                    if is_training: 
                        lb_loss = (lb_a + lb_p + lb_n) / 3.0
                else: # 标准模型或LearnableLPMP
                    a_f, p_f, n_f = model(anchor), model(positive), model(negative)
                    task_loss = F.triplet_margin_loss(a_f, p_f, n_f, margin=margin)

                # 反向传播与优化 (仅在训练时)
                if is_training:
                    optimizer.zero_grad(set_to_none=True)
                    total_loss = task_loss + LOAD_BALANCE_ALPHA * lb_loss
                    
                    scaler.scale(total_loss).backward()
                    
                    if name in lpmp_models: # 仅对LearnableLPMP进行梯度裁剪
                        scaler.unscale_(optimizer)
                        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    
                    scaler.step(optimizer)
                    scaler.update()

        # 累加损失和正确数
        epoch_losses['task'] += task_loss.item() * batch_size
        if is_training and lb_loss.item() != 0:
            epoch_losses['lb'] += lb_loss.item() * batch_size
        with torch.no_grad():
            corrects += torch.sum(F.pairwise_distance(a_f, p_f) < F.pairwise_distance(a_f, n_f)).item()
        total_samples += batch_size
        
        # 实时更新tqdm进度条的后缀信息
        current_loss = epoch_losses['task'] / total_samples
        current_acc = corrects / total_samples
        iterable.set_postfix(Loss=f"{current_loss:.4f}", Acc=f"{current_acc:.2%}")

    # 计算最终的平均损失和精度
    final_losses = {k: v / total_samples for k, v in epoch_losses.items()}
    accuracy = corrects / total_samples
    
    return final_losses, accuracy


# ==============================================================================
# 3. 主训练循环 (Main Training Loop)
# ==============================================================================

if 'train_loader' in locals():
    print("-" * 70)
    print(f"所有模型已准备就绪。开始从 Epoch {start_epoch + 1} 进行训练...")
    total_start_time = time.time()

    for epoch in range(start_epoch, NUM_EPOCHS):
        print("-" * 70 + f"\n--- Epoch [{epoch+1:02d}/{NUM_EPOCHS}] ---")
        epoch_start_time = time.time()

        # --- 训练阶段 ---
        # 存储训练结果以便后续统一打印
        epoch_train_results = []
        for name, model in models.items():
            train_losses, train_acc = run_one_epoch(model, name, train_loader, optimizers[name], scalers[name], margin=MARGIN, is_training=True)
            history[name]['train_loss'].append(train_losses['task'])
            if name in moe_model_names:
                history[name]['moe_lb_loss'].append(train_losses['lb'])
            # 存储结果
            epoch_train_results.append({'name': name, 'loss': train_losses['task'], 'acc': train_acc})

        # --- 验证阶段 ---
        epoch_valid_results = []
        for name, model in models.items():
            valid_losses, valid_acc = run_one_epoch(model, name, valid_loader, None, None, margin=MARGIN, is_training=False)
            history[name]['valid_loss'].append(valid_losses['task'])
            history[name]['valid_acc'].append(valid_acc)
            # 存储结果
            epoch_valid_results.append({'name': name, 'loss': valid_losses['task'], 'acc': valid_acc})
            # 保存最佳模型的逻辑
            save_best_model(model, name, valid_acc, best_acc, MODEL_SAVE_PATH)
        
        # --- 打印对齐的、统一的总结表格 ---
        print(f"\n  Epoch {epoch+1} 总结:")
        print("  " + "-" * 65)
        print(f"  {'Model':<25} | {'Train Loss':>12} | {'Valid Loss':>12} | {'Valid Acc':>11}")
        print("  " + "-" * 65)
        for i in range(len(models)):
            train_res = epoch_train_results[i]
            valid_res = epoch_valid_results[i]
            # 检查是否有新的最佳模型，以便在表格中标记
            is_best_str = " *" if valid_res['acc'] >= best_acc.get(valid_res['name'], 0.0) else ""
            
            print(f"  {valid_res['name']:<25} | {train_res['loss']:>12.4f} | {valid_res['loss']:>12.4f} | {valid_res['acc']:>10.2%}{is_best_str}")
        print("  " + "-" * 65)
        print("  (* 表示该轮达到了新的最佳验证精度)")

        # --- 更新学习率 ---
        for name, scheduler in lpmp_schedulers.items():
            scheduler.step()
            print(f"  - LR Scheduler Update for {name}: New LR = {optimizers[name].param_groups[0]['lr']:.6f}")

        # --- 保存检查点 ---
        checkpoint_data = {
            'epoch': epoch,
            'model_state_dicts': {n: m.state_dict() for n, m in models.items()},
            'optimizer_state_dicts': {n: o.state_dict() for n, o in optimizers.items()},
            'scheduler_state_dicts': {n: s.state_dict() for n, s in lpmp_schedulers.items()},
            'history': history,
            'best_acc': best_acc,
        }
        torch.save(checkpoint_data, CHECKPOINT_PATH)
        print(f"  -> Checkpoint saved to '{CHECKPOINT_PATH}' (耗时: {time.time() - epoch_start_time:.2f}s)")

    print(f"\n--- 训练完成 | 总耗时: {(time.time() - total_start_time) / 60:.2f} 分钟 ---")
else:
    print("\n训练已跳过，因为数据加载器未准备好。")

# 绘制图表

In [None]:
# --- 检查是否有可供可视化的历史记录 ---
sample_model_name = list(models.keys())[0]
if history.get(sample_model_name, {}).get('train_loss'):

    # --- 创建 2x1 的画布 ---
    fig, axes = plt.subplots(2, 1, figsize=(18, 16), sharex=True)
    fig.suptitle('所有学习模型训练过程对比', fontsize=24, y=1.0, weight='bold')

    # --- 动态生成颜色和标记映射 ---
    colors = plt.cm.get_cmap('tab20', len(models))
    markers = ['o', 's', '^', 'P', 'X', 'D', 'v', '*', '<', '>']
    model_styles = {name: {'color': colors(i), 'marker': markers[i % len(markers)]} for i, name in enumerate(models.keys())}

    # ==================== 图1: 任务损失对比 (训练 vs 验证) ====================
    ax = axes[0]
    for name in models.keys():
        if history[name]['train_loss']:
            epochs = range(1, len(history[name]['train_loss']) + 1)
            loss_to_plot = history[name].get('moe_task_loss') or history[name]['train_loss']
            ax.plot(epochs, loss_to_plot, linestyle='-', marker=model_styles[name]['marker'], color=model_styles[name]['color'], label=f'{name} (Train)', alpha=0.9, markersize=5)
        if history[name]['valid_loss']:
            epochs = range(1, len(history[name]['valid_loss']) + 1)
            ax.plot(epochs, history[name]['valid_loss'], linestyle='--', marker=model_styles[name]['marker'], color=model_styles[name]['color'], label=f'{name} (Valid)', alpha=0.7, markersize=5)

    ax.set_title('训练与验证任务损失对比 (越低越好)', fontsize=18, weight='bold')
    ax.set_ylabel('Triplet Loss', fontsize=14)
    ax.legend(title='模型 (实线:训练, 虚线:验证)', fontsize=11, ncol=2)
    ax.set_yscale('log')
    ax.grid(True, which="both", linestyle='--', alpha=0.5)

    # ==================== 图2: 验证准确度对比 ====================
    ax = axes[1]
    for name in models.keys():
        valid_acc = history[name]['valid_acc']
        if valid_acc:
            epochs = range(1, len(valid_acc) + 1)
            line, = ax.plot(epochs, valid_acc, linestyle='-', marker=model_styles[name]['marker'], label=name, color=model_styles[name]['color'], alpha=0.9, markersize=5)

            best_epoch_idx = np.argmax(valid_acc)
            best_acc_val = valid_acc[best_epoch_idx]
            ax.scatter(best_epoch_idx + 1, best_acc_val, s=200, facecolors='none', edgecolors=line.get_color(), linewidth=3, zorder=20, label=f'Best: {best_acc_val:.2%}')
            ax.annotate(f'{name} Best\n{best_acc_val:.2%}', xy=(best_epoch_idx + 1, best_acc_val), xytext=(best_epoch_idx + 1, best_acc_val + 0.1), ha='center',
                        arrowprops=dict(arrowstyle="->", color=line.get_color()), bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="none", alpha=0.7))

    ax.set_title('验证准确度对比 (越高越好)', fontsize=18, weight='bold')
    ax.set_xlabel('训练周期 (Epoch)', fontsize=14, weight='bold')
    ax.set_ylabel('三元组准确度', fontsize=14)
    ax.yaxis.set_major_formatter(plt.FuncFormatter('{:.0%}'.format))
    ax.set_ylim(-0.05, 1.05)

    handles, labels = ax.get_legend_handles_labels()
    simple_handles = [h for h, l in zip(handles, labels) if not l.startswith('Best')]
    simple_labels = [l for l in labels if not l.startswith('Best')]
    ax.legend(handles=simple_handles, labels=simple_labels, title='模型', fontsize=11, ncol=2)
    ax.grid(True, linestyle='--', alpha=0.5)

    # --- 调整整体布局并显示 ---
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.savefig(os.path.join(MODEL_SAVE_PATH, 'training_curves_final.png'), dpi=300, bbox_inches='tight')
    plt.show()

else:
    print("没有可供可视化的训练历史记录。请先运行训练单元格。")

# 模型评估与对比

In [None]:
import warnings
from typing import Tuple

# --- 1. 辅助函数 ---

def detect_special_models(models_dict: dict) -> dict:
    """自动识别并分类所有特殊模型。"""
    special_model_types = {'moe': set(), 'lpmp': set(), 'kd': set(), 'cons_kd': set()}
    for name, model in models_dict.items():
        class_name = model.__class__.__name__
        if "ConsKD" in class_name: special_model_types['cons_kd'].add(name)
        elif "KD" in class_name: special_model_types['kd'].add(name)
        elif "MoE" in class_name: special_model_types['moe'].add(name)
        elif "LearnableLPMP" in class_name: special_model_types['lpmp'].add(name)
    return special_model_types

all_models_for_training = {**models, **lpmp_models}
special_models = detect_special_models(all_models_for_training)

print("--- 自动识别的特殊模型 ---")
for type_name, names in special_models.items():
    print(f"  - {type_name.upper()} 模型: {names if names else '无'}")

def load_best_models(models_dict: dict, save_path: str, device: torch.device) -> dict:
    """加载所有已训练好的最佳模型。"""
    print("\n--- 1. 正在加载最佳模型权重 ---")
    loaded_models = {}
    for name, model_template in models_dict.items():
        model_path = os.path.join(save_path, f'best_{name}_model.pth')
        if os.path.exists(model_path):
            try:
                # 使用原始字典中的实例来加载状态
                model_instance = models_dict[name]
                model_instance.load_state_dict(torch.load(model_path, map_location=device))
                model_instance.to(device)
                model_instance.eval()
                loaded_models[name] = model_instance
                print(f"  - 已成功加载模型: {name}")
            except Exception as e:
                print(f"警告: 加载模型 {name} 权重失败，将跳过。错误: {e}")
        else:
            print(f"警告: 未找到 {name} 的权重文件，将跳过。")
    return loaded_models

def extract_all_features(
    loaded_learning_models: dict,
    non_learning_models: dict,
    dataset: 'EvaluationDataset',
    encryption_schemes: dict,
    data_transform: transforms.Compose,
    special_models: dict
) -> Tuple[dict, dict, np.ndarray]:
    """为所有模型在所有条件下提取特征并计时。"""
    all_model_names = list(loaded_learning_models.keys()) + list(non_learning_models.keys())
    all_conditions = [enc+noise for enc in encryption_schemes for noise in ['', '+噪声']]

    features = {cond: {name: [] for name in all_model_names} for cond in all_conditions}
    timings = {name: 0.0 for name in all_model_names}
    labels = np.array([item[1] for item in dataset.all_images]).astype(int)
    num_images = len(dataset)

    print(f"\n--- 2. 正在为 {num_images} 张图像在 {len(all_conditions)} 种条件下提取特征 ---")

    for img_pil, _ in tqdm(dataset, desc="最终评估进度"):
        tensor_plain = data_transform(img_pil).unsqueeze(0)
        img_np_noisy = add_gaussian_noise_np(cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR))
        tensor_plain_noisy = data_transform(Image.fromarray(cv2.cvtColor(img_np_noisy, cv2.COLOR_BGR2RGB))).unsqueeze(0)

        for enc_name, enc_func in encryption_schemes.items():
            tensor_encrypted = enc_func(tensor_plain).to(device)
            tensor_encrypted_noisy = enc_func(tensor_plain_noisy).to(device)

            with torch.no_grad():
                for name, model in loaded_learning_models.items():
                    start_time = time.perf_counter()

                    # --- 智能处理不同模型的forward签名 ---
                    if name in special_models['cons_kd']:
                        feat_norm, *_ = model(tensor_encrypted, encrypt_global_pixel_shuffle(tensor_encrypted))
                        feat_noisy, *_ = model(tensor_encrypted_noisy, encrypt_global_pixel_shuffle(tensor_encrypted_noisy))
                    elif name in special_models['kd']:
                        feat_norm, *_ = model(tensor_encrypted)
                        feat_noisy, *_ = model(tensor_encrypted_noisy)
                    elif name in special_models['moe']:
                        feat_norm, _ = model(tensor_encrypted)
                        feat_noisy, _ = model(tensor_encrypted_noisy)
                    else: # 标准模型或LearnableLPMP
                        feat_norm = model(tensor_encrypted)
                        feat_noisy = model(tensor_encrypted_noisy)

                    timings[name] += time.perf_counter() - start_time
                    features[enc_name][name].append(feat_norm.cpu().numpy())
                    features[enc_name+'+噪声'][name].append(feat_noisy.cpu().numpy())

            img_np_encrypted = np.array(tensor_encrypted.squeeze(0).cpu().permute(1, 2, 0) * 255, dtype=np.uint8)[:,:,::-1]
            img_np_noisy_encrypted = np.array(tensor_encrypted_noisy.squeeze(0).cpu().permute(1, 2, 0) * 255, dtype=np.uint8)[:,:,::-1]
            for name, func in non_learning_models.items():
                start_time = time.perf_counter()
                features[enc_name][name].append(func(img_np_encrypted))
                features[enc_name+'+噪声'][name].append(func(img_np_noisy_encrypted))
                timings[name] += time.perf_counter() - start_time

    for cond in features:
        for name in all_model_names:
            if features[cond][name]: features[cond][name] = np.vstack(features[cond][name])

    return features, timings, labels

def calculate_map(features, labels):
    if not isinstance(features, np.ndarray) or features.shape[0] < 2: return 0.0
    aps = []
    for i in range(features.shape[0]):
        query_feat, query_label = features[i], labels[i]
        distances = np.linalg.norm(features - query_feat, axis=1)
        sorted_indices = np.argsort(distances)
        relevant_mask = (labels[sorted_indices] == query_label); relevant_mask[0] = False
        num_relevant = np.sum(labels == query_label) - 1
        if num_relevant <= 0: continue
        precision_at_k = np.cumsum(relevant_mask[1:]) / np.arange(1, features.shape[0])
        ap = np.sum(precision_at_k * relevant_mask[1:]) / num_relevant
        aps.append(ap)
    return np.mean(aps) if aps else 0.0

# --- 主评估流程 ---
if 'test_dataset_plain' not in locals():
    print("错误: 'test_dataset_plain' 未定义。请先运行数据加载单元格。")
else:
    # 1. 加载所有训练好的学习模型
    loaded_models = load_best_models(models, MODEL_SAVE_PATH, device)

    # 2. 定义非学习对照组
    non_learning_models = {'ColorHist': extract_color_histogram, 'SLBP': extract_slbp_features}

    # 3. 定义加密场景
    encryption_schemes = {
        '块置乱加密': encrypt_block_permutation,
        '块内像素置乱': encrypt_block_pixel_shuffle,
        '全局像素置乱': encrypt_global_pixel_shuffle
    }

    # 4. 执行特征提取 (传入special_models字典)
    all_features, timings, all_labels = extract_all_features(
        loaded_models, non_learning_models, test_dataset_plain, encryption_schemes, data_transform, special_models
    )

    # 5. 计算所有mAP和速度
    print("\n--- 3. 正在计算所有模型的mAP... ---")
    all_model_names_eval = list(loaded_models.keys()) + list(non_learning_models.keys())
    final_map = {cond: {name: calculate_map(all_features[cond].get(name, np.array([])), all_labels)
                      for name in all_model_names_eval}
                 for cond in all_features}
    avg_timings = {name: t / (len(test_dataset_plain) * len(all_features)) for name, t in timings.items()}

    # 6. 【【【 全新可视化与制表模块 】】】

    # a. 创建并计算综合得分
    df = pd.DataFrame(final_map)
    df['速度'] = pd.Series(avg_timings)
    df = df.fillna(0) # 将NaN值填充为0，防止计算错误
    df.index.name = '模型'
    weights = {'块置乱加密': 1.0, '块置乱加密+噪声': 0.8, '块内像素置乱': 0.7, '块内像素置乱+噪声': 0.6,
               '全局像素置乱': 0.7, '全局像素置乱+噪声': 0.9, '速度': 1.0}
    normalized_df = pd.DataFrame(index=df.index)
    with warnings.catch_warnings(): # 忽略除以零的警告
        warnings.simplefilter("ignore")
        for col, weight in weights.items():
            if col == '速度':
                norm_val = 1 / (df[col] + 1e-9); normalized_df[col] = (norm_val - norm_val.min()) / (norm_val.max() - norm_val.min() + 1e-9)
            else:
                norm_val = df[col]; normalized_df[col] = (norm_val - norm_val.min()) / (norm_val.max() - norm_val.min() + 1e-9)
    df['综合得分'] = (normalized_df.fillna(0) * pd.Series(weights)).sum(axis=1)
    df = df.sort_values(by='综合得分', ascending=False)

    # b. 绘制 3x2 + 1x2 的排名式图表
    map_conditions = ['块置乱加密', '块置乱加密+噪声', '块内像素置乱', '块内像素置乱+噪声', '全局像素置乱', '全局像素置乱+噪声']
    fig, axes = plt.subplots(4, 2, figsize=(24, 30)); fig.suptitle('加密图像检索模型全方位性能排名', fontsize=28, y=0.98)
    model_colors = {model_name: plt.cm.tab20(i / len(df)) for i, model_name in enumerate(df.index)}

    for i, cond in enumerate(map_conditions):
        ax = axes[i // 2, i % 2]
        df_sorted = df.sort_values(by=cond, ascending=True)
        colors = [model_colors[model] for model in df_sorted.index]
        bars = ax.barh(df_sorted.index, df_sorted[cond], color=colors)
        ax.set_title(f'排名: {cond} (mAP)', fontsize=16); ax.set_xlabel('mAP', fontsize=12)
        ax.tick_params(axis='y', labelsize=11); ax.xaxis.set_major_formatter(plt.FuncFormatter('{:.1%}'.format))
        ax.bar_label(bars, fmt='{:.2%}', padding=3, fontsize=10); ax.grid(axis='x', linestyle='--', alpha=0.6)

    ax_speed = axes[3, 0]
    df_sorted_speed = df.sort_values(by='速度', ascending=False)
    colors_speed = [model_colors[model] for model in df_sorted_speed.index]
    bars_speed = ax_speed.barh(df_sorted_speed.index, df_sorted_speed['速度'], color=colors_speed)
    ax_speed.set_title('排名: 推理速度 (秒/张)', fontsize=16); ax_speed.set_xlabel('平均处理时间 (秒)', fontsize=12)
    ax_speed.tick_params(axis='y', labelsize=11); ax_speed.set_xscale('log')
    ax_speed.bar_label(bars_speed, fmt='{:.5f}', padding=3, fontsize=10); ax_speed.grid(axis='x', which='both', linestyle='--', alpha=0.6)

    ax_score = axes[3, 1]
    df_sorted_score = df.sort_values(by='综合得分', ascending=True)
    colors_score = [model_colors[model] for model in df_sorted_score.index]
    bars_score = ax_score.barh(df_sorted_score.index, df_sorted_score['综合得分'], color=colors_score)
    ax_score.set_title('排名: 加权综合得分', fontsize=16); ax_score.set_xlabel('综合得分', fontsize=12)
    ax_score.tick_params(axis='y', labelsize=11); ax_score.bar_label(bars_score, fmt='{:.3f}', padding=3, fontsize=10)
    ax_score.grid(axis='x', linestyle='--', alpha=0.6)

    plt.tight_layout(rect=[0, 0, 1, 0.97])
    plt.savefig(os.path.join(MODEL_SAVE_PATH, 'final_performance_ranking.png'), dpi=300, bbox_inches='tight')
    plt.show() # <-- 确保图表显示

    # c. 打印最终的、经过美化的总结表格
    print("\n" + "="*140)
    print("最终性能总结表格 (按'综合得分'降序排列)")
    print("="*144)

    df_display = df.copy()
    display_order = ['综合得分', '块置乱加密', '块置乱加密+噪声', '块内像素置乱', '块内像素置乱+噪声', '全局像素置乱', '全局像素置乱+噪声', '速度']
    df_display = df_display[display_order]

    def highlight_max(s, props='color:green; font-weight:bold;'): return np.where(s == np.nanmax(s.values), props, '')
    def highlight_min(s, props='color:blue; font-weight:bold;'): return np.where(s == np.nanmin(s.values), props, '')

    with pd.option_context('display.precision', 5):
        styled_df = df_display.style.format({
            **{col: '{:.2%}'.format for col in df_display.columns if '加密' in col},
            '速度': '{:.5f}'.format,
            '综合得分': '{:.3f}'.format
        }).apply(highlight_max, props='color:white; background-color:#2ca02c;', subset=[c for c in df.columns if c != '速度']
        ).apply(highlight_min, props='color:white; background-color:#1f77b4;', subset=['速度']
        ).background_gradient(cmap='viridis_r', subset=[c for c in df.columns if c != '速度']
        ).background_gradient(cmap='plasma', subset=['速度'])

    display(styled_df) # <-- 确保美化表格显示