## Tut1-2NN

In [4]:
import numpy as np
from statistics import mean, pstdev
import random
import os
import torch
import torch.nn as nn 
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms

Reproducible

In [5]:
def set_seed(seed: int = 2025):
    random.seed(seed)
    np.random.seed(seed)                    # Fix NumPy random seed
    torch.manual_seed(seed)                 # Fix PyTorch random seed

2NN

In [6]:
class TwoNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 200)
        self.fc2 = nn.Linear(200, 200)
        self.out = nn.Linear(200, 10) 

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.out(x)

IID, where the data is shuffled, and then partitioned into 100 clients each receiving 600 examples.

In [7]:
def make_iid_parts(train_ds, K=100):
    idx = np.arange(len(train_ds))              # Get all sample indices
    np.random.shuffle(idx)                      # Shuffle globally (ensures IID)
    parts = np.array_split(idx, K)              # Split into K parts evenly
    return [p.astype(int) for p in parts]       # Convert to int, needed by Subset

Non-IID,
where we first sort the data by digit label, divide it into 200 shards of size 300, and assign each of 100 clients 2 shards.

In [8]:
def make_pathological_noniid_parts(train_ds, K=100):
    n = len(train_ds)                           # MNIST training set has 60,000 samples
    # Extract labels and create (idx, label) mapping key
    labels = np.fromiter((train_ds[i][1] for i in range(n)), dtype=np.int64, count=n)
    idx_sorted = np.argsort(labels)             # Sort indices by label (so each shard is almost single-class)

    shards = 2 * K                              # Paper setting: 200 shards
    shard_size = n // shards                    # Each shard has 300 samples
    shard_list = [idx_sorted[i*shard_size:(i+1)*shard_size] for i in range(shards)]
    np.random.shuffle(shard_list)               # Shuffle shard order to avoid bias

    parts = []
    for k in range(K):
        # Each client gets 2 shards (so most clients only see 2 digit classes → non-IID)
        part = np.concatenate([shard_list[2*k], shard_list[2*k + 1]])
        parts.append(part.astype(int))
    return parts

ClientUpdate

In [9]:
def client_update(global_model, dataset, indices, E=1, B=10, lr=0.1, device=torch.device("cpu")):
    # 1) Copy the global model locally (each client trains locally starting from the same weights)
    model = TwoNN().to(device)
    model.load_state_dict({k: v.detach().clone() for k, v in global_model.state_dict().items()})
    model.train()  # Switch to training mode (enable Dropout/BN, etc.)

    # 2) Define loss function and optimizer (paper uses SGD; simplified here)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    # 3) Build the client’s DataLoader (Subset = local client dataset; B=None means full batch)
    subset = Subset(dataset, indices.tolist())
    if B is None:
        loader = DataLoader(subset, batch_size=len(subset), shuffle=False)   # Full batch GD
    else:
        loader = DataLoader(subset, batch_size=B, shuffle=True, drop_last=False)

    # 4) Local training for E epochs (multiple local steps before communication)
    for _ in range(E):
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()           # Reset gradients
            logits = model(x)               # Forward pass
            loss = criterion(logits, y)     # Compute cross-entropy loss
            loss.backward()                 # Backward pass
            optimizer.step()                # Update parameters

    # 5) Return updated local model parameters and the number of samples (for weighted aggregation)
    return model.state_dict(), len(indices)

In [10]:
def fedavg_aggregate(updates):
    total = sum(n_k for _, n_k in updates)   # Total number of samples n = sum(n_k)
    keys = list(updates[0][0].keys())        # Get all parameter keys (e.g., 'fc1.weight')
    agg = {k: torch.zeros_like(updates[0][0][k]) for k in keys}  # Initialize with zeros

    for state_k, n_k in updates:             # Iterate over each client’s parameters
        w = n_k / total                      # Weight = client sample size / total samples
        for k in keys:
            agg[k] += state_k[k] * w         # Weighted accumulation -> weighted average

    return agg                               # Return the new global parameter dict

In [11]:
@torch.no_grad()
def evaluate(model, test_ds, device=torch.device("cpu")):
    model.eval()   # Evaluation mode (disable Dropout/BN training behavior)
    loader = DataLoader(test_ds, batch_size=1024, shuffle=False)  # Large batch for faster testing
    correct, total = 0, 0

    for x, y in loader:
        x, y = x.to(device), y.to(device)
        pred = model(x).argmax(dim=1)        # Prediction = class with highest probability
        correct += (pred == y).sum().item()  # Count correct predictions
        total += y.size(0)                   # Count total samples

    return correct / total                   # Return accuracy

In [20]:
def run_once_with_lr(
    lr: float,
    *,
    seed: int = 2025,
    partition: str = "noniid",   # Data partitioning scheme: "iid" or "noniid"
    K: int = 100,                # Total number of clients
    C: float = 0.1,              # Fraction of clients participating in each round
    E: int = 1,                  # Local training epochs per client
    B: int | None = 10,          # Local mini-batch size (None = full batch)
    target_acc: float = 0.97,    # Target accuracy threshold to stop training
    max_rounds: int = 500,      
    device: torch.device = torch.device("cpu"), 
    DATA_DIR: str | None = None,
) -> int | None:
    """Run one FedAvg training with given (seed, lr).
    Returns the round number when target_acc is first achieved,
    or None if not reached within max_rounds.
    """
    set_seed(seed)  # Ensure reproducibility for random operations

    # ---- Dataset preparation ----
    transform = transforms.Compose([
        transforms.ToTensor(),                      # Convert images (PIL/ndarray) to tensors scaled in [0,1]
        transforms.Normalize((0.1307,), (0.3081,)), # Normalize MNIST with mean=0.1307 and std=0.3081
    ])
    DATA_DIR = os.path.expanduser("~/datasets/mnist")  # Directory to store MNIST dataset
    os.makedirs(DATA_DIR, exist_ok=True)               # Create directory if not exists
    train_ds = datasets.MNIST(root=DATA_DIR, train=True, download=True, transform=transform)   # Training dataset
    test_ds  = datasets.MNIST(root=DATA_DIR, train=False, download=True, transform=transform)  # Test dataset

    # ---- Partition data among clients ----
    if partition == "iid":
        parts = make_iid_parts(train_ds, K=K)                 # Split dataset into IID partitions
    else:
        parts = make_pathological_noniid_parts(train_ds, K=K) # Split dataset into pathological non-IID partitions

    # ---- Initialize global model ----
    global_model = TwoNN().to(device)  # Two-layer NN model for MNIST, moved to device

    # ---- Federated learning loop ----
    best_acc = 0.0          # Track the best test accuracy observed so far
    rounds_to_target = None # Record the first round achieving target accuracy

    for t in range(1, max_rounds + 1):  # Iterate over communication rounds
        m = max(int(C * K), 1)  # Number of clients selected this round (at least 1)
        selected = np.random.choice(np.arange(K), size=m, replace=False)  # Randomly select m clients

        updates = []  # List to store client updates (model state, sample count)
        for k in selected:
            state_k, n_k = client_update(
                global_model, train_ds, parts[k],  # Client trains on its local partition
                E=E, B=B, lr=lr, device=device     # Using given hyperparameters
            )
            updates.append((state_k, n_k))  # Collect local update and dataset size

        new_state = fedavg_aggregate(updates)  # Aggregate client updates (weighted average)
        global_model.load_state_dict(new_state)  # Update global model with aggregated state

        acc = evaluate(global_model, test_ds, device)  # Evaluate global model on test set
        best_acc = max(best_acc, acc)                  # Track best accuracy seen so far

        if best_acc >= target_acc and rounds_to_target is None:
            rounds_to_target = t  # Record the round when target accuracy is first reached
            break                 # Stop early since target is achieved

    return rounds_to_target  # Return round number, or None if not achieved


In [None]:
def main():
    # ---- Fixed experiment knobs (same as paper defaults) ----
    partition = "noniid"              # "iid" or "noniid"
    K, C, E, B = 100, 0.1, 1, 10
    target_acc = 0.97
    max_rounds = 500
    device = torch.device("cpu")
    DATA_DIR = os.path.expanduser("~/datasets/mnist")
    seed = 2025                       

    # ---- Learning-rate candidates (you can tweak/expand) ----
    lr_candidates = [0.03, 0.05, 0.07, 0.1, 0.15]

    print(f"[Grid] partition={partition}, K={K}, C={C}, E={E}, B={B}, "
          f"target={int(target_acc*100)}%, max_rounds={max_rounds}, seed={seed}")
    print("-" * 72)

    results: list[tuple[float, int | None]] = []
    for lr in lr_candidates:
        r = run_once_with_lr(
            lr,
            seed=seed,
            partition=partition, K=K, C=C, E=E, B=B,
            target_acc=target_acc, max_rounds=max_rounds,
            device=device, DATA_DIR=DATA_DIR,
        )
        results.append((lr, r))
        if r is None:
            print(f"lr={lr:<6g} | rounds=NA  (did not reach {int(target_acc*100)}%)")
        else:
            print(f"lr={lr:<6g} | rounds={r}")

    # ---- pick best: minimal rounds among those that reached target ----
    reached = [(lr, r) for (lr, r) in results if r is not None]
    print("-" * 72)
    if reached:
        best_lr, best_rounds = min(reached, key=lambda x: x[1])
        print(f"BEST lr={best_lr} | rounds-to-{int(target_acc*100)}% = {best_rounds}")
    else:
        print(f"No lr reached {int(target_acc*100)}% within {max_rounds} rounds. "
              f"Try increasing max_rounds or widening lr grid.")

if __name__ == "__main__":
    main()

[Grid] partition=noniid, K=100, C=0.1, E=1, B=10, target=97%, max_rounds=500, seed=2025
------------------------------------------------------------------------
lr=0.03   | rounds=359
lr=0.05   | rounds=227
lr=0.07   | rounds=190
lr=0.1    | rounds=142
lr=0.15   | rounds=145
------------------------------------------------------------------------
BEST lr=0.1 | rounds-to-97% = 142
