# Module 1 — Federated Learning (FL) Intro: From Centralized Training to FedAvg

This notebook is a hands-on walkthrough of the *minimum working pipeline* for Federated Learning:
- what changes relative to centralized ML,
- how client data is split,
- how local training works,
- how the server aggregates updates (FedAvg),
- and how we evaluate a global model over multiple rounds.

**Output:** a trained *global model* and plots showing training progress across FL rounds.



## Learning objectives

By the end of this notebook you should be able to:
1. Explain the difference between *centralized* training and *federated* training.
2. Describe the 5-step FL loop (select clients → send model → local train → aggregate → repeat).
3. Implement a simple FedAvg round using multiple clients.
4. Evaluate a global model and interpret accuracy/loss curves across rounds.



## Imports

We import:
- **PyTorch** for models, training, and tensors,
- **torchvision** for datasets and transforms,
- **NumPy** for light data manipulation,
- **matplotlib** for visualization.

As you read the code, focus on *where FL logic begins*—the parts that simulate multiple clients and aggregation.



In [2]:
import os, math, random, logging
from copy import deepcopy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms

## Reproducibility

We set random seeds so that:
- dataset splits across clients are repeatable,
- training results are easier to compare across runs,
- and debugging is less painful.

Note: exact reproducibility can still vary across GPU/CPU and library versions.



In [15]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

## Model: a small CNN for MNIST

All clients train the **same model architecture**.

Federated Learning does not require a special model; the key change is *how training is coordinated*.

In each round:
- The server sends the current global weights to clients.
- Clients train locally on their private data.
- The server aggregates client updates into a new global model.


In [24]:
class Net(nn.Module):

    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool  = nn.MaxPool2d(2)
        self.dropout = nn.Dropout(0.25)
        self.fc1 = nn.Linear(64 * 14 * 14, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = self.dropout(F.relu(self.fc1(x)))
        return self.fc2(x)

## Dataset overview

We use a standard vision dataset so we can focus on the *federated process* instead of domain-specific data cleaning.

Key idea:
- In centralized ML: all data is available in one place.
- In FL: data is partitioned across many clients (devices/organizations) and **does not move** to the server.



In [25]:
def download_mnist(root="./data"):

    _ = datasets.MNIST(root=root, train=True,  download=True)
    _ = datasets.MNIST(root=root, train=False, download=True)
    print(f"MNIST downloaded to: {os.path.abspath(root)}")

## Data preprocessing and transforms

We convert images to tensors and normalize them so training is stable and comparable across clients.

Typical transform components:
- `ToTensor()` scales pixel values to `[0, 1]`.
- `Normalize(mean, std)` standardizes the input distribution.

All clients should use the same preprocessing for a clean baseline.


In [26]:
def mnist_transforms():

    return transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

def load_mnist_with_transforms(root="./data", transform=None):

    if transform is None:
        transform = mnist_transforms()
    train_ds = datasets.MNIST(root=root, train=True,  download=False, transform=transform)
    test_ds  = datasets.MNIST(root=root, train=False, download=False, transform=transform)
    return train_ds, test_ds

## Distributing data across clients

We simulate **multiple clients** by partitioning the training dataset into per-client subsets.

Two common modes are supported:
- **IID**: each client receives a random slice of the dataset (similar class distribution).
- **Non-IID**: each client receives a skewed distribution (some classes dominate per client).

In this intro module, we start with IID (or mild non-IID) so the FedAvg loop is easier to validate.


In [27]:
def _iid_indices(n_items, num_clients, seed=42):
    # print("iid")
    rng = np.random.default_rng(seed)
    idx = np.arange(n_items)
    rng.shuffle(idx)
    splits = np.array_split(idx, num_clients)
    return [np.array(s, dtype=int) for s in splits]

def _dirichlet_label_skew_indices(targets, num_clients, alpha=0.5, seed=42):

    # print("skew")
    rng = np.random.default_rng(seed)
    targets = np.asarray(targets)
    classes = np.unique(targets)
    client_indices = [[] for _ in range(num_clients)]

    for c in classes:
        c_idx = np.where(targets == c)[0]
        rng.shuffle(c_idx)

        props = rng.dirichlet([alpha] * num_clients)

        counts = (len(c_idx) * props).astype(int)

        while counts.sum() < len(c_idx):
            counts[np.argmax(props)] += 1

        start = 0
        for i, cnt in enumerate(counts):
            if cnt > 0:
                client_indices[i].extend(c_idx[start:start+cnt])
                start += cnt


    for i in range(num_clients):
        client_indices[i] = np.array(client_indices[i], dtype=int)
        rng.shuffle(client_indices[i])
    return client_indices

def make_client_loaders(
    train_ds,
    test_ds,
    num_clients=10,
    batch_size=64,
    non_iid_per=0.0,
    seed=42,
):

    if non_iid_per <= 1e-8:
        client_idxs = _iid_indices(len(train_ds), num_clients, seed=seed)
    else:

        alpha = max(0.01, 1.0 - 0.99 * non_iid_per)
        targets = train_ds.targets if hasattr(train_ds, "targets") else train_ds.labels
        client_idxs = _dirichlet_label_skew_indices(targets, num_clients, alpha=alpha, seed=seed)

    local_loaders = [
        DataLoader(Subset(train_ds, idxs), batch_size=batch_size, shuffle=True, drop_last=False)
        for idxs in client_idxs
    ]
    test_loader = DataLoader(test_ds, batch_size=512, shuffle=False, drop_last=False)
    return local_loaders, test_loader

## Federated Averaging (FedAvg): client updates + server aggregation

FedAvg is the simplest baseline FL algorithm.

High-level flow per round:
1. Sample a subset of clients.
2. Broadcast the current global model to those clients.
3. Each client trains locally for a small number of epochs.
4. The server averages client models (often weighted by client dataset size).
5. Evaluate the updated global model on a shared test set.

This establishes a baseline you can compare against later modules (non-IID, attacks, defenses).


## Client object: local training on private data

Each client holds:
- its own private dataset (a DataLoader),
- local hyperparameters like local epochs and learning rate,
- two model references: a local copy to train and a way to receive updated global weights.

The client never sends raw data to the server—only model parameters (or parameter differences).


In [28]:
class Client():

    def __init__(self, client_id, local_data, device, num_epochs, criterion, lr):
        self.id = client_id
        self.data = local_data
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.num_epochs = num_epochs
        self.lr = lr
        self.criterion = criterion
        self.x = None
        self.y = None

    def client_update(self):
        # print("update")


        self.y = deepcopy(self.x)
        self.y.to(self.device)

        for epoch in range(self.num_epochs):

            for inputs,labels in self.data:
              inputs, labels = inputs.float().to(self.device), labels.long().to(self.device)
              output = self.y(inputs)
              loss = self.criterion(output, labels)
              grads = torch.autograd.grad(loss,self.y.parameters())

              with torch.no_grad():
                  for param,grad in zip(self.y.parameters(),grads):
                      param.data = param.data - self.lr * grad.data

              if self.device == "cuda": torch.cuda.empty_cache()


## Server-side utilities: evaluation, sampling, aggregation, and history

Server-side responsibilities in this notebook:
- `test_model(...)`: computes loss and accuracy on the shared test set.
- `sample_clients(...)`: selects a subset of clients each round (to mimic partial participation).
- `FedAvg(...)`: aggregates client models into a new global model.
- `history`: stores metrics by round so we can plot progress.

These utilities make the training loop easier to read and debug.


In [None]:

def sample_clients(num_clients, fraction, rng):
    k = max(1, int(math.ceil(fraction * num_clients)))
    return sorted(rng.choice(np.arange(num_clients), size=k, replace=False).tolist())

def fedavg_aggregate(global_model, client_models, device):

    with torch.no_grad():
        # init accumulators
        acc = [torch.zeros_like(p, device=device) for p in global_model.parameters()]
        for cm in client_models:
            for a, p in zip(acc, cm.parameters()):
                a.add_(p.to(device))
        for gp, a in zip(global_model.parameters(), acc):
            gp.copy_(a / len(client_models))


In [None]:

def train_fedavg(
    model_ctor,
    client_loaders,
    test_loader,
    *,
    device,
    rounds=20,
    local_epochs=1,
    fraction=0.1,
    local_lr=0.01,
    criterion=None,
    seed=42,
    log_level=logging.INFO,
):

    # print("train")
    logging.basicConfig(level=log_level, format="%(message)s")
    rng = np.random.default_rng(seed)

    # global model
    global_model = model_ctor().to(device)
    if criterion is None:
        criterion = nn.CrossEntropyLoss()

    # wrap clients
    clients = [
        Client(i, ld, device=device, num_epochs=local_epochs, criterion=criterion, lr=local_lr)
        for i, ld in enumerate(client_loaders)
    ]

    history = {"round": [], "loss": [], "acc": []}

    for r in range(1, rounds + 1):
        # print("round")
        # sample & broadcast
        ids = sample_clients(len(clients), fraction, rng)
        for i in ids:
            clients[i].x = global_model  # reference (they deepcopy inside)

        # local updates
        for i in ids:
            clients[i].client_update()

        # aggregate
        fedavg_aggregate(global_model, [clients[i].y for i in ids], device=device)

        # evaluate
        test_loss, test_acc = evaluate_fn(test_loader, global_model, criterion, device)
        history["round"].append(r); history["loss"].append(test_loss); history["acc"].append(test_acc)
        logging.info(f"Round {r:03d} | test loss: {test_loss:.4f} | test acc: {test_acc:.2f}%")
        print(f"Round {r:03d} | test loss: {test_loss:.4f} | test acc: {test_acc:.2f}%")

    return global_model, history

In [None]:
def evaluate_fn(test_loader, model, criterion, device):
    # print("eval")
    model.eval()
    n, total_loss, correct = 0, 0.0, 0
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device, dtype=torch.float)
            y = y.to(device, dtype=torch.long)
            logits = model(x)
            loss = criterion(logits, y)
            total_loss += loss.item() * x.size(0)
            pred = logits.argmax(dim=1)
            correct += (pred == y).sum().item()
            n += x.size(0)
    return (total_loss / max(1, n)), (100.0 * correct / max(1, n))

## Experiment configuration and running training

Here we set:
- the number of clients,
- client batch size,
- IID vs non-IID split strength (`non_iid_per`),
- FL rounds, local epochs, and the fraction of clients per round.

Then we:
1. download/load MNIST,
2. create per-client data loaders,
3. run FedAvg training and collect metrics in `history`.


In [30]:
seed = 42
set_seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_root = "./data"
num_clients = 10
batch_size = 64
non_iid_per = 0.0      # 0 = IID, 1 = very non-IID
rounds = 20
local_epochs = 1
fraction = 0.2         # 20% clients per round
local_lr = 0.01
criterion = nn.CrossEntropyLoss()

# Data
download_mnist(data_root)
train_ds, test_ds = load_mnist_with_transforms(data_root)
client_loaders, test_loader = make_client_loaders(
    train_ds, test_ds, num_clients=num_clients,
    batch_size=batch_size, non_iid_per=non_iid_per, seed=seed
)

# Train
global_model, history = train_fedavg(
    Net, client_loaders, test_loader,
    device=device, rounds=rounds, local_epochs=local_epochs,
    fraction=fraction, local_lr=local_lr, criterion=criterion, seed=seed
)


MNIST downloaded to: /content/data
round
Round 001 | test loss: 0.6342 | test acc: 83.18%
round
Round 002 | test loss: 0.3816 | test acc: 88.16%
round
Round 003 | test loss: 0.3553 | test acc: 89.10%
round
Round 004 | test loss: 0.2917 | test acc: 91.24%
round
Round 005 | test loss: 0.2710 | test acc: 92.01%
round
Round 006 | test loss: 0.2376 | test acc: 92.90%
round
Round 007 | test loss: 0.2315 | test acc: 92.98%
round
Round 008 | test loss: 0.2298 | test acc: 92.87%
round
Round 009 | test loss: 0.1945 | test acc: 94.08%
round
Round 010 | test loss: 0.1803 | test acc: 94.42%
round
Round 011 | test loss: 0.1762 | test acc: 94.64%
round
Round 012 | test loss: 0.1821 | test acc: 94.08%
round
Round 013 | test loss: 0.1626 | test acc: 95.05%
round
Round 014 | test loss: 0.1461 | test acc: 95.68%
round
Round 015 | test loss: 0.1427 | test acc: 95.72%
round
Round 016 | test loss: 0.1366 | test acc: 96.00%
round
Round 017 | test loss: 0.1408 | test acc: 95.76%
round
Round 018 | test loss: 0

## Next steps 

- Plot `history["acc"]` and `history["loss"]` versus `history["round"]`.
- Save the trained `global_model` checkpoint for later modules.
- Try changing one knob at a time (clients, local epochs, fraction, non-IID) and compare curves.


In [None]:
import numpy as np
import matplotlib.pyplot as plt

def _get_first_key(d, keys):
    for k in keys:
        if k in d:
            return k
    raise KeyError(f"None of these keys found in history: {keys}. Available keys: {list(d.keys())}")

# Adjust these if your notebook uses different names
round_key = _get_first_key(history, ["round", "rounds", "r", "epoch", "global_round"])
acc_key   = _get_first_key(history, ["acc", "accuracy", "test_acc", "global_acc"])
loss_key  = _get_first_key(history, ["loss", "test_loss", "global_loss"])

rounds = np.array(history[round_key])
acc    = np.array(history[acc_key])
loss   = np.array(history[loss_key])

# Accuracy vs Round
plt.figure()
plt.plot(rounds, acc, marker="o")
plt.xlabel("Round")
plt.ylabel("Accuracy")
plt.title("Global Accuracy vs Round")
plt.grid(True)
plt.show()

# Loss vs Round
plt.figure()
plt.plot(rounds, loss, marker="o")
plt.xlabel("Round")
plt.ylabel("Loss")
plt.title("Global Loss vs Round")
plt.grid(True)
plt.show()


In [None]:

# Choose a location that works well on ODU HPC
save_dir = os.environ.get("FD_learning", ".")
save_path = os.path.join(save_dir, "t3fl_global_model.pt")

# Prefer saving state_dict (portable, standard)
ckpt = {
    "model_state_dict": global_model.state_dict(),
    # Optional but useful:
    "round": history.get("round", None) if isinstance(history, dict) else None,
    "history": history if isinstance(history, dict) else None,
    # "config": config_dict,  # uncomment if you have one
}

torch.save(ckpt, save_path)
print(f"Saved checkpoint to: {save_path}")

# --- Loading later (example) ---
# global_model = Net()  # re-create the same model class/architecture first
# ckpt = torch.load(save_path, map_location="cpu")
# global_model.load_state_dict(ckpt["model_state_dict"])
# global_model.eval()