In [None]:
import numpy as np
from collections import defaultdict

import torch
from torch.utils.data import DataLoader, Subset
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

import pandas as pd
import os


In [None]:
# ==============================
# Dirichlet non-IID splitter
# ==============================
def dirichlet_split_noniid(dataset, num_clients, alpha, seed=42):
    """
    Dirichlet label-distribution-based non-IID split.

    New semantics:
      - alpha = 0   -> IID split (random, equal-size).
      - alpha > 0   -> non-IID; larger alpha => more skewed (more non-IID).

    Returns:
        dict: {client_id: np.ndarray of indices}
    """
    np.random.seed(seed)

    n_samples = len(dataset)

    # ----- Special case: alpha = 0 => IID -----
    if alpha == 0:
        all_indices = np.random.permutation(n_samples)
        splits = np.array_split(all_indices, num_clients)
        client_indices = {
            cid: splits[cid] for cid in range(num_clients)
        }
        return client_indices

    # ----- alpha > 0 => Dirichlet-based non-IID -----
    # Map user alpha to Dirichlet concentration:
    # larger user alpha -> smaller conc -> more non-IID
    dirichlet_conc = 1.0 / alpha
    # (optional: clip to avoid extreme numerical issues)
    dirichlet_conc = float(np.clip(dirichlet_conc, 1e-3, 1e3))

    # Get labels generically
    if hasattr(dataset, "targets"):
        labels = np.array(dataset.targets)
    elif hasattr(dataset, "labels"):
        labels = np.array(dataset.labels)
    else:
        labels = np.array([dataset[i][1] for i in range(n_samples)])

    num_classes = int(labels.max()) + 1
    client_indices = defaultdict(list)

    # For each class, sample Dirichlet over clients
    for c in range(num_classes):
        class_idx = np.where(labels == c)[0]
        if len(class_idx) == 0:
            continue
        np.random.shuffle(class_idx)

        # Dirichlet proportions for this class across clients
        proportions = np.random.dirichlet(
            dirichlet_conc * np.ones(num_clients)
        )

        # Convert proportions to index splits
        split_points = (np.cumsum(proportions) * len(class_idx)).astype(int)
        class_split = np.split(class_idx, split_points[:-1])

        for client_id, idx in enumerate(class_split):
            client_indices[client_id].extend(idx.tolist())

    # Shuffle indices inside each client
    for client_id in range(num_clients):
        idx = np.array(client_indices[client_id], dtype=int)
        np.random.shuffle(idx)
        client_indices[client_id] = idx

    return client_indices


# ==============================
# CIFAR-10 + per-client loaders
# ==============================
def get_cifar10_dirichlet_clients(
    data_root="./data",
    num_clients=10,
    alpha=0.5,
    batch_size=64,
    seed=42,
):
    # 1) CIFAR-10 transforms
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=(0.4914, 0.4822, 0.4465),
            std=(0.2023, 0.1994, 0.2010),
        ),
    ])

    # 2) Load CIFAR-10 train set
    trainset = torchvision.datasets.CIFAR10(
        root=data_root,
        train=True,
        download=True,
        transform=transform_train,
    )

    # 3) Get Dirichlet split indices per client
    client_idcs = dirichlet_split_noniid(
        dataset=trainset,
        num_clients=num_clients,
        alpha=alpha,
        seed=seed,
    )

    # 4) Wrap each client’s indices as a Subset + DataLoader
    client_loaders = {}
    for cid, idxs in client_idcs.items():
        subset = Subset(trainset, idxs)
        loader = DataLoader(
            subset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=2,
            pin_memory=True,
        )
        client_loaders[cid] = loader

    return trainset, client_idcs, client_loaders


# ==============================
# CIFAR-100 + per-client loaders
# ==============================
def get_cifar100_dirichlet_clients(
    data_root="./data",
    num_clients=10,
    alpha=0.5,
    batch_size=64,
    seed=42,
):
    # 1) CIFAR-100 transforms (same as CIFAR-10 usually)
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=(0.5071, 0.4867, 0.4408),
            std=(0.2675, 0.2565, 0.2761),
        ),
    ])

    # 2) Load CIFAR-100 train set
    trainset = torchvision.datasets.CIFAR100(
        root=data_root,
        train=True,
        download=True,
        transform=transform_train,
    )

    # 3) Get Dirichlet split indices per client
    client_idcs = dirichlet_split_noniid(
        dataset=trainset,
        num_clients=num_clients,
        alpha=alpha,
        seed=seed,
    )

    # 4) Wrap each client’s indices as a Subset + DataLoader
    client_loaders = {}
    for cid, idxs in client_idcs.items():
        subset = Subset(trainset, idxs)
        loader = DataLoader(
            subset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=2,
            pin_memory=True,
        )
        client_loaders[cid] = loader

    return trainset, client_idcs, client_loaders

# =====================================
# Shakespeare + per-client loaders
# =====================================
def get_shakespeare_dirichlet_clients(
    dataset,
    num_clients=10,
    alpha=0.5,
    batch_size=32,
    seed=42,
):
    """
    dataset: a PyTorch Dataset for Shakespeare samples.
             Must expose labels via .targets/.labels or via dataset[i] -> (x, y).
    """
    # 1) Get Dirichlet split indices per client
    client_idcs = dirichlet_split_noniid(
        dataset=dataset,
        num_clients=num_clients,
        alpha=alpha,
        seed=seed,
    )

    # 2) Wrap each client’s indices as a Subset + DataLoader
    client_loaders = {}
    for cid, idxs in client_idcs.items():
        subset = Subset(dataset, idxs)
        loader = DataLoader(
            subset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=2,
            pin_memory=True,
        )
        client_loaders[cid] = loader

    return dataset, client_idcs, client_loaders

In [None]:
# ==============================
# Example usage
# ==============================
if __name__ == "__main__":
    NUM_CLIENTS = 10
    ALPHA = 5  # smaller => more non-IID

    trainset, client_idcs, client_loaders = get_cifar10_dirichlet_clients(
        data_root="./data",
        num_clients=NUM_CLIENTS,
        alpha=ALPHA,
        batch_size=64,
        seed=123,
    )

    # Quick sanity check: print per-client sizes
    for cid, idxs in client_idcs.items():
        print(f"Client {cid}: {len(idxs)} samples")

    for i in range(NUM_CLIENTS):
      # Optional: check label distribution for a client
      labels = np.array(trainset.targets)
      cid = i
      client_labels = labels[client_idcs[cid]]
      print(f"Client {cid} label histogram:")
      print(np.bincount(client_labels, minlength=10))


100%|██████████| 170M/170M [00:19<00:00, 8.67MB/s]


Client 0: 1580 samples
Client 1: 3948 samples
Client 2: 4833 samples
Client 3: 4475 samples
Client 4: 7366 samples
Client 5: 6468 samples
Client 6: 1431 samples
Client 7: 8291 samples
Client 8: 4917 samples
Client 9: 6691 samples
Client 0 label histogram:
[ 55   2   3 493 215   6  13 793   0   0]
Client 1 label histogram:
[ 543    0 1151  166  399    0    1    0 1221  467]
Client 2 label histogram:
[   1   83 1311 1217  580 1629    0    0   11    1]
Client 3 label histogram:
[   2    2    0 2610    0    0 1827    0    1   33]
Client 4 label histogram:
[2682    0 1413    0   22 3055   81   22    0   91]
Client 5 label histogram:
[ 219    0    0  408 1593    0 1126  551 1805  766]
Client 6 label histogram:
[  0   0 727   0   1 244 183   8 261   7]
Client 7 label histogram:
[1497  525    0    3 2171    4 1583  336  214 1958]
Client 8 label histogram:
[   0 4318   98   60   14    0  185    0  231   11]
Client 9 label histogram:
[   1   70  297   43    5   62    1 3290 1256 1666]


In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset
import torch.nn as nn
import torch.optim as optim
import torchvision
import pandas as pd
import os
from datetime import datetime
from collections import defaultdict

In [None]:
import numpy as np
from collections import defaultdict
import torch
from torch.utils.data import DataLoader, Subset
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import os
from datetime import datetime

# ----------------------------
# Parameters
# ----------------------------
NUM_CLIENTS = len(client_idcs)
TEST_RATIO = 0.05
TRAIN_BATCH = 16
TEST_BATCH = 32
ROUNDS = 100
LOCAL_EPOCHS = 0.01        # fractional epoch (<1) or integer (>=1)
LR = 1e-3
SEED = 123
AUTOSAVE_PATH = "results_cifar10_alpha5_epoch_0_01_new.csv"
os.makedirs("results", exist_ok=True)

np.random.seed(SEED)
torch.manual_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ----------------------------
# Model & Loss
# ----------------------------
def get_model(num_classes=10):
    model = torchvision.models.resnet18(weights=None)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model.to(device)

criterion = nn.CrossEntropyLoss()

# ----------------------------
# Split train/test per client (95/5)
# ----------------------------
client_train_idx = {}
client_test_idx = {}
client_train_loaders = {}
client_test_loaders = {}
split_info = {}

for cid, idxs in client_idcs.items():
    idxs = np.array(idxs)
    np.random.shuffle(idxs)
    n = len(idxs)
    n_test = max(1, int(TEST_RATIO * n))
    train_idx = idxs[:-n_test]
    test_idx = idxs[-n_test:]

    client_train_idx[cid] = train_idx.tolist()
    client_test_idx[cid] = test_idx.tolist()
    split_info[cid] = (len(train_idx), len(test_idx))

    client_train_loaders[cid] = DataLoader(
        Subset(trainset, train_idx),
        batch_size=TRAIN_BATCH,
        shuffle=True,
        drop_last=True,
        num_workers=2,
        pin_memory=True
    )
    client_test_loaders[cid] = DataLoader(
        Subset(trainset, test_idx),
        batch_size=TEST_BATCH,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

print("Client splits complete (train/test counts):")
for cid in range(NUM_CLIENTS):
    print(f" Client {cid}: train={split_info[cid][0]}, test={split_info[cid][1]}")

# ----------------------------
# Training / Evaluation
# ----------------------------
def train_one_epoch(model, loader, optimizer):
    model.train()
    total_loss, total_correct, total_samples = 0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()

        bs = y.size(0)
        total_loss += loss.item() * bs
        total_correct += (out.argmax(1) == y).sum().item()
        total_samples += bs
    if total_samples == 0:
        return 0.0, 0.0
    return total_loss / total_samples, 100 * total_correct / total_samples

def evaluate(model, loader):
    model.eval()
    total_correct, total_samples = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            total_correct += (out.argmax(1) == y).sum().item()
            total_samples += y.size(0)
    if total_samples == 0:
        return 0.0
    return 100 * total_correct / total_samples

# ----------------------------
# Fractional epoch helper
# ----------------------------
def sample_fractional_data(cid, frac, rng):
    pool = client_train_idx[cid]
    n = len(pool)
    take = max(1, int(np.ceil(n * frac)))
    chosen = rng.choice(pool, size=take, replace=False).tolist()
    return chosen

# ----------------------------
# Initialize global model
# ----------------------------
global_model = get_model()
global_state = {k: v.cpu() for k, v in global_model.state_dict().items()}

records = []

# ----------------------------
# Federated rounds
# ----------------------------
for rnd in range(1, ROUNDS+1):
    print(f"\n======== ROUND {rnd}/{ROUNDS} ========")
    client_states = {}
    latest_record = {}
    fractional_mode = (LOCAL_EPOCHS > 0 and LOCAL_EPOCHS < 1)

    for cid in range(NUM_CLIENTS):
        local_model = get_model().to(device)
        local_model.load_state_dict({k: global_state[k].clone().to(device) for k in global_state})
        optimizer = optim.Adam(local_model.parameters(), lr=LR)
        n_train, n_test = split_info[cid]

        # Fractional epoch mode
        if fractional_mode:
            rng = np.random.RandomState(SEED + rnd*1000 + cid)
            selected = sample_fractional_data(cid, LOCAL_EPOCHS, rng)
            loader = DataLoader(
                Subset(trainset, selected),
                batch_size=TRAIN_BATCH,
                shuffle=True,
                drop_last=True
            )
            train_loss, train_acc = train_one_epoch(local_model, loader, optimizer)
            rec = {
                "round": rnd, "client": cid, "num_train": n_train, "num_test": n_test,
                "epoch": 1, "fraction": LOCAL_EPOCHS,
                "train_loss": train_loss, "train_acc": train_acc,
                "local_test_acc_before_agg": None,
                "global_test_acc_on_client": None,
                "timestamp": datetime.utcnow().isoformat()
            }
            records.append(rec)
            latest_record[cid] = len(records) - 1
            print(f"[Client {cid}] fractional train size={len(selected)} loss={train_loss:.4f} acc={train_acc:.2f}")

        # Integer epoch mode
        else:
            for ep in range(1, int(LOCAL_EPOCHS)+1):
                train_loss, train_acc = train_one_epoch(local_model, client_train_loaders[cid], optimizer)
                rec = {
                    "round": rnd, "client": cid, "num_train": n_train, "num_test": n_test,
                    "epoch": ep, "fraction": 0,
                    "train_loss": train_loss, "train_acc": train_acc,
                    "local_test_acc_before_agg": None,
                    "global_test_acc_on_client": None,
                    "timestamp": datetime.utcnow().isoformat()
                }
                records.append(rec)
                latest_record[cid] = len(records) - 1
                print(f"[Client {cid}] epoch={ep} loss={train_loss:.4f} acc={train_acc:.2f}")

        # Local test before aggregation
        local_acc = evaluate(local_model, client_test_loaders[cid])
        print(f"[Client {cid}] Local test BEFORE aggregation = {local_acc:.2f}%")
        records[latest_record[cid]]["local_test_acc_before_agg"] = local_acc

        # Save local weights for FedAvg
        client_states[cid] = {k: v.cpu() for k, v in local_model.state_dict().items()}

    # FedAvg aggregation
    new_global = {k: torch.zeros_like(global_state[k], dtype=torch.float32) for k in global_state}
    for cid in client_states:
        for k in new_global:
            new_global[k] += client_states[cid][k].float()
    for k in new_global:
        new_global[k] /= NUM_CLIENTS
        new_global[k] = new_global[k].to(global_state[k].dtype)
    global_state = {k: new_global[k].clone() for k in new_global}
    global_model.load_state_dict({k: v.to(device) for k, v in global_state.items()})

    # Evaluate global model per client
    for cid in range(NUM_CLIENTS):
        gacc = evaluate(global_model, client_test_loaders[cid])
        records[latest_record[cid]]["global_test_acc_on_client"] = gacc
        print(f"[Global model] Accuracy on client {cid} test set = {gacc:.2f}%")

    # Autosave after each round
    df = pd.DataFrame(records)
    df.to_csv(AUTOSAVE_PATH, index=False)
    print(f"Autosaved to {AUTOSAVE_PATH}")

print("\nTraining complete. Final CSV saved.")


Client splits complete (train/test counts):
 Client 0: train=1501, test=79
 Client 1: train=3751, test=197
 Client 2: train=4592, test=241
 Client 3: train=4252, test=223
 Client 4: train=6998, test=368
 Client 5: train=6145, test=323
 Client 6: train=1360, test=71
 Client 7: train=7877, test=414
 Client 8: train=4672, test=245
 Client 9: train=6357, test=334

[Client 0] fractional train size=16 loss=2.6903 acc=6.25


  "timestamp": datetime.utcnow().isoformat()


[Client 0] Local test BEFORE aggregation = 54.43%
[Client 1] fractional train size=38 loss=1.8080 acc=34.38
[Client 1] Local test BEFORE aggregation = 11.17%
[Client 2] fractional train size=46 loss=2.0340 acc=18.75
[Client 2] Local test BEFORE aggregation = 25.31%
[Client 3] fractional train size=43 loss=2.0457 acc=21.88
[Client 3] Local test BEFORE aggregation = 55.61%
[Client 4] fractional train size=70 loss=2.0513 acc=26.56
[Client 4] Local test BEFORE aggregation = 45.11%
[Client 5] fractional train size=62 loss=2.4036 acc=18.75
[Client 5] Local test BEFORE aggregation = 16.10%
[Client 6] fractional train size=14 loss=0.0000 acc=0.00
[Client 6] Local test BEFORE aggregation = 18.31%
[Client 7] fractional train size=79 loss=2.1740 acc=28.12
[Client 7] Local test BEFORE aggregation = 24.64%
[Client 8] fractional train size=47 loss=2.0076 acc=21.88
[Client 8] Local test BEFORE aggregation = 88.57%
[Client 9] fractional train size=64 loss=2.0252 acc=35.94
[Client 9] Local test BEFORE 

In [None]:
from google.colab import files
files.download("results_cifar10_alpha5_epoch_0_01_new.csv")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>