In [None]:
# ==============================
# FEDAVG CIFAR-10 NON-IID BASELINE (NO DP, NO OPACUS)
# Single-cell, fresh, runnable in Google Colab
# ==============================

import os
import json
import csv
import math
import random
from collections import defaultdict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
import torchvision
import torchvision.transforms as T
import matplotlib.pyplot as plt

# ------------------------------
# 0) Repro + Device
# ------------------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

# ------------------------------
# 1) Config
# ------------------------------
CONFIG = {
    "seed": SEED,
    "dataset": "CIFAR-10",
    "num_clients": 50,
    "clients_per_round": 10,
    "rounds": 10,                 # increase to 50-200 later
    "local_epochs": 1,            # increase to 2-5 later
    "batch_size": 64,
    "lr": 0.01,
    "momentum": 0.9,
    "weight_decay": 0.0,
    "dirichlet_alpha": 0.5,       # lower = more Non-IID
    "test_batch_size": 256,
    "results_dir": "results_fedavg_cifar10_non_iid",
}

RESULTS_DIR = CONFIG["results_dir"]
os.makedirs(RESULTS_DIR, exist_ok=True)

with open(os.path.join(RESULTS_DIR, "config.json"), "w") as f:
    json.dump(CONFIG, f, indent=2)

# ------------------------------
# 2) Data
# ------------------------------
transform_train = T.Compose([
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = T.Compose([
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

train_dataset = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform_train
)
test_dataset = torchvision.datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform_test
)

test_loader = DataLoader(test_dataset, batch_size=CONFIG["test_batch_size"], shuffle=False, num_workers=2)

# ------------------------------
# 3) Non-IID partition: Dirichlet over classes
# ------------------------------
def dirichlet_partition(dataset, num_clients: int, alpha: float, num_classes: int = 10, seed: int = 42):
    rng = np.random.default_rng(seed)
    targets = np.array(dataset.targets)

    class_indices = [np.where(targets == y)[0] for y in range(num_classes)]
    client_indices = [[] for _ in range(num_clients)]

    for c in range(num_classes):
        idx_c = class_indices[c]
        rng.shuffle(idx_c)

        proportions = rng.dirichlet(alpha * np.ones(num_clients))
        # convert proportions to counts
        counts = (proportions * len(idx_c)).astype(int)

        # fix rounding so sum(counts) == len(idx_c)
        diff = len(idx_c) - counts.sum()
        for i in rng.choice(num_clients, size=abs(diff), replace=True):
            counts[i] += 1 if diff > 0 else -1
        # guard against negative
        counts = np.clip(counts, 0, None)

        start = 0
        for k in range(num_clients):
            take = counts[k]
            if take > 0:
                client_indices[k].extend(idx_c[start:start+take].tolist())
                start += take

    # shuffle within each client
    for k in range(num_clients):
        rng.shuffle(client_indices[k])

    return client_indices

client_indices = dirichlet_partition(
    train_dataset,
    num_clients=CONFIG["num_clients"],
    alpha=CONFIG["dirichlet_alpha"],
    num_classes=10,
    seed=CONFIG["seed"]
)

sizes = [len(ix) for ix in client_indices]
print("\nPartition stats")
print("Total train samples:", len(train_dataset))
print("Clients:", CONFIG["num_clients"])
print(f"Samples per client: min={min(sizes)} max={max(sizes)} mean={np.mean(sizes):.1f} std={np.std(sizes):.1f}")

# ------------------------------
# 4) Model (simple CNN that matches CIFAR-10 shape)
# ------------------------------
class CIFAR10CNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),  # 32x32
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=3, padding=1), # 32x32
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),                             # 16x16

            nn.Conv2d(64, 128, kernel_size=3, padding=1),# 16x16
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),                             # 8x8
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 8 * 8, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# ------------------------------
# 5) Utils: evaluation + FedAvg
# ------------------------------
@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader):
    model.eval()
    total = 0
    correct = 0
    loss_sum = 0.0

    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        logits = model(x)
        loss = F.cross_entropy(logits, y, reduction="sum")
        preds = logits.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)
        loss_sum += loss.item()

    return loss_sum / total, correct / total

def get_client_loader(cid: int):
    subset = Subset(train_dataset, client_indices[cid])
    return DataLoader(subset, batch_size=CONFIG["batch_size"], shuffle=True, num_workers=2, drop_last=False)

def train_one_client(global_state: dict, cid: int):
    model = CIFAR10CNN().to(DEVICE)
    model.load_state_dict(global_state)
    model.train()

    loader = get_client_loader(cid)
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=CONFIG["lr"],
        momentum=CONFIG["momentum"],
        weight_decay=CONFIG["weight_decay"]
    )

    for _ in range(CONFIG["local_epochs"]):
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            logits = model(x)
            loss = F.cross_entropy(logits, y)
            loss.backward()
            optimizer.step()

    return {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}

def fedavg(states: list):
    # states: list of state_dicts (CPU tensors)
    avg = {}
    for k in states[0].keys():
        avg[k] = torch.zeros_like(states[0][k])
    for sd in states:
        for k in avg.keys():
            avg[k] += sd[k] / len(states)
    return avg

# ------------------------------
# 6) Federated training loop
# ------------------------------
global_model = CIFAR10CNN().to(DEVICE)
metrics = []

init_loss, init_acc = evaluate(global_model, test_loader)
print(f"\nInitial | test_loss={init_loss:.4f} test_acc={init_acc*100:.2f}%")

for rnd in range(1, CONFIG["rounds"] + 1):
    print(f"\nRound {rnd}/{CONFIG['rounds']}")

    selected = random.sample(range(CONFIG["num_clients"]), CONFIG["clients_per_round"])
    global_state = {k: v.detach().cpu().clone() for k, v in global_model.state_dict().items()}

    client_states = []
    for cid in selected:
        cs = train_one_client(global_state, cid)
        client_states.append(cs)

    new_global_state = fedavg(client_states)
    global_model.load_state_dict(new_global_state)

    test_loss, test_acc = evaluate(global_model, test_loader)
    print(f"  test_loss={test_loss:.4f} test_acc={test_acc*100:.2f}%")
    metrics.append([rnd, test_loss, test_acc])

# ------------------------------
# 7) Save outputs
# ------------------------------
metrics_path = os.path.join(RESULTS_DIR, "metrics.csv")
with open(metrics_path, "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["round", "test_loss", "test_accuracy"])
    writer.writerows(metrics)

model_path = os.path.join(RESULTS_DIR, "final_model.pt")
torch.save(global_model.state_dict(), model_path)

rounds = [m[0] for m in metrics]
losses = [m[1] for m in metrics]
accs = [m[2] for m in metrics]

plt.figure(figsize=(10,4))
plt.subplot(1,2,1)
plt.plot(rounds, accs)
plt.title("Test Accuracy")

plt.subplot(1,2,2)
plt.plot(rounds, losses)
plt.title("Test Loss")

fig_path = os.path.join(RESULTS_DIR, "training_curves.png")
plt.savefig(fig_path)
plt.close()

print("\nFEDAVG CIFAR-10 Non-IID BASELINE FINISHED")
print("Saved:")
print(" -", metrics_path)
print(" -", fig_path)
print(" -", model_path)


Device: cpu


100%|██████████| 170M/170M [00:11<00:00, 15.2MB/s]



Partition stats
Total train samples: 50000
Clients: 50
Samples per client: min=281 max=1836 mean=1000.0 std=374.2

Initial | test_loss=2.3031 test_acc=11.09%

Round 1/10
  test_loss=2.3156 test_acc=13.54%

Round 2/10
  test_loss=2.3326 test_acc=10.00%

Round 3/10
  test_loss=2.3127 test_acc=16.48%

Round 4/10
  test_loss=2.2128 test_acc=15.21%

Round 5/10
  test_loss=2.1549 test_acc=21.29%

Round 6/10
  test_loss=2.0967 test_acc=22.32%

Round 7/10
  test_loss=2.1304 test_acc=20.39%

Round 8/10
  test_loss=2.0562 test_acc=19.88%

Round 9/10
  test_loss=2.0874 test_acc=21.96%

Round 10/10
  test_loss=2.0791 test_acc=21.29%

FEDAVG CIFAR-10 Non-IID BASELINE FINISHED
Saved:
 - results_fedavg_cifar10_non_iid/metrics.csv
 - results_fedavg_cifar10_non_iid/training_curves.png
 - results_fedavg_cifar10_non_iid/final_model.pt


In [1]:
# ===============================
# FEDAVG CIFAR-10 NON-IID BASELINE
# ===============================

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import random
import os
import csv
import matplotlib.pyplot as plt

# -------------------------------
# CONFIG
# -------------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_CLIENTS = 10
CLIENTS_PER_ROUND = 5
ROUNDS = 10
LOCAL_EPOCHS = 1
BATCH_SIZE = 64
LR = 0.01
RESULTS_DIR = "results_fedavg_cifar10_non_iid"

os.makedirs(RESULTS_DIR, exist_ok=True)
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

# -------------------------------
# DATA
# -------------------------------
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)

test_dataset = datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform
)

test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

# -------------------------------
# NON-IID SPLIT (label skew)
# -------------------------------
def non_iid_split(dataset, num_clients):
    labels = np.array(dataset.targets)
    idxs = np.arange(len(dataset))
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0]

    shards = np.array_split(idxs, num_clients)
    return {i: shards[i] for i in range(num_clients)}

client_indices = non_iid_split(train_dataset, NUM_CLIENTS)

# -------------------------------
# MODEL (FIXED DIMENSIONS)
# -------------------------------
class CIFAR10CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 8 * 8, 256)  # ✅ FIX
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

# -------------------------------
# TRAIN ONE CLIENT
# -------------------------------
def train_one_client(global_state, cid):
    model = CIFAR10CNN().to(DEVICE)
    model.load_state_dict(global_state)
    model.train()

    optimizer = optim.SGD(model.parameters(), lr=LR)

    loader = DataLoader(
        Subset(train_dataset, client_indices[cid]),
        batch_size=BATCH_SIZE,
        shuffle=True
    )

    for _ in range(LOCAL_EPOCHS):
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            loss = F.cross_entropy(model(x), y)
            loss.backward()
            optimizer.step()

    return model.state_dict()

# -------------------------------
# EVALUATION
# -------------------------------
def evaluate(model):
    model.eval()
    correct, total, loss_sum = 0, 0, 0.0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            out = model(x)
            loss_sum += F.cross_entropy(out, y, reduction="sum").item()
            preds = out.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return loss_sum / total, correct / total

# -------------------------------
# FEDERATED TRAINING
# -------------------------------
global_model = CIFAR10CNN().to(DEVICE)
metrics = []

for rnd in range(1, ROUNDS + 1):
    print(f"\nRound {rnd}/{ROUNDS}")

    selected = random.sample(range(NUM_CLIENTS), CLIENTS_PER_ROUND)
    global_state = global_model.state_dict()

    agg_state = {k: torch.zeros_like(v) for k, v in global_state.items()}

    for cid in selected:
        local_state = train_one_client(global_state, cid)
        for k in agg_state:
            agg_state[k] += local_state[k] / CLIENTS_PER_ROUND

    global_model.load_state_dict(agg_state)

    loss, acc = evaluate(global_model)
    metrics.append([rnd, loss, acc])
    print(f"Test Acc: {acc*100:.2f}% | Loss: {loss:.4f}")

# -------------------------------
# SAVE RESULTS
# -------------------------------
with open(os.path.join(RESULTS_DIR, "metrics.csv"), "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["Round", "Loss", "Accuracy"])
    writer.writerows(metrics)

torch.save(
    global_model.state_dict(),
    os.path.join(RESULTS_DIR, "final_model.pt")
)

rounds = [m[0] for m in metrics]
accs = [m[2] for m in metrics]

plt.figure(figsize=(8, 4))
plt.plot(rounds, accs)
plt.xlabel("Round")
plt.ylabel("Test Accuracy")
plt.title("FedAvg CIFAR-10 Non-IID Baseline")
plt.savefig(os.path.join(RESULTS_DIR, "training_curves.png"))
plt.close()

print("\nFEDAVG CIFAR-10 NON-IID BASELINE FINISHED")
print("Saved:")
print(f"- {RESULTS_DIR}/metrics.csv")
print(f"- {RESULTS_DIR}/training_curves.png")
print(f"- {RESULTS_DIR}/final_model.pt")

100%|██████████| 170M/170M [00:02<00:00, 73.1MB/s]



Round 1/10
Test Acc: 10.00% | Loss: 2.5301

Round 2/10
Test Acc: 11.57% | Loss: 2.5730

Round 3/10
Test Acc: 13.25% | Loss: 2.6168

Round 4/10
Test Acc: 10.79% | Loss: 2.7831

Round 5/10
Test Acc: 10.66% | Loss: 2.8081

Round 6/10
Test Acc: 9.93% | Loss: 2.7666

Round 7/10
Test Acc: 10.71% | Loss: 2.8367

Round 8/10
Test Acc: 12.63% | Loss: 2.6903

Round 9/10
Test Acc: 10.30% | Loss: 2.6871

Round 10/10
Test Acc: 10.24% | Loss: 2.6997

FEDAVG CIFAR-10 NON-IID BASELINE FINISHED
Saved:
- results_fedavg_cifar10_non_iid/metrics.csv
- results_fedavg_cifar10_non_iid/training_curves.png
- results_fedavg_cifar10_non_iid/final_model.pt
