In [1]:
# mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# import module
import sys
sys.path.append('/content/drive/MyDrive')
from preprocessing import FederatedDataBuilder

Mounted at /content/drive


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# 导入你已经实现的组件
from taskarithmetic import SparseSGDM, compute_fisher_sensitivity, calibrate_masks
from fed_avg_non_iid import DINOCIFAR100 # 或者使用你定义的模型
from preprocessing import FederatedDataBuilder

def run_task_arithmetic_experiment(sparsity_ratio=0.1, calibration_batches=10):
    """
    应用任务算术技术进行稀疏微调
    """
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 1. 准备数据
    # 项目要求使用 CIFAR-100 [cite: 29]
    data_builder = FederatedDataBuilder(val_split_ratio=0.1)
    train_loader = DataLoader(data_builder.train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(data_builder.test_dataset, batch_size=128, shuffle=False)

    # 2. 初始化模型 (DINO ViT-S/16)
    model = DINOCIFAR100(num_classes=100).to(DEVICE)
    criterion = nn.CrossEntropyLoss()

    # ---------------------------------------------------------
    # 第一步：校准梯度掩码 (Mask Calibration)
    # ---------------------------------------------------------
    # 项目要求：通过识别“最不敏感”的参数来校准掩码
    print(f"\n--- 阶段 1: 掩码校准 (Sparsity: {sparsity_ratio}, Batches: {calibration_batches}) ---")

    # 计算敏感度分数 (基于对角费舍尔信息矩阵)
    sensitivity_scores = compute_fisher_sensitivity(
        model, train_loader, criterion, DEVICE, num_batches=calibration_batches
    )

    # 生成掩码：选择最不敏感 (least-sensitive) 的参数进行更新
    # keep_least_sensitive=True 表示敏感度越低，掩码越可能为 1 (允许更新)
    masks = calibrate_masks(
        sensitivity_scores,
        sparsity_ratio=sparsity_ratio,
        keep_least_sensitive=True
    )

    # 统计掩码信息
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    active_params = sum(m.sum().item() for m in masks.values())
    print(f"掩码校准完成。可更新参数占比: {100 * active_params / total_params:.2f}%")

    # ---------------------------------------------------------
    # 第二步：使用 SparseSGDM 进行稀疏微调
    # ---------------------------------------------------------
    # 项目要求：使用 SparseSGDM 并在微调时应用掩码 [cite: 70]
    print("\n--- 阶段 2: 稀疏微调 ---")

    # 初始化你的 SparseSGDM
    optimizer = SparseSGDM(
        model.parameters(),
        lr=0.05,
        momentum=0.9,
        weight_decay=1e-4,
        masks=masks # 传入校准好的掩码
    )

    history = {'train_loss': [], 'test_acc': []}

    # 运行几个微调 Epoch
    epochs = 10
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0

        for inputs, targets in train_loader:
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()

            # SparseSGDM 会在这一步根据掩码过滤梯度 [cite: 55]
            optimizer.step()

            running_loss += loss.item()

        # 评估
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
                outputs = model(inputs)
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

        acc = 100. * correct / total
        avg_loss = running_loss / len(train_loader)
        history['train_loss'].append(avg_loss)
        history['test_acc'].append(acc)

        print(f"Epoch {epoch+1}/{epochs} | Loss: {avg_loss:.4f} | Test Acc: {acc:.2f}%")

    return history

if __name__ == "__main__":
    # 项目建议实验不同的稀疏比例 [cite: 72]
    ratios = [0.05, 0.1, 0.2]
    for r in ratios:
        print(f"\n{'='*50}")
        print(f"实验开始: Sparsity Ratio = {r}")
        run_task_arithmetic_experiment(sparsity_ratio=r)

Loading DINO backbone (ONE TIME ONLY)...


Using cache found in /Users/van/.cache/torch/hub/facebookresearch_dino_main


✓ DINO backbone loaded and cached globally
Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100.0%


Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified

--- Strategy: least_sensitive | Sparsity: 0.1 ---
Calculating sensitivity over 10 batches...
