In [None]:
import copy
import numpy as np
from tqdm import trange
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, Dataset

In [None]:
# Utils: data partition (Dirichlet)
def partition_dirichlet(labels, n_clients, alpha=0.5, rng=None):
    """
    Partition indices by Dirichlet distribution to create label-skew (non-IID).
    Returns a list of index lists, one per client.
    - labels: numpy array of labels for full dataset
    - n_clients: number of clients
    - alpha: Dirichlet concentration (smaller -> more skewed)
    """
    if rng is None:
        rng = np.random.RandomState(0)
    n_classes = labels.max() + 1
    idx_by_class = [np.where(labels == c)[0] for c in range(n_classes)]
    client_indices = [[] for _ in range(n_clients)]

    # For each class, split its indices to clients via Dirichlet distribution
    for c in range(n_classes):
        n_c = len(idx_by_class[c])
        # proportions for this class across clients
        probs = rng.dirichlet([alpha] * n_clients)
        # ensure no zero by tiny fix
        probs = (probs / probs.sum()) * n_c
        # floor -> get sizes that sum ~= n_c, assign remainder randomly
        sizes = np.floor(probs).astype(int)
        remainder = n_c - sizes.sum()
        if remainder > 0:
            # distribute the remainder to random clients (weighted)
            add_idx = rng.choice(n_clients, remainder, replace=True)
            for ai in add_idx:
                sizes[ai] += 1
        # now split and assign
        shuffled = rng.permutation(idx_by_class[c])
        pointer = 0
        for client_id in range(n_clients):
            cnt = sizes[client_id]
            if cnt > 0:
                sel = shuffled[pointer: pointer + cnt]
                client_indices[client_id].extend(sel.tolist())
                pointer += cnt
    return client_indices

In [None]:
# -------------------
# Simple CNN model
# -------------------
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1),  # 28x28 -> 28x28
            nn.ReLU(),
            nn.MaxPool2d(2),                 # -> 14x14
            nn.Conv2d(16, 32, 3, padding=1), # -> 14x14
            nn.ReLU(),
            nn.MaxPool2d(2),                 # -> 7x7
            nn.Flatten(),
            nn.Linear(32 * 7 * 7, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        return self.net(x)

In [None]:
# -------------------
# Local training routine
# -------------------
def local_train(model, dataloader, device, epochs=1, lr=0.01):
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    for _ in range(epochs):
        for xb, yb in dataloader:
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad()
            logits = model(xb)
            loss = criterion(logits, yb)
            loss.backward()
            optimizer.step()
    return model.state_dict()

In [None]:
# -------------------
# Evaluate model on test set
# -------------------
def evaluate(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    loss_sum = 0.0
    criterion = nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for xb, yb in dataloader:
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)
            loss_sum += criterion(logits, yb).item()
            preds = logits.argmax(dim=1)
            correct += (preds == yb).sum().item()
            total += yb.size(0)
    return loss_sum / total, correct / total

In [None]:
# -------------------
# FedAvg aggregation (weighted by n_samples)
# -------------------
def fedavg_aggregate(global_state, local_states, local_sizes):
    new_state = {}
    total = sum(local_sizes)
    # iterate params
    for key in global_state.keys():
        accum = None
        for st, n in zip(local_states, local_sizes):
            w = st[key].float() * (n / total)
            if accum is None:
                accum = w.clone()
            else:
                accum += w
        new_state[key] = accum
    return new_state

In [None]:
# -------------------
# Main simulation
# -------------------
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device)

    # Hyperparams
    n_clients = 5
    alpha = 0.5           # Dirichlet concentration (smaller -> more skewed/non-IID)
    rounds = 30
    local_epochs = 1
    local_batch = 64
    lr = 0.05

    # MNIST transform / datasets
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    mnist_train = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    mnist_test  = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

    # prepare labels array for partitioning
    labels = np.array(mnist_train.targets)
    client_indices = partition_dirichlet(labels, n_clients=n_clients, alpha=alpha, rng=np.random.RandomState(42))
    # Create DataLoaders for each client
    client_loaders = []
    client_sizes = []
    for idxs in client_indices:
        sub = Subset(mnist_train, idxs)
        ld = DataLoader(sub, batch_size=local_batch, shuffle=True, num_workers=0)
        client_loaders.append(ld)
        client_sizes.append(len(idxs))
    print("Client sizes:", client_sizes)

    # global test loader
    test_loader = DataLoader(mnist_test, batch_size=512, shuffle=False, num_workers=0)

    # Initialize global model
    global_model = SimpleCNN().to(device)
    global_state = global_model.state_dict()

    # Training rounds
    hist = {"round": [], "test_loss": [], "test_acc": []}
    for r in range(1, rounds + 1):
        local_states = []
        # Each client trains locally starting from global_state
        for cid in range(n_clients):
            # create local model copy and load global weights
            local_model = SimpleCNN().to(device)
            local_model.load_state_dict(global_state)
            # train locally
            st = local_train(local_model, client_loaders[cid], device, epochs=local_epochs, lr=lr)
            local_states.append({k: v.cpu() for k, v in st.items()})  # move to cpu for aggregation
        # Aggregation (FedAvg)
        aggregated = fedavg_aggregate(global_state, local_states, client_sizes)
        # Update global_state with aggregated weights
        # replace numeric tensors in global_state
        for k in global_state.keys():
            global_state[k] = aggregated[k]
        global_model.load_state_dict(global_state)

        # Evaluate
        test_loss, test_acc = evaluate(global_model, test_loader, device)
        hist["round"].append(r)
        hist["test_loss"].append(test_loss)
        hist["test_acc"].append(test_acc)
        print(f"Round {r:02d} | Test loss {test_loss:.4f} | Test acc {test_acc:.4f}")
    print("Done. Final test acc:", hist["test_acc"][-1])

In [None]:
main()