<a href="https://colab.research.google.com/github/ILoveCoder999/FederatedLearning/blob/master/fed_avg.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
# mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# import module
import sys
sys.path.append('/content/drive/MyDrive')
from preprocessing import FederatedDataBuilder

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import copy
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Subset

from preprocessing import FederatedDataBuilder
from centralized_model import DINOCIFAR100


# ============================================================
# CRITICAL FIX: Model Initialization Function
# ============================================================
def initialize_model(num_classes=100):
    """
    Initialize DINO model with proper head initialization.
    This fixes the 1% accuracy issue.

    IMPORTANT: Always use this function instead of DINOCIFAR100() directly!
    """
    model = DINOCIFAR100(num_classes=num_classes)

    # Properly initialize the classification head
    # Xavier/Glorot initialization for better convergence
    nn.init.xavier_uniform_(model.head.weight)
    nn.init.zeros_(model.head.bias)

    print("✓ Model head initialized with Xavier uniform")

    return model


def fed_avg_aggregate(global_model, local_weights, client_sample_counts):
    """
    Performs the weighted averaging of local model weights.
    w_global = sum(n_k * w_k) / sum(n_k)
    """
    # Create a deep copy of the global model state to update
    global_dict = global_model.state_dict()

    # Calculate total samples in this round for weighted average
    total_samples = sum(client_sample_counts)

    # Initialize the aggregated dictionary
    # We take the first local model as the base (scaled by its weight)
    first_weights = local_weights[0]
    first_ratio = client_sample_counts[0] / total_samples

    for k in global_dict.keys():
        # Handle strict type checking for scalars (long/float)
        if 'num_batches_tracked' in k:
            global_dict[k] = first_weights[k]
        else:
            global_dict[k] = first_weights[k] * first_ratio

    # Add the rest of the models
    for i in range(1, len(local_weights)):
        ratio = client_sample_counts[i] / total_samples
        weights = local_weights[i]
        for k in global_dict.keys():
            if 'num_batches_tracked' not in k:
                global_dict[k] += weights[k] * ratio

    return global_dict


class LocalClient:
    """
    Simulates a local client training process.
    """
    def __init__(self, client_id, dataset, indices, device, model_class):
        self.client_id = client_id
        self.dataset = dataset
        self.indices = indices
        self.device = device
        self.model_class = model_class

        # Create a local dataloader
        # Increased batch size from 32 to 64 for better gradient estimates
        self.trainloader = DataLoader(
            Subset(dataset, list(indices)),
            batch_size=64,  # Improved from 32
            shuffle=True
        )

    def train(self, global_weights, local_steps=4, lr=0.01):
        """
        Runs local training for J steps (not epochs).
        """
        # 1. Initialize local model with global weights
        local_model = self.model_class(num_classes=100).to(self.device)
        local_model.load_state_dict(global_weights)
        local_model.train()

        # 2. Setup Optimizer (SGD is standard for FedAvg)
        # Added weight decay for regularization
        optimizer = optim.SGD(
            local_model.parameters(),
            lr=lr,
            momentum=0.9,
            weight_decay=1e-4  # Added for better generalization
        )
        criterion = nn.CrossEntropyLoss()

        # 3. Local Training Loop (J steps)
        step_count = 0
        epoch_loss = []

        # Create an iterator that resets if we run out of data
        iterator = iter(self.trainloader)

        while step_count < local_steps:
            try:
                inputs, targets = next(iterator)
            except StopIteration:
                # Restart iterator if dataset is exhausted (though unlikely for J=4)
                iterator = iter(self.trainloader)
                inputs, targets = next(iterator)

            inputs, targets = inputs.to(self.device), targets.to(self.device)

            optimizer.zero_grad()
            outputs = local_model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            epoch_loss.append(loss.item())
            step_count += 1

        return local_model.state_dict(), sum(epoch_loss)/len(epoch_loss)


def evaluate_global(model, test_loader, device):
    """Evaluate global model on server test set"""
    model.eval()
    correct = 0
    total = 0
    loss_sum = 0
    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss_sum += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    return loss_sum / len(test_loader), 100. * correct / total


def run_fedavg_experiment():
    # ---------------------------------------------------------
    # Configuration
    # ---------------------------------------------------------
    K = 100             # Total clients
    C = 0.1             # Fraction of clients
    J = 4               # Local steps (NOT Epochs)
    ROUNDS = 50         # Number of communication rounds
    LR = 0.01           # Learning rate
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print(f"\n{'='*70}")
    print(f"Running FedAvg Experiment")
    print(f"{'='*70}")
    print(f"K={K}, C={C}, J={J} steps, Rounds={ROUNDS}")
    print(f"Device: {DEVICE}")
    print(f"{'='*70}\n")

    # 1. Data Preparation
    print("Preparing data...")
    data_builder = FederatedDataBuilder(val_split_ratio=0.1, K=K)
    client_dict = data_builder.get_iid_partition()  # IID sharding

    # Verify partition
    data_builder.verify_partition(client_dict)

    test_loader = DataLoader(data_builder.test_dataset, batch_size=128, shuffle=False)

    # 2. Global Model Initialization
    #  CRITICAL FIX: Use initialize_model() instead of DINOCIFAR100() directly
    print("\nInitializing global model...")
    global_model = initialize_model(num_classes=100).to(DEVICE)

    # History for plotting
    history = {'loss': [], 'accuracy': [], 'round': []}

    # 3. Federated Training Loop
    print(f"\nStarting federated training...")
    m = max(int(C * K), 1)  # Number of clients to select
    print(f"Selecting {m} clients per round\n")

    for r in range(ROUNDS):
        # a. Client Selection
        selected_clients = np.random.choice(range(K), m, replace=False)

        local_weights = []
        client_sample_counts = []

        # b. Local Training (Sequential Simulation)
        global_weights = copy.deepcopy(global_model.state_dict())

        for client_idx in selected_clients:
            # Create client interface
            client = LocalClient(
                client_id=client_idx,
                dataset=data_builder.train_dataset,
                indices=client_dict[client_idx],
                device=DEVICE,
                model_class=DINOCIFAR100
            )

            # Train locally
            w_local, loss_local = client.train(global_weights, local_steps=J, lr=LR)

            local_weights.append(w_local)
            client_sample_counts.append(len(client_dict[client_idx]))

        # c. Aggregation (FedAvg)
        new_weights = fed_avg_aggregate(global_model, local_weights, client_sample_counts)
        global_model.load_state_dict(new_weights)

        # d. Evaluation
        test_loss, test_acc = evaluate_global(global_model, test_loader, DEVICE)
        history['loss'].append(test_loss)
        history['accuracy'].append(test_acc)
        history['round'].append(r + 1)

        # Print progress (more frequent at the beginning)
        if r < 10 or (r + 1) % 5 == 0:
            print(f"Round {r+1}/{ROUNDS} -> Loss: {test_loss:.4f} | Accuracy: {test_acc:.2f}%")

    # 4. Plot Results
    print(f"\n{'='*70}")
    print(f"Training Complete!")
    print(f"{'='*70}")
    print(f"Final Test Accuracy: {history['accuracy'][-1]:.2f}%")
    print(f"Best Test Accuracy: {max(history['accuracy']):.2f}%")
    print(f"{'='*70}\n")

    plt.figure(figsize=(10, 5))
    plt.plot(history['round'], history['accuracy'], 'b-o', linewidth=2)
    plt.title(f'FedAvg (IID) Performance (J={J})', fontsize=14, fontweight='bold')
    plt.xlabel('Communication Rounds', fontsize=12)
    plt.ylabel('Test Accuracy (%)', fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig('fedavg_iid_results.png', dpi=150, bbox_inches='tight')
    print("Figure saved to: fedavg_iid_results.png")
    plt.show()

    return history, global_model


if __name__ == "__main__":
    # Run the experiment
    history, model = run_fedavg_experiment()

    print("\n✓ Experiment completed successfully!")
    print("\nNext steps:")
    print("1. Check the results plot: fedavg_iid_results.png")
    print("2. Run experiments with Non-IID data")
    print("3. Try different values of J (local steps)")

In [None]:
if __name__ == "__main__":
    # Ensure dependencies (FederatedDataBuilder, DINOCIFAR100) are in scope
    run_fedavg_experiment()

Running FedAvg | K=100, C=0.1, J=4 steps | Device: cuda


100%|██████████| 169M/169M [00:08<00:00, 20.7MB/s]


Creating IID partition for 100 clients...
Downloading/Loading DINO ViT-S/16...
Downloading: "https://github.com/facebookresearch/dino/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" to /root/.cache/torch/hub/checkpoints/dino_deitsmall16_pretrain.pth


100%|██████████| 82.7M/82.7M [00:03<00:00, 22.7MB/s]



--- Round 1/50 ---
Selected clients: [12 36 37 69 76 93  1 29 82 72]
Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Global Model Stats -> Loss: 8.3684 | Accuracy: 0.81%

--- Round 2/50 ---
Selected clients: [94 13 80 82 58 64 54 45 38 42]
Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Global Model Stats -> Loss: 8.5881 | Accuracy: 1.00%

--- Round 3/50 ---
Selected clients: [39 43 56 64 10 20 54 80 75 44]
Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Global Model Stats -> Loss: 9.0631 | Accuracy: 1.00%

--- Round 4/50 ---
Selected clients: [50 21 79 73 94 39 45  3 97 70]
Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Global Model Stats -> Loss: 8.6098 | Accuracy: 1.00%

--- Round 5/50 ---
Selected clients: [38 56 92 87 13 31 72 25 81 80]
Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Global Model Stats -> Loss: 8.1339 | Accuracy: 1.29%

--- Round 6/50 ---
Selected clients: [20 94 73 45  4 60  1 66  7 30]
Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Global Model Stats -> Loss: 7.1370 | Accuracy: 1.00%

--- Round 7/50 ---
Selected clients: [88 55 77 40 36 26  9 25 60 93]
Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Global Model Stats -> Loss: 6.4344 | Accuracy: 1.00%

--- Round 8/50 ---
Selected clients: [77 44 74 94  3 26 56 12 46 91]
Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Global Model Stats -> Loss: 6.1984 | Accuracy: 1.07%

--- Round 9/50 ---
Selected clients: [77 13 28  4 17 51 40 46 60 80]
Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Global Model Stats -> Loss: 5.9003 | Accuracy: 1.31%

--- Round 10/50 ---
Selected clients: [61 74 31 36 43 23 26 46 72 96]
Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main
