# Federated Learning Under the Lens of Task Arithmetic

This notebook runs all experiments for the project:
1. Centralized baseline training
2. FedAvg with IID/non-IID sharding
3. Sparse fine-tuning with task arithmetic
4. Mask strategy comparison (extension)

**Important**: This notebook is designed for Google Colab with GPU runtime.

## Setup

In [None]:
!git clone https://github.com/VitoFe/amlproject
%cd amlproject
# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Install uv and dependencies
!pip install --upgrade uv
!uv sync

In [None]:
# Set checkpoint directory
CHECKPOINT_DIR = './fl_checkpoints'
LOG_DIR = './fl_logs'
# Import project modules
from src.data.dataset import get_cifar100_datasets, get_dataloaders
from src.models.dino_vit import create_dino_vit
from src.training.centralized import CentralizedTrainer
from src.training.federated import FederatedTrainer
from src.training.federated_sparse import FederatedSparseTrainer
from src.utils.seed import set_seed
from src.utils.logging import setup_logging
from src.utils.visualization import plot_training_curves, plot_comparison

## Configuration

In [None]:
# Experiment configuration
CONFIG = {
    'seed': 42,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'data_dir': './data',
    'checkpoint_dir': CHECKPOINT_DIR,
    'log_dir': LOG_DIR,
    
    # Data
    'batch_size': 64,
    'val_split': 0.1,
    
    # Centralized
    'centralized_epochs': 50,
    'centralized_lr': 0.001,
    
    # Federated
    'num_clients': 100,       # K
    'participation_rate': 0.1, # C
    'local_steps': 4,          # J
    'num_rounds': 500,
    'federated_lr': 0.01,
    
    # Non-IID
    'nc_values': [1, 5, 10, 50],
    'j_values': [4, 8, 16],
    
    # Sparse
    'sparsity_ratio': 0.9,
    'calibration_rounds': 5,
    'fisher_samples': 512,
    
    # Multiple runs
    'num_runs': 3
}

print("Configuration loaded")
print(f"Device: {CONFIG['device']}")

## Load Data

In [None]:
# Load CIFAR-100
set_seed(CONFIG['seed'])

train_dataset, val_dataset, test_dataset = get_cifar100_datasets(
    data_dir=CONFIG['data_dir'],
    val_split=CONFIG['val_split'],
    seed=CONFIG['seed']
)

train_loader, val_loader, test_loader = get_dataloaders(
    train_dataset, val_dataset, test_dataset,
    batch_size=CONFIG['batch_size'],
    num_workers=2
)

print(f"Train: {len(train_dataset)} samples")
print(f"Val: {len(val_dataset)} samples")
print(f"Test: {len(test_dataset)} samples")

## Experiment 1: Centralized Baseline

In [None]:
# Uncomment to delete old checkpoints and start fresh
# import shutil, os
# shutil.rmtree(CHECKPOINT_DIR, ignore_errors=True)
# os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Create model with regularization
model = create_dino_vit(
    num_classes=100, 
    device=CONFIG['device'],
    dropout=0.3,
    freeze_layers=6
)
print(f"Model parameters: {model.count_parameters()}")

In [None]:
# Train centralized baseline
centralized_config = {
    'epochs': 20,
    'learning_rate': 0.0005,
    'momentum': 0.9,
    'weight_decay': 0.01,
    'scheduler': 'cosine',
    'label_smoothing': 0.1,
    'early_stopping_patience': 10,
    'checkpoint_dir': CHECKPOINT_DIR,
    'log_dir': LOG_DIR
}

trainer = CentralizedTrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    config=centralized_config,
    device=CONFIG['device'],
    experiment_name='centralized_regularized'
)

# (resume=False to start fresh, or True to continue)
centralized_results = trainer.train(resume=True, save_every=5)
print(f"\nCentralized Test Accuracy: {centralized_results['accuracy']:.4f}")

In [None]:
# Plot training curves
plot_training_curves(
    trainer.get_metrics_history(),
    title='Centralized Training',
    save_path=f"{CONFIG['log_dir']}/centralized_curves.png"
)

## Experiment 2: FedAvg with IID Sharding

In [None]:
# FedAvg with IID data distribution
set_seed(CONFIG['seed'])
model = create_dino_vit(num_classes=100, device=CONFIG['device'])

federated_config = {
    'num_clients': CONFIG['num_clients'],
    'participation_rate': CONFIG['participation_rate'],
    'local_steps': CONFIG['local_steps'],
    'num_rounds': CONFIG['num_rounds'],
    'learning_rate': CONFIG['federated_lr'],
    'momentum': 0.9,
    'weight_decay': 1e-4,
    'early_stopping_patience': 10,
    'batch_size': CONFIG['batch_size'],
    'checkpoint_dir': CONFIG['checkpoint_dir'],
    'log_dir': CONFIG['log_dir'],
    'seed': CONFIG['seed'],
    'sharding': {'strategy': 'iid', 'nc': 100}
}

trainer = FederatedTrainer(
    model=model,
    train_dataset=train_dataset,
    val_loader=val_loader,
    test_loader=test_loader,
    config=federated_config,
    device=CONFIG['device'],
    experiment_name='fedavg_iid'
)

fedavg_iid_results = trainer.train(resume=True)
print(f"\nFedAvg (IID) Test Accuracy: {fedavg_iid_results['accuracy']:.4f}")

## Experiment 3: FedAvg with Non-IID Sharding (Varying Nc)

In [None]:
# Test different levels of data heterogeneity
noniid_results = {}

for nc in CONFIG['nc_values']:
    print(f"\n{'='*50}")
    print(f"Testing Nc = {nc} classes per client")
    print(f"{'='*50}")
    
    set_seed(CONFIG['seed'])
    model = create_dino_vit(num_classes=100, device=CONFIG['device'])
    
    config = {
        'num_clients': CONFIG['num_clients'],
        'participation_rate': CONFIG['participation_rate'],
        'local_steps': CONFIG['local_steps'],
        'num_rounds': CONFIG['num_rounds'],
        'learning_rate': CONFIG['federated_lr'],
        'momentum': 0.9,
        'weight_decay': 1e-6,
        'early_stopping_patience': 10,
        'batch_size': CONFIG['batch_size'],
        'checkpoint_dir': CONFIG['checkpoint_dir'],
        'log_dir': CONFIG['log_dir'],
        'seed': CONFIG['seed'],
        'sharding': {'strategy': 'non_iid', 'nc': nc}
    }
    
    trainer = FederatedTrainer(
        model=model,
        train_dataset=train_dataset,
        val_loader=val_loader,
        test_loader=test_loader,
        config=config,
        device=CONFIG['device'],
        experiment_name=f'fedavg_noniid_nc{nc}'
    )
    
    results = trainer.train(resume=True)
    noniid_results[nc] = results['accuracy']
    print(f"Nc={nc}: Test Accuracy = {results['accuracy']:.4f}")

In [None]:
print("\nHeterogeneity Experiment Results:")
print("-" * 30)
for nc, acc in noniid_results.items():
    print(f"Nc = {nc}: {acc:.4f}")

## Experiment 4: Sparse Fine-tuning (Task Arithmetic)

In [None]:
# Federated sparse fine-tuning with least-sensitive masking
set_seed(CONFIG['seed'])
model = create_dino_vit(num_classes=100, device=CONFIG['device'])

sparse_config = {
    'num_clients': CONFIG['num_clients'],
    'participation_rate': CONFIG['participation_rate'],
    'local_steps': CONFIG['local_steps'],
    'num_rounds': CONFIG['num_rounds'],
    'learning_rate': CONFIG['federated_lr'],
    'momentum': 0.9,
    'weight_decay': 1e-4,
    'batch_size': CONFIG['batch_size'],
    'checkpoint_dir': CONFIG['checkpoint_dir'],
    'log_dir': CONFIG['log_dir'],
    'seed': CONFIG['seed'],
    'sharding': {'strategy': 'iid', 'nc': 100},
    'sparse': {
        'sparsity_ratio': CONFIG['sparsity_ratio'],
        'calibration_rounds': CONFIG['calibration_rounds'],
        'mask_strategy': 'least_sensitive',
        'fisher_samples': CONFIG['fisher_samples']
    }
}

trainer = FederatedSparseTrainer(
    model=model,
    train_dataset=train_dataset,
    val_loader=val_loader,
    test_loader=test_loader,
    config=sparse_config,
    device=CONFIG['device'],
    experiment_name='fedavg_sparse_least_sensitive'
)

sparse_results = trainer.train(resume=True, calibrate_masks=True)
print(f"\nSparse Fine-tuning Test Accuracy: {sparse_results['accuracy']:.4f}")

## Experiment 5 (Extension): Mask Strategy Comparison

In [None]:
# Compare all mask strategies
strategies = [
    'least_sensitive',
    'most_sensitive',
    'lowest_magnitude',
    'highest_magnitude',
    'random'
]

strategy_results = {}

for strategy in strategies:
    print(f"\n{'='*50}")
    print(f"Testing strategy: {strategy}")
    print(f"{'='*50}")
    
    set_seed(CONFIG['seed'])
    model = create_dino_vit(num_classes=100, device=CONFIG['device'])
    
    config = sparse_config.copy()
    config['sparse'] = {
        'sparsity_ratio': CONFIG['sparsity_ratio'],
        'calibration_rounds': CONFIG['calibration_rounds'],
        'mask_strategy': strategy,
        'fisher_samples': CONFIG['fisher_samples']
    }
    
    trainer = FederatedSparseTrainer(
        model=model,
        train_dataset=train_dataset,
        val_loader=val_loader,
        test_loader=test_loader,
        config=config,
        device=CONFIG['device'],
        experiment_name=f'fedavg_sparse_{strategy}'
    )
    
    results = trainer.train(resume=True, calibrate_masks=True)
    strategy_results[strategy] = results['accuracy']
    print(f"{strategy}: Test Accuracy = {results['accuracy']:.4f}")

In [None]:
print("\nMask Strategy Comparison:")
print("="*60)
for strategy, acc in sorted(strategy_results.items(), key=lambda x: x[1], reverse=True):
    print(f"{strategy:25s}: {acc:.4f}")

plot_comparison(
    {s: {'test_accuracy': a} for s, a in strategy_results.items()},
    metric='test_accuracy',
    title='Mask Strategy Comparison',
    save_path=f"{CONFIG['log_dir']}/strategy_comparison.png"
)

## Summary of Results

In [None]:
print("EXPERIMENT SUMMARY")
print("="*60)

print(f"\n1. Centralized Baseline: {centralized_results['accuracy']:.4f}")
print(f"\n2. FedAvg (IID): {fedavg_iid_results['accuracy']:.4f}")

print("\n3. FedAvg (Non-IID):")
for nc, acc in noniid_results.items():
    print(f"   Nc={nc}: {acc:.4f}")

print(f"\n4. Sparse Fine-tuning: {sparse_results['accuracy']:.4f}")

print("\n5. Mask Strategy Comparison:")
for strategy, acc in sorted(strategy_results.items(), key=lambda x: x[1], reverse=True):
    print(f"   {strategy}: {acc:.4f}")

In [None]:
# Save results to JSON
import json

all_results = {
    'centralized': centralized_results['accuracy'],
    'fedavg_iid': fedavg_iid_results['accuracy'],
    'fedavg_noniid': noniid_results,
    'sparse': sparse_results['accuracy'],
    'mask_strategies': strategy_results
}

with open(f"{CONFIG['log_dir']}/all_results.json", 'w') as f:
    json.dump(all_results, f, indent=2)

print(f"Results saved to {CONFIG['log_dir']}/all_results.json")