In [None]:
# mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# import module
import sys
sys.path.append('/content/drive/MyDrive')
from preprocessing import FederatedDataBuilder

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os
import json

# Importing your implemented components
from taskarithmetic import SparseSGDM, compute_fisher_sensitivity, calibrate_masks
from fed_avg_non_iid import DINOCIFAR100 
from preprocessing import FederatedDataBuilder

def run_extension_experiment(strategy='least_sensitive', sparsity_ratio=0.1, calibration_batches=10):
    """
    Implements Project Part 4: Guided Extension.
    Compares different gradient mask calibration rules[cite: 77].
    """
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # 1. Data Preparation
    # Project requirement: use CIFAR-100 dataset [cite: 29]
    data_builder = FederatedDataBuilder(val_split_ratio=0.1)
    train_loader = DataLoader(data_builder.train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(data_builder.test_dataset, batch_size=128, shuffle=False)

    # 2. Model Initialization (DINO ViT-S/16)
    model = DINOCIFAR100(num_classes=100).to(DEVICE)
    criterion = nn.CrossEntropyLoss()

    # ---------------------------------------------------------
    # STEP 1: Mask Calibration [cite: 68]
    # ---------------------------------------------------------
    print(f"\n--- Strategy: {strategy} | Sparsity: {sparsity_ratio} ---")
    
    masks = {}
    
    # Strategy 1 & 2: Sensitivity-based (Fisher Information) [cite: 78]
    if strategy in ['least_sensitive', 'most_sensitive']:
        sensitivity_scores = compute_fisher_sensitivity(
            model, train_loader, criterion, DEVICE, num_batches=calibration_batches
        )
        # keep_least_sensitive=True for 'least_sensitive', False for 'most_sensitive'
        masks = calibrate_masks(
            sensitivity_scores, 
            sparsity_ratio=sparsity_ratio, 
            keep_least_sensitive=(strategy == 'least_sensitive')
        )
    
    # Strategy 3 & 4: Magnitude-based [cite: 79, 80]
    elif strategy in ['low_magnitude', 'high_magnitude']:
        # Flatten all weights to find global magnitude threshold
        all_weights = torch.cat([p.data.abs().view(-1) for p in model.parameters() if p.requires_grad])
        k = int(all_weights.numel() * sparsity_ratio)
        
        if strategy == 'low_magnitude':
            threshold = torch.kthvalue(all_weights, k).values.item()
            for p in model.parameters():
                if p.requires_grad:
                    masks[p] = (p.data.abs() <= threshold).float()
        else: # high_magnitude
            threshold = torch.kthvalue(all_weights, all_weights.numel() - k).values.item()
            for p in model.parameters():
                if p.requires_grad:
                    masks[p] = (p.data.abs() >= threshold).float()
                    
    # Strategy 5: Random selection [cite: 80]
    elif strategy == 'random':
        for p in model.parameters():
            if p.requires_grad:
                masks[p] = (torch.rand_like(p) <= sparsity_ratio).float()

    # ---------------------------------------------------------
    # STEP 2: Sparse Fine-tuning with SparseSGDM [cite: 70]
    # ---------------------------------------------------------
    optimizer = SparseSGDM(
        model.parameters(), 
        lr=0.01, 
        momentum=0.9, 
        weight_decay=1e-4,
        masks=masks
    )

    history = {'train_loss': [], 'test_acc': []}
    epochs = 10 # Increase epochs for clearer comparison in the report [cite: 50]

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step() # Applies the calibrated mask [cite: 55]
            running_loss += loss.item()
        
        # Evaluation [cite: 105]
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
                outputs = model(inputs)
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        
        acc = 100. * correct / total
        history['test_acc'].append(acc)
        print(f"Epoch {epoch+1} | Loss: {running_loss/len(train_loader):.4f} | Acc: {acc:.2f}%")

    return history

if __name__ == "__main__":
    strategies = ['least_sensitive', 'most_sensitive', 'low_magnitude', 'high_magnitude', 'random']
    results = {}
    
    for s in strategies:
        results[s] = run_extension_experiment(strategy=s, sparsity_ratio=0.1)
    
    # Plotting for the report [cite: 93, 100]
    plt.figure(figsize=(10, 6))
    for s, h in results.items():
        plt.plot(range(1, 11), h['test_acc'], label=s)
    plt.xlabel('Epoch')
    plt.ylabel('Test Accuracy (%)')
    plt.title('Comparison of Mask Calibration Rules (Extension)')
    plt.legend()
    plt.grid(True)
    plt.savefig('extension_comparison.png')
    plt.show()

Loading DINO backbone (ONE TIME ONLY)...


Using cache found in /Users/van/.cache/torch/hub/facebookresearch_dino_main


âœ“ DINO backbone loaded and cached globally
Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100.0%


Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified

--- Strategy: least_sensitive | Sparsity: 0.1 ---
Calculating sensitivity over 10 batches...
