In [45]:
from ultralytics import YOLO
yolo_path = "/Users/yaogunzhishen/Desktop/best.pt"  # 预训练 YOLO 分类模型权重路径
model = YOLO(yolo_path)
print(model)

    

YOLO(
  (model): ClassificationModel(
    (model): Sequential(
      (0): Conv(
        (conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU(inplace=True)
      )
      (1): Conv(
        (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU(inplace=True)
      )
      (2): C3k2(
        (cv1): Conv(
          (conv): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): SiLU(inplace=True)
        )
        (cv2): Conv(
          (conv): Conv2d(48, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running

attention 消融实验

In [46]:

import os
import torch
from ultralytics import YOLO
from torchvision.datasets import ImageFolder

# ---------------------------
# 配置部分
# ---------------------------
yolo_path = "/Users/yaogunzhishen/Desktop/best.pt"  # 预训练 YOLO 分类模型权重路径
test_folder = "/Users/yaogunzhishen/Desktop/datasets/test"   # 测试集目录，每个类别一个子文件夹

# 选择设备（例如 MPS 或 CUDA）
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("Using device:", device)

# ---------------------------
# 辅助函数：加载模型包装器
# ---------------------------
def load_model_wrapper():
    """
    重新加载一个新的 YOLO 模型包装器
    """
    wrapper = YOLO(yolo_path)
    wrapper.model.to(device)
    wrapper.model.eval()
    return wrapper

# ---------------------------
# 辅助函数：遍历模型中的注意力模块（AAttn）
# ---------------------------
def get_attention_modules(model):
    """
    遍历模型中所有子模块，返回所有类型为 AAttn 的模块列表，
    每个元素为 (name, module)。
    """
    attn_modules = []
    for name, module in model.named_modules():
        if module.__class__.__name__ == "AAttn":
            attn_modules.append((name, module))
    return attn_modules

# ---------------------------
# 辅助函数：禁用（消融）指定注意力模块
# ---------------------------
def disable_module(module):
    """
    将模块中用于计算注意力的权重置零。
    对于 AAttn 模块，尝试将 qkv、proj 和 pe 卷积层的权重置零（如果存在）。
    """
    if hasattr(module, 'qkv') and hasattr(module.qkv, 'weight'):
        module.qkv.weight.data.zero_()
    if hasattr(module, 'proj') and hasattr(module.proj, 'weight'):
        module.proj.weight.data.zero_()
    if hasattr(module, 'pe') and hasattr(module.pe, 'conv') and hasattr(module.pe.conv, 'weight'):
        module.pe.conv.weight.data.zero_()

# ---------------------------
# 辅助函数：评估模型在测试集上的准确率
# ---------------------------
def evaluate_model(model_wrapper, dataset):
    """
    遍历测试集样本，通过文件路径调用模型包装器进行预测，并统计准确率
    """
    num_total = 0
    num_correct = 0
    for image_path, true_label in dataset.samples:
        # 禁用 verbose 输出
        model_wrapper.overrides['verbose'] = False
        results = model_wrapper(image_path)
        if not results:
            print(f"No result for {image_path}")
            continue
        res = results[0]
        if hasattr(res, "probs") and res.probs is not None:
            pred_label = int(res.probs.top1)
        else:
            print(f"Missing 'probs' attribute for {image_path}")
            continue

        num_total += 1
        if pred_label == true_label:
            num_correct += 1
    accuracy = num_correct / num_total if num_total > 0 else 0
    return accuracy

# ---------------------------
# 主程序：分阶段消融实验
# ---------------------------
def main():
    # 加载测试集
    test_dataset = ImageFolder(root=test_folder)
    
    # Baseline：带注意力的模型
    print("Evaluating model WITH attention:")
    wrapper_baseline = load_model_wrapper()
    baseline_acc = evaluate_model(wrapper_baseline, test_dataset)
    print("Baseline Accuracy: {:.2%}".format(baseline_acc))
    
    # 消融实验 1：仅禁用第6层注意力（遍历所有注意力模块时下标为5）
    print("\nEvaluating model with 6th attention disabled:")
    wrapper_6 = load_model_wrapper()
    attn_modules = get_attention_modules(wrapper_6.model)
    if len(attn_modules) >= 6:
        for i in range(4):
            name6, module6 = attn_modules[i]
            print(f"Disabling attention in module: {name6}")
            disable_module(module6)
    else:
        print("Not enough attention modules for 6th layer!")
    acc_6 = evaluate_model(wrapper_6, test_dataset)
    print("Accuracy with 6th attention disabled: {:.2%}".format(acc_6))
    
    # 消融实验 2：仅禁用第8层注意力（下标7）
    print("\nEvaluating model with 8th attention disabled:")
    wrapper_8 = load_model_wrapper()
    attn_modules = get_attention_modules(wrapper_8.model)
    if len(attn_modules) >= 8:
        for i in range(4,8):
            name6, module6 = attn_modules[i]
            print(f"Disabling attention in module: {name6}")
            disable_module(module6)
    else:
        print("Not enough attention modules for 8th layer!")
    acc_8 = evaluate_model(wrapper_8, test_dataset)
    print("Accuracy with 8th attention disabled: {:.2%}".format(acc_8))
    
    # 消融实验 3：同时禁用第6层和第8层注意力
    print("\nEvaluating model with both 6th and 8th attention disabled:")
    wrapper_both = load_model_wrapper()
    attn_modules = get_attention_modules(wrapper_both.model)
    if len(attn_modules) >= 8:
        for i in range(8):
            name6, module6 = attn_modules[i]
            print(f"Disabling attention in module: {name6}")
            disable_module(module6)
    else:
        print("Not enough attention modules for both layers!")
    acc_both = evaluate_model(wrapper_both, test_dataset)
    print("Accuracy with both attention disabled: {:.2%}".format(acc_both))
    
if __name__ == '__main__':
    main()


Using device: mps
Evaluating model WITH attention:
Baseline Accuracy: 90.44%

Evaluating model with 6th attention disabled:
Disabling attention in module: model.6.m.0.0.attn
Disabling attention in module: model.6.m.0.1.attn
Disabling attention in module: model.6.m.1.0.attn
Disabling attention in module: model.6.m.1.1.attn
Accuracy with 6th attention disabled: 84.33%

Evaluating model with 8th attention disabled:
Disabling attention in module: model.8.m.0.0.attn
Disabling attention in module: model.8.m.0.1.attn
Disabling attention in module: model.8.m.1.0.attn
Disabling attention in module: model.8.m.1.1.attn
Accuracy with 8th attention disabled: 88.22%

Evaluating model with both 6th and 8th attention disabled:
Disabling attention in module: model.6.m.0.0.attn
Disabling attention in module: model.6.m.0.1.attn
Disabling attention in module: model.6.m.1.0.attn
Disabling attention in module: model.6.m.1.1.attn
Disabling attention in module: model.8.m.0.0.attn
Disabling attention in module

C3k2 m层 消融实验

In [1]:
import os
import torch
import torch.nn as nn
from ultralytics import YOLO
from torchvision.datasets import ImageFolder

# ---------------------------
# 配置部分
# ---------------------------
yolo_path = "/Users/yaogunzhishen/Desktop/best.pt"  # 预训练 YOLO 分类模型权重路径
test_folder = "/Users/yaogunzhishen/Desktop/datasets/test"   # 测试集目录，每个类别一个子文件夹

# 选择设备（例如 MPS 或 CUDA）
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("Using device:", device)

# ---------------------------
# 新的简单卷积模块：PlainConv（采用四个连续卷积层的结构）
# ---------------------------

# ---------------------------
# 定义零模块：ZeroModule
# ---------------------------
class ZeroModule(nn.Module):
    """
    返回全零张量的模块，保持输出通道数不变、尺寸与输入相同。
    """
    def __init__(self, channels):
        super().__init__()
        self.channels = channels

    def forward(self, x):
        B, _, H, W = x.shape
        return torch.zeros(B, self.channels, H, W, dtype=x.dtype, device=x.device)

# ---------------------------
# 辅助函数：加载模型包装器
# ---------------------------
def load_model_wrapper():
    """
    加载预训练的 YOLO 模型，并将其移动到指定设备，同时设置为评估模式
    """
    wrapper = YOLO(yolo_path)
    wrapper.model.to(device)
    wrapper.model.eval()
    return wrapper

# ---------------------------
# 辅助函数：根据目标键名称去掉 C3k2 模块中 ModuleList 部分
# ---------------------------
def remove_modulelist_from_target_c3k2(model, target_keys="2"):
    """
    递归遍历当前模块的所有直接子模块，如果遇到键名称等于 target_keys 中的元素且类型为 C3k2 的模块，
    则将该模块的 ModuleList（属性 m）替换为一个包含一个 ZeroModule 的 ModuleList，
    其中 ZeroModule 会返回全零张量，其通道数保证拼接后通道数与原来一致。
    
    参数：
      model       - 待处理的模型（或子模块）
      target_keys - 目标模块的键名称（可以为单个字符串或列表），例如 "2" 或 ["2", "4"]
    """
    # 若 target_keys 不是列表，则转换为列表
    if not isinstance(target_keys, list):
        target_keys = [target_keys]
    
    # 遍历当前层的所有直接子模块
    for name, module in list(model._modules.items()):
        # 递归处理子模块
        remove_modulelist_from_target_c3k2(module, target_keys)
        
        # 判断当前子模块是否为目标 C3k2 模块
        if name in target_keys and module.__class__.__name__ == "C3k2":
            print(f"Found target C3k2 module with key '{name}'. Removing its ModuleList part by replacing it with a ZeroModule.")
            try:
                # 假设输出为：concat(cv1(x), m(x))
                # cv1 输出的通道数为 cv1.conv.out_channels；
                # cv2 期望的输入通道数为 cv2.conv.in_channels；
                # 那么 m(x) 应输出的通道数为 dummy_channels = cv2.conv.in_channels - cv1.conv.out_channels
                dummy_channels = module.cv2.conv.in_channels - module.cv1.conv.out_channels
                if dummy_channels < 0:
                    raise ValueError("dummy_channels computed as negative, check the architecture!")
            except Exception as e:
                print(f"Error computing dummy_channels for module '{name}': {e}")
                continue

            # 用包含一个 ZeroModule 的 ModuleList 替代原来的 ModuleList
            module.m = nn.ModuleList([ZeroModule(dummy_channels)])
            print(f"Module '{name}' m set to ModuleList with ZeroModule (dummy_channels={dummy_channels}).")

# ---------------------------
# 辅助函数：评估模型在测试集上的准确率
# ---------------------------
def evaluate_model(model_wrapper, dataset):
    """
    遍历测试集中的每个样本，通过文件路径调用模型包装器进行推理，
    并统计预测正确率（假定返回结果中有 probs.top1 属性表示预测类别）。
    """
    num_total = 0
    num_correct = 0
    for image_path, true_label in dataset.samples:
        model_wrapper.overrides['verbose'] = False
        results = model_wrapper(image_path)
        if not results:
            print(f"No result for {image_path}")
            continue
        res = results[0]
        if hasattr(res, "probs") and res.probs is not None:
            pred_label = int(res.probs.top1)
        else:
            print(f"Missing 'probs' attribute for {image_path}")
            continue

        num_total += 1
        if pred_label == true_label:
            num_correct += 1
    accuracy = num_correct / num_total if num_total > 0 else 0
    return accuracy

# ---------------------------
# 主程序：依次进行消融实验
#   1. 仅去掉 module "2" 的 C3k2 中 ModuleList 部分
#   2. 仅去掉 module "4" 的 C3k2 中 ModuleList 部分
#   3. 同时去掉 module "2" 和 "4" 的 C3k2 中 ModuleList 部分
# ---------------------------
def main():
    # 加载测试集
    test_dataset = ImageFolder(root=test_folder)
    
    # 实验1：基线模型（原始模型，不修改C3k2模块）
    print("Evaluating baseline model (with original C3k2):")
    wrapper_baseline = load_model_wrapper()
    baseline_acc = evaluate_model(wrapper_baseline, test_dataset)
    print("Baseline Accuracy: {:.2%}".format(baseline_acc))
    
    # 实验2：仅去掉键为"2"的 C3k2 模块中 ModuleList 部分
    print("\nEvaluating model with ModuleList removed from module '2' of C3k2:")
    wrapper_mod2 = load_model_wrapper()
    remove_modulelist_from_target_c3k2(wrapper_mod2.model, target_keys="2")
    acc_mod2 = evaluate_model(wrapper_mod2, test_dataset)
    print("Accuracy with ModuleList removed from module '2': {:.2%}".format(acc_mod2))
    for name, module in wrapper_mod2.model.named_modules():
        if module.__class__.__name__ == "C3k2" and name.endswith("2"):
            print(f"Module '{name}' m attribute:", module.m)
    
    # 实验3：仅去掉键为"4"的 C3k2 模块中 ModuleList 部分
    print("\nEvaluating model with ModuleList removed from module '4' of C3k2:")
    wrapper_mod4 = load_model_wrapper()
    remove_modulelist_from_target_c3k2(wrapper_mod4.model, target_keys="4")
    acc_mod4 = evaluate_model(wrapper_mod4, test_dataset)
    print("Accuracy with ModuleList removed from module '4': {:.2%}".format(acc_mod4))
    for name, module in wrapper_mod4.model.named_modules():
        if module.__class__.__name__ == "C3k2" and name.endswith("4"):
            print(f"Module '{name}' m attribute:", module.m)
    
    # 实验4：同时去掉键为"2"和"4"的 C3k2 模块中 ModuleList 部分
    print("\nEvaluating model with ModuleList removed from modules '2' and '4' of C3k2:")
    wrapper_both = load_model_wrapper()
    remove_modulelist_from_target_c3k2(wrapper_both.model, target_keys=["2", "4"])
    acc_both = evaluate_model(wrapper_both, test_dataset)
    print("Accuracy with ModuleList removed from modules '2' and '4': {:.2%}".format(acc_both))
    for name, module in wrapper_both.model.named_modules():
        if module.__class__.__name__ == "C3k2" and (name.endswith("2") or name.endswith("4")):
            print(f"Module '{name}' m attribute:", module.m)
    
if __name__ == '__main__':
    main()

Using device: mps
Evaluating baseline model (with original C3k2):
Baseline Accuracy: 90.44%

Evaluating model with ModuleList removed from module '2' of C3k2:
Found target C3k2 module with key '2'. Removing its ModuleList part by replacing it with a ZeroModule.
Module '2' m set to ModuleList with ZeroModule (dummy_channels=16).
Accuracy with ModuleList removed from module '2': 19.72%
Module 'model.2' m attribute: ModuleList(
  (0): ZeroModule()
)

Evaluating model with ModuleList removed from module '4' of C3k2:
Found target C3k2 module with key '4'. Removing its ModuleList part by replacing it with a ZeroModule.
Module '4' m set to ModuleList with ZeroModule (dummy_channels=32).
Accuracy with ModuleList removed from module '4': 44.17%
Module 'model.4' m attribute: ModuleList(
  (0): ZeroModule()
)

Evaluating model with ModuleList removed from modules '2' and '4' of C3k2:
Found target C3k2 module with key '2'. Removing its ModuleList part by replacing it with a ZeroModule.
Module '2' 

c3k2 m 层退化实验

In [2]:
import os
import torch
import torch.nn as nn
from ultralytics import YOLO
from torchvision.datasets import ImageFolder

# ---------------------------
# 配置部分
# ---------------------------
yolo_path = "/Users/yaogunzhishen/Desktop/best.pt"  # 预训练 YOLO 分类模型权重路径
test_folder = "/Users/yaogunzhishen/Desktop/datasets/test"   # 测试集目录，每个类别一个子文件夹

# 选择设备（例如 MPS 或 CUDA）
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("Using device:", device)

# ---------------------------
# 新的简单卷积模块：PlainConv（采用四个连续卷积层的结构）
# ---------------------------
class PlainConv(nn.Module):
    """
    用于替换 C3k2 模块的普通卷积模块，新设计由四个连续的卷积层组成，
    每个卷积层后跟 BatchNorm 和 SiLU 激活。可选参数 hidden_channels 用于指定中间层通道数，
    如果未指定，则默认为 out_channels。
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,
                 bias=False, hidden_channels=None):
        super().__init__()
        if hidden_channels is None:
            hidden_channels = out_channels

        # 第一层：从 in_channels 到 hidden_channels
        self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size, stride, padding, bias=bias)
        self.bn1 = nn.BatchNorm2d(hidden_channels)
        self.act1 = nn.SiLU(inplace=True)
        
        # 第二层：保持 hidden_channels
        self.conv2 = nn.Conv2d(hidden_channels, hidden_channels, kernel_size, stride, padding, bias=bias)
        self.bn2 = nn.BatchNorm2d(hidden_channels)
        self.act2 = nn.SiLU(inplace=True)
        
        # 第三层：保持 hidden_channels
        self.conv3 = nn.Conv2d(hidden_channels, hidden_channels, kernel_size, stride, padding, bias=bias)
        self.bn3 = nn.BatchNorm2d(hidden_channels)
        self.act3 = nn.SiLU(inplace=True)
        
        # 第四层：从 hidden_channels 到 out_channels
        self.conv4 = nn.Conv2d(hidden_channels, out_channels, kernel_size, stride, padding, bias=bias)
        self.bn4 = nn.BatchNorm2d(out_channels)
        self.act4 = nn.SiLU(inplace=True)

    def forward(self, x):
        x = self.act1(self.bn1(self.conv1(x)))
        x = self.act2(self.bn2(self.conv2(x)))
        x = self.act3(self.bn3(self.conv3(x)))
        x = self.act4(self.bn4(self.conv4(x)))
        return x

# ---------------------------
# 定义两个卷积层模块：TwoConv
# ---------------------------
class TwoConv(nn.Module):
    """
    用于替换 C3k2 模块中原 ModuleList 部分的辅助模块，由两个连续的卷积层组成，
    每层后跟 BatchNorm 和 SiLU 激活。此模块接收输入和输出通道均为 dummy_channels，
    以保证拼接后的通道数与原设计一致。
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False, hidden_channels=None):
        super().__init__()
        # 若未指定 hidden_channels，则将中间通道数设为 out_channels
        if hidden_channels is None:
            hidden_channels = out_channels
        self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size, stride, padding, bias=bias)
        self.bn1 = nn.BatchNorm2d(hidden_channels)
        self.act1 = nn.SiLU(inplace=True)
        self.conv2 = nn.Conv2d(hidden_channels, out_channels, kernel_size, stride, padding, bias=bias)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.act2 = nn.SiLU(inplace=True)
    
    def forward(self, x):
        x = self.act1(self.bn1(self.conv1(x)))
        x = self.act2(self.bn2(self.conv2(x)))
        return x

# ---------------------------
# 定义零模块（备用）：ZeroModule
# ---------------------------
class ZeroModule(nn.Module):
    """
    返回全零张量的模块，保持输出通道数不变、尺寸与输入相同。
    """
    def __init__(self, channels):
        super().__init__()
        self.channels = channels

    def forward(self, x):
        B, _, H, W = x.shape
        return torch.zeros(B, self.channels, H, W, dtype=x.dtype, device=x.device)

# ---------------------------
# 辅助函数：加载模型包装器
# ---------------------------
def load_model_wrapper():
    """
    加载预训练的 YOLO 模型，并将其移动到指定设备，同时设置为评估模式
    """
    wrapper = YOLO(yolo_path)
    wrapper.model.to(device)
    wrapper.model.eval()
    return wrapper

# ---------------------------
# 辅助函数：根据目标键名称去掉 C3k2 模块中 ModuleList 部分，替换为 TwoConv 模块
# ---------------------------
def replace_modulelist_with_twoconv(model, target_keys="2"):
    """
    递归遍历当前模块的所有直接子模块，如果遇到键名称在 target_keys 中且类型为 C3k2 的模块，
    则将该模块的 ModuleList（属性 m）替换为一个包含一个 TwoConv 模块的 ModuleList。
    
    TwoConv 模块的设计用于模拟原 ModuleList 分支的输出，
    要求其输出通道数 dummy_channels = cv2.conv.in_channels - cv1.conv.out_channels，
    从而在拼接时保证通道数不变。
    
    参数：
      model       - 待处理的模型（或子模块）
      target_keys - 目标模块的键名称（可以为单个字符串或列表），例如 "2" 或 ["2", "4"]
    """
    if not isinstance(target_keys, list):
        target_keys = [target_keys]
        
    for name, module in list(model._modules.items()):
        # 递归调用处理子模块
        replace_modulelist_with_twoconv(module, target_keys)
        
        if name in target_keys and module.__class__.__name__ == "C3k2":
            print(f"Found target C3k2 module with key '{name}'. Replacing its ModuleList part with TwoConv.")
            try:
                # 计算 dummy_channels ：cv1 输出的通道数
                # cv2 期望的输入通道数为 cv2.conv.in_channels，
                # 因此 dummy_channels = cv2.conv.in_channels - cv1.conv.out_channels
                dummy_channels = module.cv2.conv.in_channels - module.cv1.conv.out_channels
                if dummy_channels < 0:
                    raise ValueError("dummy_channels computed as negative, check the architecture!")
            except Exception as e:
                print(f"Error computing dummy_channels for module '{name}': {e}")
                continue
            
            # 用包含一个 TwoConv 模块的 ModuleList 替换原 m
            two_conv_block = TwoConv(dummy_channels, dummy_channels, kernel_size=3, stride=1, padding=1, bias=False)
            module.m = nn.ModuleList([two_conv_block])
            print(f"Module '{name}' m replaced with TwoConv (dummy_channels={dummy_channels}).")

# ---------------------------
# 辅助函数：评估模型在测试集上的准确率
# ---------------------------
def evaluate_model(model_wrapper, dataset):
    """
    遍历测试集中的每个样本，通过文件路径调用模型包装器进行推理，
    并统计预测正确率（假定返回结果中有 probs.top1 属性表示预测类别）。
    """
    num_total = 0
    num_correct = 0
    for image_path, true_label in dataset.samples:
        model_wrapper.overrides['verbose'] = False
        results = model_wrapper(image_path)
        if not results:
            print(f"No result for {image_path}")
            continue
        res = results[0]
        if hasattr(res, "probs") and res.probs is not None:
            pred_label = int(res.probs.top1)
        else:
            print(f"Missing 'probs' attribute for {image_path}")
            continue

        num_total += 1
        if pred_label == true_label:
            num_correct += 1
    accuracy = num_correct / num_total if num_total > 0 else 0
    return accuracy

# ---------------------------
# 主程序：依次进行消融实验
#   1. 仅替换 module "2" 的 C3k2 模块中 ModuleList 部分为 TwoConv 模块
#   2. 仅替换 module "4" 的 C3k2 模块中 ModuleList 部分为 TwoConv 模块
#   3. 同时替换 module "2" 和 "4" 的 C3k2 模块中 ModuleList 部分为 TwoConv 模块
# ---------------------------
def main():
    # 加载测试集
    test_dataset = ImageFolder(root=test_folder)
    
    # 实验1：基线模型（原始模型，不修改 C3k2 模块）
    print("Evaluating baseline model (with original C3k2):")
    wrapper_baseline = load_model_wrapper()
    baseline_acc = evaluate_model(wrapper_baseline, test_dataset)
    print("Baseline Accuracy: {:.2%}".format(baseline_acc))
    
    # 实验2：仅替换键为 "2" 的 C3k2 模块中 ModuleList 部分为 TwoConv 模块
    print("\nEvaluating model with ModuleList replaced (TwoConv) for module '2' of C3k2:")
    wrapper_mod2 = load_model_wrapper()
    replace_modulelist_with_twoconv(wrapper_mod2.model, target_keys="2")
    acc_mod2 = evaluate_model(wrapper_mod2, test_dataset)
    print("Accuracy with module '2' replaced: {:.2%}".format(acc_mod2))
    for name, module in wrapper_mod2.model.named_modules():
        if module.__class__.__name__ == "C3k2" and name.endswith("2"):
            print(f"Module '{name}' m attribute:", module.m)
    
    # 实验3：仅替换键为 "4" 的 C3k2 模块中 ModuleList 部分为 TwoConv 模块
    print("\nEvaluating model with ModuleList replaced (TwoConv) for module '4' of C3k2:")
    wrapper_mod4 = load_model_wrapper()
    replace_modulelist_with_twoconv(wrapper_mod4.model, target_keys="4")
    acc_mod4 = evaluate_model(wrapper_mod4, test_dataset)
    print("Accuracy with module '4' replaced: {:.2%}".format(acc_mod4))
    for name, module in wrapper_mod4.model.named_modules():
        if module.__class__.__name__ == "C3k2" and name.endswith("4"):
            print(f"Module '{name}' m attribute:", module.m)
    
    # 实验4：同时替换键为 "2" 和 "4" 的 C3k2 模块中 ModuleList 部分为 TwoConv 模块
    print("\nEvaluating model with ModuleList replaced (TwoConv) for modules '2' and '4' of C3k2:")
    wrapper_both = load_model_wrapper()
    replace_modulelist_with_twoconv(wrapper_both.model, target_keys=["2", "4"])
    acc_both = evaluate_model(wrapper_both, test_dataset)
    print("Accuracy with modules '2' and '4' replaced: {:.2%}".format(acc_both))
    for name, module in wrapper_both.model.named_modules():
        if module.__class__.__name__ == "C3k2" and (name.endswith("2") or name.endswith("4")):
            print(f"Module '{name}' m attribute:", module.m)
    
if __name__ == '__main__':
    main()

Using device: mps
Evaluating baseline model (with original C3k2):
Baseline Accuracy: 90.44%

Evaluating model with ModuleList replaced (TwoConv) for module '2' of C3k2:
Found target C3k2 module with key '2'. Replacing its ModuleList part with TwoConv.
Module '2' m replaced with TwoConv (dummy_channels=16).
Accuracy with module '2' replaced: 20.22%
Module 'model.2' m attribute: ModuleList(
  (0): TwoConv(
    (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act1): SiLU(inplace=True)
    (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act2): SiLU(inplace=True)
  )
)

Evaluating model with ModuleList replaced (TwoConv) for module '4' of C3k2:
Found target C3k2 module with key '4'. Replacing its ModuleList part with TwoConv.
