In [None]:
# mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# import module
import sys
sys.path.append('/content/drive/MyDrive')
from preprocessing import FederatedDataBuilder

In [None]:
import torch
import matplotlib.pyplot as plt
from taskarithmetic import SparseSGDM, compute_fisher_sensitivity, calibrate_masks
from fed_avg_non_iid import DINOCIFAR100 
from preprocessing import FederatedDataBuilder
import torch.nn as nn
from torch.utils.data import DataLoader

def run_strategy(strategy_name, sparsity=0.1):
    """运行指定的扩展策略"""
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data_builder = FederatedDataBuilder()
    train_loader = DataLoader(data_builder.train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(data_builder.test_dataset, batch_size=128, shuffle=False)
    
    model = DINOCIFAR100(num_classes=100).to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    
    # --- 掩码校准逻辑 (Extension Core) ---
    masks = {}
    if strategy_name in ['least_sensitive', 'most_sensitive']:
        scores = compute_fisher_sensitivity(model, train_loader, criterion, DEVICE, num_batches=10)
        masks = calibrate_masks(scores, sparsity_ratio=sparsity, keep_least_sensitive=(strategy_name == 'least_sensitive'))
    
    elif strategy_name in ['low_magnitude', 'high_magnitude']:
        all_weights = torch.cat([p.data.abs().view(-1) for p in model.parameters() if p.requires_grad])
        k = int(all_weights.numel() * sparsity)
        if strategy_name == 'low_magnitude':
            thresh = torch.kthvalue(all_weights, k).values.item()
            for p in model.parameters(): masks[p] = (p.data.abs() <= thresh).float()
        else:
            thresh = torch.kthvalue(all_weights, all_weights.numel() - k).values.item()
            for p in model.parameters(): masks[p] = (p.data.abs() >= thresh).float()
            
    elif strategy_name == 'random':
        for p in model.parameters(): masks[p] = (torch.rand_like(p) <= sparsity).float()

    # --- 训练循环 ---
    optimizer = SparseSGDM(model.parameters(), lr=0.01, masks=masks)
    acc_history = []
    for epoch in range(5): # 扩展部分建议运行至少5-10个epoch以观察趋势
        model.train()
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
            optimizer.zero_grad()
            model(inputs).backward(criterion(model(inputs), targets))
            optimizer.step()
        
        # 评估过程... (此处省略具体evaluate代码)
        # acc_history.append(current_acc)
    
    return acc_history

if __name__ == "__main__":
    strategies = ['least_sensitive', 'most_sensitive', 'low_magnitude', 'high_magnitude', 'random']
    results = {}
    for s in strategies:
        print(f"Running extension strategy: {s}")
        results[s] = run_strategy(s)
    
    # 绘图并保存结果用于 Report [cite: 93]
    plt.plot(...) 
    plt.savefig('extension_results.png')