<a href="https://colab.research.google.com/github/ILoveCoder999/FederatedLearning/blob/master/fed_avg.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# import module
import sys
sys.path.append('/content/drive/MyDrive')
from preprocessing import FederatedDataBuilder

Mounted at /content/drive


In [None]:
"""
Fixed FedAvg Implementation for CIFAR-100
==========================================
Constraint: J MUST = 4 (as required by assignment)
Target: 30 rounds instead of 50

Key fixes when J is constrained to 4:
1. Significantly increase learning rate: 0.01 -> 0.1 or higher
2. Better initialization
3. Potential backbone unfreezing
4. Optimize within the 4-step constraint
"""

import torch
import torch.nn as nn
import torch.optim as optim
import copy
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Subset

from preprocessing import FederatedDataBuilder
from centralized_model import DINOCIFAR100


def initialize_model(num_classes=100):
    """
    Initialize DINO model with proper head initialization.
    """
    model = DINOCIFAR100(num_classes=num_classes)

    # Better initialization for the classification head
    # He initialization might work better with ReLU-like activations
    nn.init.kaiming_normal_(model.head.weight, mode='fan_out')
    nn.init.zeros_(model.head.bias)

    print("✓ Model head initialized with Kaiming normal")

    return model


def fed_avg_aggregate(global_model, local_weights, client_sample_counts):
    """
    Performs the weighted averaging of local model weights.
    """
    global_dict = copy.deepcopy(global_model.state_dict())
    total_samples = sum(client_sample_counts)

    # Initialize aggregated weights to zero
    for k in global_dict.keys():
        if 'num_batches_tracked' not in k:
            global_dict[k] = global_dict[k] * 0.0

    # Weighted average
    for i in range(len(local_weights)):
        ratio = client_sample_counts[i] / total_samples
        weights = local_weights[i]
        for k in global_dict.keys():
            if 'num_batches_tracked' not in k:
                global_dict[k] += weights[k] * ratio

    return global_dict


class LocalClient:
    """
    Local client with J=4 constraint.
    """
    def __init__(self, client_id, dataset, indices, device, model_class):
        self.client_id = client_id
        self.indices = indices
        self.device = device
        self.model_class = model_class

        # IMPORTANT: Larger batch size for more stable gradients with only 4 steps
        self.trainloader = DataLoader(
            Subset(dataset, list(indices)),
            batch_size=128,  # Increased from 64 to 128
            shuffle=True,
            num_workers=2,
            pin_memory=True
        )

    def train(self, global_weights, local_steps=4, lr=0.1):
        """
        Train for exactly 4 steps (as required).

        KEY FIXES for J=4:
        1. Much higher learning rate (0.1 instead of 0.01)
        2. Larger batch size (128 instead of 64)
        3. Potentially unfreeze last layers of backbone
        """
        # Initialize local model
        local_model = self.model_class(num_classes=100).to(self.device)
        local_model.load_state_dict(global_weights)
        local_model.train()

        # CRITICAL FIX: Use SGD with higher learning rate and momentum
        optimizer = optim.SGD(
            local_model.parameters(),
            lr=lr,              # Much higher LR for 4 steps
            momentum=0.9,
            weight_decay=1e-4,
            nesterov=True       # Nesterov momentum for better convergence
        )

        criterion = nn.CrossEntropyLoss()

        # Train for exactly 4 steps
        step_count = 0
        epoch_loss = []
        iterator = iter(self.trainloader)

        while step_count < local_steps:
            try:
                inputs, targets = next(iterator)
            except StopIteration:
                iterator = iter(self.trainloader)
                inputs, targets = next(iterator)

            inputs, targets = inputs.to(self.device), targets.to(self.device)

            optimizer.zero_grad()
            outputs = local_model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()

            # Gradient clipping to prevent explosion with high LR
            torch.nn.utils.clip_grad_norm_(local_model.parameters(), max_norm=1.0)

            optimizer.step()

            epoch_loss.append(loss.item())
            step_count += 1

        return local_model.state_dict(), sum(epoch_loss)/len(epoch_loss)


def evaluate_global(model, test_loader, device):
    """Evaluate global model on server test set"""
    model.eval()
    correct = 0
    total = 0
    loss_sum = 0
    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss_sum += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    return loss_sum / len(test_loader), 100. * correct / total


def run_fedavg_experiment_j4(
    K=100,
    C=0.1,
    J=4,                # FIXED: Must be 4
    rounds=30,          # Reduced from 50 to 30
    lr=0.1,             # CRITICAL: Increased from 0.01 to 0.1
    batch_size=128      # Increased from 64
):
    """
    FedAvg with J=4 constraint.

    Strategy to overcome J=4 limitation:
    1. Higher learning rate (0.1)
    2. Larger batch size (128)
    3. Better initialization
    4. Gradient clipping
    5. Nesterov momentum
    """
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print(f"\n{'='*70}")
    print(f"FedAvg with J=4 Constraint (Assignment Requirement)")
    print(f"{'='*70}")
    print(f"K={K}, C={C}, J={J} (FIXED), Rounds={rounds}")
    print(f"Learning rate: {lr} (HIGH for J=4)")
    print(f"Batch size: {batch_size} (LARGE for stable gradients)")
    print(f"Device: {DEVICE}")
    print(f"{'='*70}\n")

    # Data Preparation
    print("Preparing data...")
    data_builder = FederatedDataBuilder(val_split_ratio=0.1, K=K)
    client_dict = data_builder.get_iid_partition()
    data_builder.verify_partition(client_dict)

    test_loader = DataLoader(
        data_builder.test_dataset,
        batch_size=256,  # Larger batch for faster evaluation
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    # Global Model Initialization
    print("\nInitializing global model...")
    global_model = initialize_model(num_classes=100).to(DEVICE)

    history = {'loss': [], 'accuracy': [], 'round': []}

    # Federated Training
    print(f"\nStarting federated training...")
    m = max(int(C * K), 1)
    print(f"Selecting {m} clients per round")
    print(f"Each client trains for {J} steps with batch_size={batch_size}")
    print(f"Effective samples per client per round: {J * batch_size} = {J*batch_size}\n")

    for r in range(rounds):
        # Client Selection
        selected_clients = np.random.choice(range(K), m, replace=False)

        local_weights = []
        client_sample_counts = []
        client_losses = []

        # Local Training
        global_weights = copy.deepcopy(global_model.state_dict())

        for client_idx in selected_clients:
            client = LocalClient(
                client_id=client_idx,
                dataset=data_builder.train_dataset,
                indices=client_dict[client_idx],
                device=DEVICE,
                model_class=DINOCIFAR100
            )

            # Train with high LR
            w_local, loss_local = client.train(
                global_weights,
                local_steps=J,  # Fixed at 4
                lr=lr
            )

            local_weights.append(w_local)
            client_sample_counts.append(len(client_dict[client_idx]))
            client_losses.append(loss_local)

        # Aggregation
        new_weights = fed_avg_aggregate(global_model, local_weights, client_sample_counts)
        global_model.load_state_dict(new_weights)

        # Evaluation
        test_loss, test_acc = evaluate_global(global_model, test_loader, DEVICE)
        history['loss'].append(test_loss)
        history['accuracy'].append(test_acc)
        history['round'].append(r + 1)

        # Print progress
        avg_client_loss = sum(client_losses) / len(client_losses)
        print(f"Round {r+1:2d}/{rounds} -> "
              f"Train Loss: {avg_client_loss:.4f} | "
              f"Test Loss: {test_loss:.4f} | "
              f"Test Acc: {test_acc:.2f}%")

    # Results
    print(f"\n{'='*70}")
    print(f"Training Complete!")
    print(f"{'='*70}")
    print(f"Final Test Accuracy: {history['accuracy'][-1]:.2f}%")
    print(f"Best Test Accuracy: {max(history['accuracy']):.2f}%")
    print(f"{'='*70}\n")

    # Plot
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(history['round'], history['accuracy'], 'b-o', linewidth=2, markersize=4)
    plt.title(f'FedAvg (J=4, LR={lr}) - Test Accuracy', fontsize=14, fontweight='bold')
    plt.xlabel('Communication Rounds', fontsize=12)
    plt.ylabel('Test Accuracy (%)', fontsize=12)
    plt.grid(True, alpha=0.3)

    plt.subplot(1, 2, 2)
    plt.plot(history['round'], history['loss'], 'r-o', linewidth=2, markersize=4)
    plt.title('Test Loss', fontsize=14, fontweight='bold')
    plt.xlabel('Communication Rounds', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig('fedavg_j4_results.png', dpi=150, bbox_inches='tight')
    print("Figure saved to: fedavg_j4_results.png")
    plt.show()

    return history, global_model


if __name__ == "__main__":


    # Run with different learning rates to find best one

    # EXPERIMENT 1: LR = 0.1 (recommended)
    print("\n" + "="*70)
    print("EXPERIMENT 1: Learning Rate = 0.1")
    print("="*70)
    history1, model1 = run_fedavg_experiment_j4(
        K=100,
        C=0.1,
        J=4,
        rounds=30,
        lr=0.1,
        batch_size=128
    )

    # EXPERIMENT 2: LR = 0.05 (conservative)
    # Uncomment to try a more conservative approach
    """
    print("\n" + "="*70)
    print("EXPERIMENT 2: Learning Rate = 0.05")
    print("="*70)
    history2, model2 = run_fedavg_experiment_j4(
        K=100,
        C=0.1,
        J=4,
        rounds=30,
        lr=0.05,
        batch_size=128
    )
    """

    # EXPERIMENT 3: LR = 0.2 (aggressive)
    # Uncomment to try a more aggressive approach
    """
    print("\n" + "="*70)
    print("EXPERIMENT 3: Learning Rate = 0.2")
    print("="*70)
    history3, model3 = run_fedavg_experiment_j4(
        K=100,
        C=0.1,
        J=4,
        rounds=30,
        lr=0.2,
        batch_size=128
    )
    """

    print("\n✓ Experiment completed!")
    print("\nWith J=4 constraint, realistic expectations:")
    print("  - With original config (LR=0.01): ~1-2% accuracy")
    print("  - With optimized config (LR=0.1): ~10-20% accuracy")
    print("  - Theoretical limit with J=4: ~20-25% accuracy")
    print("\nNote: J=4 is a severe limitation. For >30% accuracy,")
    print("      you would need J≥50 or more rounds (100+).")


    ╔════════════════════════════════════════════════════════════════╗
    ║         FedAvg Fixed Implementation for CIFAR-100              ║
    ║                                                                ║
    ║  This version includes critical fixes for the 1.8% accuracy:  ║
    ║  ✓ Increased local steps: 4 → 100                            ║
    ║  ✓ Increased learning rate: 0.01 → 0.05                      ║
    ║  ✓ Better model initialization                                 ║
    ║  ✓ Option for epoch-based training                            ║
    ╚════════════════════════════════════════════════════════════════╝
    

[OPTION 1] Running with 100 steps per round...

Running FedAvg Experiment (FIXED VERSION)
K=100, C=0.1, Rounds=100
Training mode: 100 LOCAL STEPS
Learning rate: 0.05
Device: cuda

Preparing data...
Creating IID partition for 100 clients...

Verifying Partition
Total samples: 45000/45000
No overlap
✓ Avg classes per client: 98.8


Initializing global model...

Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


✓ Model head initialized with Xavier uniform

Starting federated training...
Selecting 10 clients per round

Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Round   1/100 -> Train Loss: 208.1513 | Test Loss: 82.0057 | Test Acc: 0.89%
Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Round   2/100 -> Train Loss: 21.4596 | Test Loss: 26.9646 | Test Acc: 1.08%
Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Round   3/100 -> Train Loss: 7.9617 | Test Loss: 4.9241 | Test Acc: 1.83%
Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Round   4/100 -> Train Loss: 4.9095 | Test Loss: 4.7356 | Test Acc: 2.03%
Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Round   5/100 -> Train Loss: 4.7303 | Test Loss: 4.5754 | Test Acc: 2.20%
Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Round   6/100 -> Train Loss: 4.7250 | Test Loss: 4.5639 | Test Acc: 2.90%
Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Round   7/100 -> Train Loss: 4.7066 | Test Loss: 4.5517 | Test Acc: 2.57%
Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Round   8/100 -> Train Loss: 4.6062 | Test Loss: 4.5344 | Test Acc: 2.40%
Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main
