In [102]:
import os
import torch
from ultralytics import YOLO
from torchvision.datasets import ImageFolder

# ---------------------------
# 配置部分
# ---------------------------
yolo_path = "/Users/yaogunzhishen/Desktop/best.pt"  # 预训练 YOLO 分类模型权重路径
test_folder = "/Users/yaogunzhishen/Desktop/datasets/test"   # 测试集目录，每个类别一个子文件夹

# 选择设备（例如 MPS 或 CUDA）
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("Using device:", device)

# ---------------------------
# 辅助函数：加载模型包装器
# ---------------------------
def load_model_wrapper():
    """
    重新加载一个新的 YOLO 模型包装器
    """
    wrapper = YOLO(yolo_path)
    wrapper.model.to(device)
    wrapper.model.eval()
    return wrapper

# ---------------------------
# 辅助函数：遍历模型中的注意力模块（AAttn）
# ---------------------------
def get_attention_modules(model):
    """
    遍历模型中所有子模块，返回所有类型为 AAttn 的模块列表，
    每个元素为 (name, module)。
    """
    attn_modules = []
    for name, module in model.named_modules():
        if module.__class__.__name__ == "AAttn":
            attn_modules.append((name, module))
    return attn_modules

# ---------------------------
# 辅助函数：禁用（消融）指定注意力模块
# ---------------------------
def disable_module(module):
    """
    将模块中用于计算注意力的权重置零。
    对于 AAttn 模块，尝试将 qkv、proj 和 pe 卷积层的权重置零（如果存在）。
    """
    if hasattr(module, 'qkv') and hasattr(module.qkv, 'weight'):
        module.qkv.weight.data.zero_()
    if hasattr(module, 'proj') and hasattr(module.proj, 'weight'):
        module.proj.weight.data.zero_()
    if hasattr(module, 'pe') and hasattr(module.pe, 'conv') and hasattr(module.pe.conv, 'weight'):
        module.pe.conv.weight.data.zero_()

# ---------------------------
# 辅助函数：评估模型在测试集上的准确率
# ---------------------------
def evaluate_model(model_wrapper, dataset):
    """
    遍历测试集样本，通过文件路径调用模型包装器进行预测，并统计准确率
    """
    num_total = 0
    num_correct = 0
    for image_path, true_label in dataset.samples:
        # 禁用 verbose 输出
        model_wrapper.overrides['verbose'] = False
        results = model_wrapper(image_path)
        if not results:
            print(f"No result for {image_path}")
            continue
        res = results[0]
        if hasattr(res, "probs") and res.probs is not None:
            pred_label = int(res.probs.top1)
        else:
            print(f"Missing 'probs' attribute for {image_path}")
            continue

        num_total += 1
        if pred_label == true_label:
            num_correct += 1
    accuracy = num_correct / num_total if num_total > 0 else 0
    return accuracy

# ---------------------------
# 主程序：分阶段消融实验
# ---------------------------
def main():
    # 加载测试集
    test_dataset = ImageFolder(root=test_folder)
    
    # Baseline：带注意力的模型
    print("Evaluating model WITH attention:")
    wrapper_baseline = load_model_wrapper()
    baseline_acc = evaluate_model(wrapper_baseline, test_dataset)
    print("Baseline Accuracy: {:.2%}".format(baseline_acc))
    
    # 消融实验 1：仅禁用第6层注意力（遍历所有注意力模块时下标为5）
    print("\nEvaluating model with 6th attention disabled:")
    wrapper_6 = load_model_wrapper()
    attn_modules = get_attention_modules(wrapper_6.model)
    if len(attn_modules) >= 6:
        for i in range(4):
            name6, module6 = attn_modules[i]
            print(f"Disabling attention in module: {name6}")
            disable_module(module6)
    else:
        print("Not enough attention modules for 6th layer!")
    acc_6 = evaluate_model(wrapper_6, test_dataset)
    print("Accuracy with 6th attention disabled: {:.2%}".format(acc_6))
    
    # 消融实验 2：仅禁用第8层注意力（下标7）
    print("\nEvaluating model with 8th attention disabled:")
    wrapper_8 = load_model_wrapper()
    attn_modules = get_attention_modules(wrapper_8.model)
    if len(attn_modules) >= 8:
        for i in range(4,8):
            name6, module6 = attn_modules[i]
            print(f"Disabling attention in module: {name6}")
            disable_module(module6)
    else:
        print("Not enough attention modules for 8th layer!")
    acc_8 = evaluate_model(wrapper_8, test_dataset)
    print("Accuracy with 8th attention disabled: {:.2%}".format(acc_8))
    
    # 消融实验 3：同时禁用第6层和第8层注意力
    print("\nEvaluating model with both 6th and 8th attention disabled:")
    wrapper_both = load_model_wrapper()
    attn_modules = get_attention_modules(wrapper_both.model)
    if len(attn_modules) >= 8:
        for i in range(8):
            name6, module6 = attn_modules[i]
            print(f"Disabling attention in module: {name6}")
            disable_module(module6)
    else:
        print("Not enough attention modules for both layers!")
    acc_both = evaluate_model(wrapper_both, test_dataset)
    print("Accuracy with both attention disabled: {:.2%}".format(acc_both))
    
if __name__ == '__main__':
    main()

Using device: mps
Evaluating model WITH attention:
Baseline Accuracy: 90.44%

Evaluating model with 6th attention disabled:
Disabling attention in module: model.6.m.0.0.attn
Disabling attention in module: model.6.m.0.1.attn
Disabling attention in module: model.6.m.1.0.attn
Disabling attention in module: model.6.m.1.1.attn
Accuracy with 6th attention disabled: 84.33%

Evaluating model with 8th attention disabled:
Disabling attention in module: model.8.m.0.0.attn
Disabling attention in module: model.8.m.0.1.attn
Disabling attention in module: model.8.m.1.0.attn
Disabling attention in module: model.8.m.1.1.attn
Accuracy with 8th attention disabled: 88.22%

Evaluating model with both 6th and 8th attention disabled:
Disabling attention in module: model.6.m.0.0.attn
Disabling attention in module: model.6.m.0.1.attn
Disabling attention in module: model.6.m.1.0.attn
Disabling attention in module: model.6.m.1.1.attn
Disabling attention in module: model.8.m.0.0.attn
Disabling attention in module