In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

Mounted at /content/drive


In [1]:
import os, math, random, logging
from copy import deepcopy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
class Net(nn.Module):

    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool  = nn.MaxPool2d(2)
        self.dropout = nn.Dropout(0.25)
        self.fc1 = nn.Linear(64 * 14 * 14, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = self.dropout(F.relu(self.fc1(x)))
        return self.fc2(x)
def download_mnist(root="./data"):

    _ = datasets.MNIST(root=root, train=True,  download=True)
    _ = datasets.MNIST(root=root, train=False, download=True)
    print(f"MNIST downloaded to: {os.path.abspath(root)}")
def mnist_transforms():

    return transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

def load_mnist_with_transforms(root="./data", transform=None):

    if transform is None:
        transform = mnist_transforms()
    train_ds = datasets.MNIST(root=root, train=True,  download=False, transform=transform)
    test_ds  = datasets.MNIST(root=root, train=False, download=False, transform=transform)
    return train_ds, test_ds
def _iid_indices(n_items, num_clients, seed=42):
    # print("iid")
    rng = np.random.default_rng(seed)
    idx = np.arange(n_items)
    rng.shuffle(idx)
    splits = np.array_split(idx, num_clients)
    return [np.array(s, dtype=int) for s in splits]


# Non-IID Data in Federated Learning

In real-world federated learning, client data is rarely **IID (independent and identically distributed)**.  
For example:
- A phone user may only write certain digits, so their dataset is biased.  
- A hospital may serve a specific demographic, creating skewed medical records.  

To simulate this, we use **non-IID splits** of MNIST.

---

## Dirichlet Distribution for Data Partitioning

We use the **Dirichlet distribution** to control how labels are assigned to clients:

- The Dirichlet distribution is a **probability distribution over probability vectors**.  
- Each vector element represents the proportion of a class assigned to a client.  
- The parameter **α (alpha)** controls how "skewed" the distribution is:

  - **High α (e.g., α = 10)** → more uniform → data is closer to IID.  
  - **Low α (e.g., α = 0.1)** → very skewed → each client may only see a few labels.  

---

### Example

- Suppose we have 10 clients and 10 digit classes.  
- For each digit, we sample a probability vector from Dirichlet(α).  
- This vector decides how much of that digit’s data goes to each client.  
- Repeating for all digits creates **client datasets with unique label distributions**.  

In [2]:

def _dirichlet_label_skew_indices(targets, num_clients, alpha=0.5, seed=42):

    # print("skew")
    rng = np.random.default_rng(seed)
    targets = np.asarray(targets)
    classes = np.unique(targets)
    client_indices = [[] for _ in range(num_clients)]

    for c in classes:
        c_idx = np.where(targets == c)[0]
        rng.shuffle(c_idx)

        props = rng.dirichlet([alpha] * num_clients)

        counts = (len(c_idx) * props).astype(int)

        while counts.sum() < len(c_idx):
            counts[np.argmax(props)] += 1

        start = 0
        for i, cnt in enumerate(counts):
            if cnt > 0:
                client_indices[i].extend(c_idx[start:start+cnt])
                start += cnt


    for i in range(num_clients):
        client_indices[i] = np.array(client_indices[i], dtype=int)
        rng.shuffle(client_indices[i])
    return client_indices


# Load client data
 Create DataLoader objects for each client.
 Each loader serves local batches of training data.
 Supports IID or non-IID distributions.

In [3]:

def make_client_loaders(
    train_ds,
    test_ds,
    num_clients=10,
    batch_size=64,
    non_iid_per=0.0,
    seed=42,
):

    if non_iid_per <= 1e-8:
        client_idxs = _iid_indices(len(train_ds), num_clients, seed=seed)
    else:

        alpha = max(0.01, 1.0 - 0.99 * non_iid_per)
        targets = train_ds.targets if hasattr(train_ds, "targets") else train_ds.labels
        client_idxs = _dirichlet_label_skew_indices(targets, num_clients, alpha=alpha, seed=seed)

    local_loaders = [
        DataLoader(Subset(train_ds, idxs), batch_size=batch_size, shuffle=True, drop_last=False)
        for idxs in client_idxs
    ]
    test_loader = DataLoader(test_ds, batch_size=512, shuffle=False, drop_last=False)
    return local_loaders, test_loader

#Client
## Holds local data.
##Performs local training for a few epochs on its copy of the model.

In [4]:
class Client():

    def __init__(self, client_id, local_data, device, num_epochs, criterion, lr):
        self.id = client_id
        self.data = local_data
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.num_epochs = num_epochs
        self.lr = lr
        self.criterion = criterion
        self.x = None
        self.y = None

    def client_update(self):
        # print("update")


        self.y = deepcopy(self.x)
        self.y.to(self.device)

        for epoch in range(self.num_epochs):

            for inputs,labels in self.data:
              inputs, labels = inputs.float().to(self.device), labels.long().to(self.device)
              output = self.y(inputs)
              loss = self.criterion(output, labels)
              grads = torch.autograd.grad(loss,self.y.parameters())

              with torch.no_grad():
                  for param,grad in zip(self.y.parameters(),grads):
                      param.data = param.data - self.lr * grad.data

              if self.device == "cuda": torch.cuda.empty_cache()


# Evaluate function
## Evaluate a model on the test dataset.
## Returns average loss and accuracy across all test samples.

In [5]:
def evaluate_fn(test_loader, model, criterion, device):
    # print("eval")
    model.eval()
    n, total_loss, correct = 0, 0.0, 0
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device, dtype=torch.float)
            y = y.to(device, dtype=torch.long)
            logits = model(x)
            loss = criterion(logits, y)
            total_loss += loss.item() * x.size(0)
            pred = logits.argmax(dim=1)
            correct += (pred == y).sum().item()
            n += x.size(0)
    return (total_loss / max(1, n)), (100.0 * correct / max(1, n))


# Client Sampling
## Randomly sample a fraction of clients each round.
## Simulates partial participation in federated training.

In [6]:

def sample_clients(num_clients, fraction, rng):
    k = max(1, int(math.ceil(fraction * num_clients)))
    return sorted(rng.choice(np.arange(num_clients), size=k, replace=False).tolist())


# Federated Averaging
## Aggregate model parameters from selected clients.
## Compute average weights and update the global model.


In [7]:

def fedavg_aggregate(global_model, client_models, device):

    with torch.no_grad():
        # init accumulators
        acc = [torch.zeros_like(p, device=device) for p in global_model.parameters()]
        for cm in client_models:
            for a, p in zip(acc, cm.parameters()):
                a.add_(p.to(device))
        for gp, a in zip(global_model.parameters(), acc):
            gp.copy_(a / len(client_models))


# Training
## Initialize global model.
## Each round: sample clients, broadcast model, perform local training.
## Aggregate updates with FedAvg.
## Evaluate global model on test data.

In [8]:

def train_fedavg(
    model_ctor,
    client_loaders,
    test_loader,
    *,
    device,
    rounds=20,
    local_epochs=1,
    fraction=0.1,
    local_lr=0.01,
    criterion=None,
    seed=42,
    log_level=logging.INFO,
):

    # print("train")
    logging.basicConfig(level=log_level, format="%(message)s")
    rng = np.random.default_rng(seed)

    # global model
    global_model = model_ctor().to(device)
    if criterion is None:
        criterion = nn.CrossEntropyLoss()

    # wrap clients
    clients = [
        Client(i, ld, device=device, num_epochs=local_epochs, criterion=criterion, lr=local_lr)
        for i, ld in enumerate(client_loaders)
    ]

    history = {"round": [], "loss": [], "acc": []}

    for r in range(1, rounds + 1):
        # print("round")
        # sample & broadcast
        ids = sample_clients(len(clients), fraction, rng)
        for i in ids:
            clients[i].x = global_model  # reference (they deepcopy inside)

        # local updates
        for i in ids:
            clients[i].client_update()

        # aggregate
        fedavg_aggregate(global_model, [clients[i].y for i in ids], device=device)

        # evaluate
        test_loss, test_acc = evaluate_fn(test_loader, global_model, criterion, device)
        history["round"].append(r); history["loss"].append(test_loss); history["acc"].append(test_acc)
        logging.info(f"Round {r:03d} | test loss: {test_loss:.4f} | test acc: {test_acc:.2f}%")
        print(f"Round {r:03d} | test loss: {test_loss:.4f} | test acc: {test_acc:.2f}%")

    return global_model, history

# Run Experiment
## Set hyperparameters (clients, batch size, rounds, etc.).
## Download MNIST and prepare client data loaders.
## Train the model using FedAvg and log performance per round.

In [9]:
seed = 42
set_seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_root = "./data"
num_clients = 10
batch_size = 64
non_iid_per = 0.5     # 0 = IID, 1 = very non-IID
rounds = 20
local_epochs = 1
fraction = 0.2         # 20% clients per round
local_lr = 0.01
criterion = nn.CrossEntropyLoss()

# Data
download_mnist(data_root)
train_ds, test_ds = load_mnist_with_transforms(data_root)
client_loaders, test_loader = make_client_loaders(
    train_ds, test_ds, num_clients=num_clients,
    batch_size=batch_size, non_iid_per=non_iid_per, seed=seed
)

# Train
global_model, history = train_fedavg(
    Net, client_loaders, test_loader,
    device=device, rounds=rounds, local_epochs=local_epochs,
    fraction=fraction, local_lr=local_lr, criterion=criterion, seed=seed
)


100%|██████████| 9.91M/9.91M [00:00<00:00, 138MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 14.4MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 44.9MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 3.54MB/s]


MNIST downloaded to: /home/ahoop004/T3-Ciders-FL/2_IID_Concepts/data


Round 001 | test loss: 1.4681 | test acc: 57.74%


Round 001 | test loss: 1.4681 | test acc: 57.74%


Round 002 | test loss: 0.5020 | test acc: 84.10%


Round 002 | test loss: 0.5020 | test acc: 84.10%


Round 003 | test loss: 0.7259 | test acc: 75.35%


Round 003 | test loss: 0.7259 | test acc: 75.35%


Round 004 | test loss: 0.6772 | test acc: 78.30%


Round 004 | test loss: 0.6772 | test acc: 78.30%


Round 005 | test loss: 0.5084 | test acc: 81.45%


Round 005 | test loss: 0.5084 | test acc: 81.45%


Round 006 | test loss: 1.0009 | test acc: 69.00%


Round 006 | test loss: 1.0009 | test acc: 69.00%


Round 007 | test loss: 0.3791 | test acc: 88.33%


Round 007 | test loss: 0.3791 | test acc: 88.33%


Round 008 | test loss: 0.3841 | test acc: 87.39%


Round 008 | test loss: 0.3841 | test acc: 87.39%


Round 009 | test loss: 0.3543 | test acc: 88.67%


Round 009 | test loss: 0.3543 | test acc: 88.67%


Round 010 | test loss: 0.3273 | test acc: 89.87%


Round 010 | test loss: 0.3273 | test acc: 89.87%


Round 011 | test loss: 0.2175 | test acc: 93.51%


Round 011 | test loss: 0.2175 | test acc: 93.51%


Round 012 | test loss: 0.2556 | test acc: 91.92%


Round 012 | test loss: 0.2556 | test acc: 91.92%


Round 013 | test loss: 0.2572 | test acc: 92.00%


Round 013 | test loss: 0.2572 | test acc: 92.00%


Round 014 | test loss: 0.2370 | test acc: 92.26%


Round 014 | test loss: 0.2370 | test acc: 92.26%


Round 015 | test loss: 0.2689 | test acc: 91.75%


Round 015 | test loss: 0.2689 | test acc: 91.75%


Round 016 | test loss: 0.2019 | test acc: 93.73%


Round 016 | test loss: 0.2019 | test acc: 93.73%


Round 017 | test loss: 0.2724 | test acc: 90.62%


Round 017 | test loss: 0.2724 | test acc: 90.62%


Round 018 | test loss: 0.1938 | test acc: 93.92%


Round 018 | test loss: 0.1938 | test acc: 93.92%


Round 019 | test loss: 0.1769 | test acc: 94.51%


Round 019 | test loss: 0.1769 | test acc: 94.51%


Round 020 | test loss: 0.2238 | test acc: 93.02%


Round 020 | test loss: 0.2238 | test acc: 93.02%
