# Non-IID Data in Federated Learning

In real federated learning, client data is rarely **IID** (independent & identically distributed). Clients usually reflect different users, devices, regions, or institutions, so their local label distributions differ.

In this notebook we simulate **label distribution skew** on MNIST and observe how it affects FedAvg training.

## IID vs Non-IID (what we mean here)
- **IID split:** each client receives an unbiased random subset of the global dataset.
- **Non-IID split (label skew):** clients receive different class mixtures (some clients see mostly a few digits).

## Dirichlet partitioning (how we create skew)
We use a Dirichlet-based split to generate per-class allocation proportions across clients. The **concentration parameter** controls skew:
- smaller concentration → more skew (clients specialize),
- larger concentration → more uniform (closer to IID).

### Notebook knob: `non_iid_per`
This notebook uses a single knob `non_iid_per ∈ [0, 1]` and converts it to a Dirichlet concentration `alpha`:
- `non_iid_per = 0` → close to IID (`alpha ≈ 1`)
- `non_iid_per → 1` → strong skew (`alpha` near `0.01`)

**What to watch:** as skew increases, FedAvg typically converges slower and/or reaches lower global accuracy.


In [None]:
import os, math, random, logging
from copy import deepcopy
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms


In [None]:
%matplotlib inline

In [None]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


In [None]:
class Net(nn.Module):

    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool  = nn.MaxPool2d(2)
        self.dropout = nn.Dropout(0.25)
        self.fc1 = nn.Linear(64 * 14 * 14, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = self.dropout(F.relu(self.fc1(x)))
        return self.fc2(x)


In [None]:
def download_mnist(root="./data"):

    _ = datasets.MNIST(root=root, train=True,  download=True)
    _ = datasets.MNIST(root=root, train=False, download=True)
    print(f"MNIST downloaded to: {os.path.abspath(root)}")
    


In [None]:
def mnist_transforms():

    return transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

def load_mnist_with_transforms(root="./data", transform=None):

    if transform is None:
        transform = mnist_transforms()
    train_ds = datasets.MNIST(root=root, train=True,  download=False, transform=transform)
    test_ds  = datasets.MNIST(root=root, train=False, download=False, transform=transform)
    return train_ds, test_ds


In [None]:
def _iid_indices(n_items, num_clients, seed=42):
    # print("iid")
    rng = np.random.default_rng(seed)
    idx = np.arange(n_items)
    rng.shuffle(idx)
    splits = np.array_split(idx, num_clients)
    return [np.array(s, dtype=int) for s in splits]


In [2]:

def _dirichlet_label_skew_indices(targets, num_clients, alpha=0.5, seed=42):

    # print("skew")
    rng = np.random.default_rng(seed)
    targets = np.asarray(targets)
    classes = np.unique(targets)
    client_indices = [[] for _ in range(num_clients)]

    for c in classes:
        c_idx = np.where(targets == c)[0]
        rng.shuffle(c_idx)

        props = rng.dirichlet([alpha] * num_clients)

        counts = (len(c_idx) * props).astype(int)

        while counts.sum() < len(c_idx):
            counts[np.argmax(props)] += 1

        start = 0
        for i, cnt in enumerate(counts):
            if cnt > 0:
                client_indices[i].extend(c_idx[start:start+cnt])
                start += cnt


    for i in range(num_clients):
        client_indices[i] = np.array(client_indices[i], dtype=int)
        rng.shuffle(client_indices[i])
    return client_indices


# Load client data (DataLoaders per client)

We convert the global training dataset into **one DataLoader per client**:

- `client_loaders[i]` yields batches from *client i’s private training set*.
- `test_loader` is a shared test DataLoader used only for evaluation of the global model.

## Split logic
- If `non_iid_per` is ~0, we create an IID random split of indices across clients.
- Otherwise, we build a **Dirichlet label-skew split** so each client gets a different label mixture.

This is the key step that turns a centralized dataset into a federated setting.


In [None]:

def make_client_loaders(
    train_ds,
    test_ds,
    num_clients=10,
    batch_size=64,
    non_iid_per=0.0,
    seed=42,
    return_indices=False,
):

    if non_iid_per <= 1e-8:
        client_idxs = _iid_indices(len(train_ds), num_clients, seed=seed)
    else:
        alpha = max(0.01, 1.0 - 0.99 * non_iid_per)
        targets = train_ds.targets if hasattr(train_ds, "targets") else train_ds.labels
        client_idxs = _dirichlet_label_skew_indices(targets, num_clients, alpha=alpha, seed=seed)

    local_loaders = [
        DataLoader(Subset(train_ds, idxs), batch_size=batch_size, shuffle=True, drop_last=False)
        for idxs in client_idxs
    ]
    test_loader = DataLoader(test_ds, batch_size=512, shuffle=False, drop_last=False)

    return (local_loaders, test_loader, client_idxs) if return_indices else (local_loaders, test_loader)


# Client: local training on private data

A `Client` simulates one participant in FL. Each client holds:
- `data`: its own DataLoader (private samples)
- `x`: a reference to the **current global model** (received each round)
- `y`: the **local model copy** trained on this client (produced each round)

## Local update idea
In each round:
1. The server assigns `client.x = global_model`
2. The client deep-copies `x → y` (so it can train locally)
3. The client trains `y` for `num_epochs` on its private batches
4. The trained `y` is sent back to the server for aggregation

This notebook uses a simple manual SGD-style update using gradients computed from the loss.


In [None]:
class Client():
    def __init__(self, client_id, local_data, device, num_epochs, criterion, lr):
        self.id = client_id
        self.data = local_data
        self.device = device if isinstance(device, torch.device) else torch.device(device)
        self.num_epochs = num_epochs
        self.lr = lr
        self.criterion = criterion
        self.x = None
        self.y = None

    def client_update(self):
        self.y = deepcopy(self.x)
        self.y.to(self.device)

        for epoch in range(self.num_epochs):
            for inputs, labels in self.data:
                inputs = inputs.float().to(self.device)
                labels = labels.long().to(self.device)

                output = self.y(inputs)
                loss = self.criterion(output, labels)
                grads = torch.autograd.grad(loss, self.y.parameters())

                with torch.no_grad():
                    for param, grad in zip(self.y.parameters(), grads):
                        param.data = param.data - self.lr * grad.data


# Evaluate function (global model performance)

After aggregation, we evaluate the **global model** on a shared test set.

What this function does:
- sets the model to `eval()` mode,
- disables gradient computation (`torch.no_grad()`),
- computes average loss across all test samples,
- computes accuracy (%) across all test samples.

This gives us a consistent metric to compare across rounds and across different non-IID settings.


In [5]:
def evaluate_fn(test_loader, model, criterion, device):
    # print("eval")
    model.eval()
    n, total_loss, correct = 0, 0.0, 0
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device, dtype=torch.float)
            y = y.to(device, dtype=torch.long)
            logits = model(x)
            loss = criterion(logits, y)
            total_loss += loss.item() * x.size(0)
            pred = logits.argmax(dim=1)
            correct += (pred == y).sum().item()
            n += x.size(0)
    return (total_loss / max(1, n)), (100.0 * correct / max(1, n))


# Client Sampling (partial participation)

Real FL systems rarely have all clients available every round.
To simulate this, we sample only a fraction of clients each round:

- `fraction = 0.2` means ~20% of clients participate per round
- Sampling changes training dynamics and adds variance to the updates

This is why we track performance across many rounds rather than judging a single round.


In [6]:

def sample_clients(num_clients, fraction, rng):
    k = max(1, int(math.ceil(fraction * num_clients)))
    return sorted(rng.choice(np.arange(num_clients), size=k, replace=False).tolist())


# Federated Averaging (FedAvg)

FedAvg aggregates the client-trained models into a new global model by averaging parameters.

In this implementation:
- we average each parameter tensor across the participating clients
- each participating client contributes equally

Note: Many FL implementations use a **weighted average** (by each client’s dataset size). This notebook uses the simplest unweighted baseline.



In [7]:

def fedavg_aggregate(global_model, client_models, device):

    with torch.no_grad():
        # init accumulators
        acc = [torch.zeros_like(p, device=device) for p in global_model.parameters()]
        for cm in client_models:
            for a, p in zip(acc, cm.parameters()):
                a.add_(p.to(device))
        for gp, a in zip(global_model.parameters(), acc):
            gp.copy_(a / len(client_models))


# Training (FedAvg rounds)

`train_fedavg(...)` runs the full FL process:

Per round:
1. **Sample clients** based on `fraction`
2. **Broadcast** the current global model to selected clients (`client.x = global_model`)
3. **Local training** on each selected client (`client.client_update()`)
4. **Aggregate** client models into the global model (FedAvg)
5. **Evaluate** the global model on the test set
6. **Log** metrics into `history` for plotting

Output:
- `global_model`: trained model after all rounds
- `history`: per-round loss/accuracy curves


In [8]:

def train_fedavg(
    model_ctor,
    client_loaders,
    test_loader,
    *,
    device,
    rounds=20,
    local_epochs=1,
    fraction=0.1,
    local_lr=0.01,
    criterion=None,
    seed=42,
    log_level=logging.INFO,
):

    # print("train")
    logging.basicConfig(level=log_level, format="%(message)s")
    rng = np.random.default_rng(seed)

    # global model
    global_model = model_ctor().to(device)
    if criterion is None:
        criterion = nn.CrossEntropyLoss()

    # wrap clients
    clients = [
        Client(i, ld, device=device, num_epochs=local_epochs, criterion=criterion, lr=local_lr)
        for i, ld in enumerate(client_loaders)
    ]

    history = {"round": [], "loss": [], "acc": []}

    for r in range(1, rounds + 1):
        # print("round")
        # sample & broadcast
        ids = sample_clients(len(clients), fraction, rng)
        for i in ids:
            clients[i].x = global_model  # reference (they deepcopy inside)

        # local updates
        for i in ids:
            clients[i].client_update()

        # aggregate
        fedavg_aggregate(global_model, [clients[i].y for i in ids], device=device)

        # evaluate
        test_loss, test_acc = evaluate_fn(test_loader, global_model, criterion, device)
        history["round"].append(r); history["loss"].append(test_loss); history["acc"].append(test_acc)
        logging.info(f"Round {r:03d} | test loss: {test_loss:.4f} | test acc: {test_acc:.2f}%")
        print(f"Round {r:03d} | test loss: {test_loss:.4f} | test acc: {test_acc:.2f}%")

    return global_model, history

# Run Experiment

This section sets experiment hyperparameters and executes the pipeline end-to-end.

Key knobs:
- `num_clients`: number of simulated clients
- `non_iid_per`: degree of label skew across clients
- `rounds`: number of FL rounds
- `local_epochs`: how long each client trains per round
- `fraction`: percentage of clients participating per round
- `local_lr`: client learning rate

Suggested experiments:
1. Fix everything, sweep `non_iid_per` (0 → 1) and compare curves.
2. Increase `local_epochs` and observe whether skew causes more client drift.
3. Change `fraction` (e.g., 0.1 vs 0.5) and compare stability/convergence.


In [9]:
seed = 42
set_seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_root = "./data"
num_clients = 10
batch_size = 64
non_iid_per = 0.5     # 0 = IID, 1 = very non-IID
rounds = 20
local_epochs = 1
fraction = 0.2         # 20% clients per round
local_lr = 0.01
criterion = nn.CrossEntropyLoss()

# Data
download_mnist(data_root)
train_ds, test_ds = load_mnist_with_transforms(data_root)
client_loaders, test_loader = make_client_loaders(
    train_ds, test_ds, num_clients=num_clients,
    batch_size=batch_size, non_iid_per=non_iid_per, seed=seed
)

# Train
global_model, history = train_fedavg(
    Net, client_loaders, test_loader,
    device=device, rounds=rounds, local_epochs=local_epochs,
    fraction=fraction, local_lr=local_lr, criterion=criterion, seed=seed
)


100%|██████████| 9.91M/9.91M [00:00<00:00, 138MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 14.4MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 44.9MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 3.54MB/s]


MNIST downloaded to: /home/ahoop004/T3-Ciders-FL/2_IID_Concepts/data


Round 001 | test loss: 1.4681 | test acc: 57.74%


Round 001 | test loss: 1.4681 | test acc: 57.74%


Round 002 | test loss: 0.5020 | test acc: 84.10%


Round 002 | test loss: 0.5020 | test acc: 84.10%


Round 003 | test loss: 0.7259 | test acc: 75.35%


Round 003 | test loss: 0.7259 | test acc: 75.35%


Round 004 | test loss: 0.6772 | test acc: 78.30%


Round 004 | test loss: 0.6772 | test acc: 78.30%


Round 005 | test loss: 0.5084 | test acc: 81.45%


Round 005 | test loss: 0.5084 | test acc: 81.45%


Round 006 | test loss: 1.0009 | test acc: 69.00%


Round 006 | test loss: 1.0009 | test acc: 69.00%


Round 007 | test loss: 0.3791 | test acc: 88.33%


Round 007 | test loss: 0.3791 | test acc: 88.33%


Round 008 | test loss: 0.3841 | test acc: 87.39%


Round 008 | test loss: 0.3841 | test acc: 87.39%


Round 009 | test loss: 0.3543 | test acc: 88.67%


Round 009 | test loss: 0.3543 | test acc: 88.67%


Round 010 | test loss: 0.3273 | test acc: 89.87%


Round 010 | test loss: 0.3273 | test acc: 89.87%


Round 011 | test loss: 0.2175 | test acc: 93.51%


Round 011 | test loss: 0.2175 | test acc: 93.51%


Round 012 | test loss: 0.2556 | test acc: 91.92%


Round 012 | test loss: 0.2556 | test acc: 91.92%


Round 013 | test loss: 0.2572 | test acc: 92.00%


Round 013 | test loss: 0.2572 | test acc: 92.00%


Round 014 | test loss: 0.2370 | test acc: 92.26%


Round 014 | test loss: 0.2370 | test acc: 92.26%


Round 015 | test loss: 0.2689 | test acc: 91.75%


Round 015 | test loss: 0.2689 | test acc: 91.75%


Round 016 | test loss: 0.2019 | test acc: 93.73%


Round 016 | test loss: 0.2019 | test acc: 93.73%


Round 017 | test loss: 0.2724 | test acc: 90.62%


Round 017 | test loss: 0.2724 | test acc: 90.62%


Round 018 | test loss: 0.1938 | test acc: 93.92%


Round 018 | test loss: 0.1938 | test acc: 93.92%


Round 019 | test loss: 0.1769 | test acc: 94.51%


Round 019 | test loss: 0.1769 | test acc: 94.51%


Round 020 | test loss: 0.2238 | test acc: 93.02%


Round 020 | test loss: 0.2238 | test acc: 93.02%


## Sweep experiment: vary `non_iid_per` and collect results

This cell runs the full FL pipeline multiple times to compare different non-IID strengths:

- Sweeps `non_iid_per ∈ {0.10, 0.25, 0.50, 0.75, 0.90}`
- For each value:
  1. Re-builds client datasets/loaders using that non-IID setting
  2. Computes a per-client label count matrix (clients × classes)
  3. Trains FedAvg and records the per-round `history`
  4. Stores everything in `sweep_results[non_iid_per]`

**Output object:** `sweep_results`  
Contains:
- `history`: global loss/accuracy vs round
- `label_counts`: client label distribution matrix (for visualization)
- `global_model`: final model for that sweep setting

Note: this can take time on HPC; reduce `rounds`, `num_clients`, or skip label counting if needed.


In [None]:
# Sweep values you requested
non_iid_sweep = [0.10, 0.25, 0.50, 0.75, 0.90]

# Use existing notebook hyperparams if they exist, otherwise fall back
seed         = globals().get("seed", 42)
num_clients  = globals().get("num_clients", 10)
batch_size   = globals().get("batch_size", 64)
rounds       = globals().get("rounds", 20)
local_epochs = globals().get("local_epochs", 1)
fraction     = globals().get("fraction", 0.2)
local_lr     = globals().get("local_lr", 0.01)
device       = globals().get("device", "cpu")
criterion    = globals().get("criterion", None)

data_root = globals().get("data_root", "./data")

# Storage
sweep_results = {}   # maps non_iid_per -> {"history":..., "label_counts":...}

def label_count_from_indices(targets, client_idxs, num_classes=10):
    """Fast counts: build a [num_clients, num_classes] matrix from split indices (no DataLoader iteration)."""
    if hasattr(targets, "detach"):
        t = targets.detach().cpu().numpy()
    else:
        t = np.asarray(targets)

    counts = np.zeros((len(client_idxs), num_classes), dtype=int)
    for ci, idxs in enumerate(client_idxs):
        idxs = np.asarray(idxs, dtype=int)
        counts[ci] = np.bincount(t[idxs], minlength=num_classes)
    return counts

for non_iid_per in non_iid_sweep:
    set_seed(seed)

    # Data + client loaders
    train_ds, test_ds = load_mnist_with_transforms(data_root)
    client_loaders, test_loader, client_idxs = make_client_loaders(
        train_ds, test_ds,
        num_clients=num_clients,
        batch_size=batch_size,
        non_iid_per=non_iid_per,
        seed=seed,
        return_indices=True
    )

    # Label distribution snapshot (clients x classes)
    targets = train_ds.targets if hasattr(train_ds, "targets") else train_ds.labels
    counts = label_count_from_indices(targets, client_idxs, num_classes=10)

    # Train
    global_model, history = train_fedavg(
        Net, client_loaders, test_loader,
        device=device,
        rounds=rounds,
        local_epochs=local_epochs,
        fraction=fraction,
        local_lr=local_lr,
        criterion=criterion,
        seed=seed
    )

    sweep_results[non_iid_per] = {
        "history": history,
        "label_counts": counts,
        "global_model": global_model,
    }

print("Done. Keys:", list(sweep_results.keys()))


## Plot comparison: client distributions + training curves

This cell visualizes two things for each `non_iid_per` in the sweep:

1. **Client label distributions (heatmaps)**  
   - Rows: clients  
   - Columns: classes (digits 0–9)  
   - Color intensity: number of samples  
   This shows how skewed each client’s data is.

2. **Training results across rounds (overlay plots)**  
   - Global **Accuracy vs Round** (one line per `non_iid_per`)
   - Global **Loss vs Round** (one line per `non_iid_per`)

These plots make it easy to see how increasing non-IID skew changes convergence speed and final performance.


In [None]:

def _get_hist(history, key_candidates):
    for k in key_candidates:
        if isinstance(history, dict) and k in history:
            return np.array(history[k])
    raise KeyError(f"History missing keys {key_candidates}. Available: {list(history.keys()) if isinstance(history, dict) else type(history)}")

# 1) Plot label distributions (heatmap per sweep value)
for non_iid_per, pack in sweep_results.items():
    counts = pack["label_counts"]  # shape [num_clients, 10]

    plt.figure()
    plt.imshow(counts, aspect="auto")
    plt.colorbar(label="samples")
    plt.xlabel("Class (digit)")
    plt.ylabel("Client index")
    plt.title(f"Client label distribution (non_iid_per={non_iid_per})")
    plt.xticks(range(10))
    plt.show()

# 2) Plot accuracy vs round (all sweeps on one plot)
plt.figure()
for non_iid_per, pack in sweep_results.items():
    h = pack["history"]
    rounds_arr = _get_hist(h, ["round", "rounds", "global_round"])
    acc_arr    = _get_hist(h, ["acc", "accuracy", "test_acc", "global_acc"])
    plt.plot(rounds_arr, acc_arr, marker="o", label=f"{non_iid_per}")
plt.xlabel("Round")
plt.ylabel("Accuracy (%)")
plt.title("Global Accuracy vs Round (sweep over non_iid_per)")
plt.grid(True)
plt.legend(title="non_iid_per")
plt.show()

# 3) Plot loss vs round (all sweeps on one plot)
plt.figure()
for non_iid_per, pack in sweep_results.items():
    h = pack["history"]
    rounds_arr = _get_hist(h, ["round", "rounds", "global_round"])
    loss_arr   = _get_hist(h, ["loss", "test_loss", "global_loss"])
    plt.plot(rounds_arr, loss_arr, marker="o", label=f"{non_iid_per}")
plt.xlabel("Round")
plt.ylabel("Loss")
plt.title("Global Loss vs Round (sweep over non_iid_per)")
plt.grid(True)
plt.legend(title="non_iid_per")
plt.show()
