In [1]:
import torch
import torchvision
import numpy as np
import torchvision.transforms as transforms
from collections import defaultdict
from torch.utils.data import DataLoader, Subset

In [5]:
# ==============================
# Dirichlet non-IID splitter
# ==============================
def dirichlet_split_noniid(dataset, num_clients, alpha, seed=42):
    """
    Dirichlet label-distribution-based non-IID split.

    New semantics:
      - alpha = 0   -> IID split (random, equal-size).
      - alpha > 0   -> non-IID; larger alpha => more skewed (more non-IID).

    Returns:
        dict: {client_id: np.ndarray of indices}
    """
    np.random.seed(seed)

    n_samples = len(dataset)

    # ----- Special case: alpha = 0 => IID -----
    if alpha == 0:
        all_indices = np.random.permutation(n_samples)
        splits = np.array_split(all_indices, num_clients)
        client_indices = {
            cid: splits[cid] for cid in range(num_clients)
        }
        return client_indices

    # ----- alpha > 0 => Dirichlet-based non-IID -----
    # Map user alpha to Dirichlet concentration:
    # larger user alpha -> smaller conc -> more non-IID
    dirichlet_conc = 1.0 / alpha
    # (optional: clip to avoid extreme numerical issues)
    dirichlet_conc = float(np.clip(dirichlet_conc, 1e-3, 1e3))

    # Get labels generically
    if hasattr(dataset, "targets"):
        labels = np.array(dataset.targets)
    elif hasattr(dataset, "labels"):
        labels = np.array(dataset.labels)
    else:
        labels = np.array([dataset[i][1] for i in range(n_samples)])

    num_classes = int(labels.max()) + 1
    client_indices = defaultdict(list)

    # For each class, sample Dirichlet over clients
    for c in range(num_classes):
        class_idx = np.where(labels == c)[0]
        if len(class_idx) == 0:
            continue
        np.random.shuffle(class_idx)

        # Dirichlet proportions for this class across clients
        proportions = np.random.dirichlet(
            dirichlet_conc * np.ones(num_clients)
        )

        # Convert proportions to index splits
        split_points = (np.cumsum(proportions) * len(class_idx)).astype(int)
        class_split = np.split(class_idx, split_points[:-1])

        for client_id, idx in enumerate(class_split):
            client_indices[client_id].extend(idx.tolist())

    # Shuffle indices inside each client
    for client_id in range(num_clients):
        idx = np.array(client_indices[client_id], dtype=int)
        np.random.shuffle(idx)
        client_indices[client_id] = idx

    return client_indices


# ==============================
# CIFAR-10 + per-client loaders
# ==============================
def get_cifar10_dirichlet_clients(
    data_root="./data",
    num_clients=10,
    alpha=0.5,
    batch_size=64,
    seed=42,
):
    # 1) CIFAR-10 transforms
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=(0.4914, 0.4822, 0.4465),
            std=(0.2023, 0.1994, 0.2010),
        ),
    ])

    # 2) Load CIFAR-10 train set
    trainset = torchvision.datasets.CIFAR10(
        root=data_root,
        train=True,
        download=True,
        transform=transform_train,
    )

    # 3) Get Dirichlet split indices per client
    client_idcs = dirichlet_split_noniid(
        dataset=trainset,
        num_clients=num_clients,
        alpha=alpha,
        seed=seed,
    )

    # 4) Wrap each client’s indices as a Subset + DataLoader
    client_loaders = {}
    for cid, idxs in client_idcs.items():
        subset = Subset(trainset, idxs)
        loader = DataLoader(
            subset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=2,
            pin_memory=True,
        )
        client_loaders[cid] = loader

    return trainset, client_idcs, client_loaders


# ==============================
# CIFAR-100 + per-client loaders
# ==============================
def get_cifar100_dirichlet_clients(
    data_root="./data",
    num_clients=10,
    alpha=0.5,
    batch_size=64,
    seed=42,
):
    # 1) CIFAR-100 transforms (same as CIFAR-10 usually)
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=(0.5071, 0.4867, 0.4408),
            std=(0.2675, 0.2565, 0.2761),
        ),
    ])

    # 2) Load CIFAR-100 train set
    trainset = torchvision.datasets.CIFAR100(
        root=data_root,
        train=True,
        download=True,
        transform=transform_train,
    )

    # 3) Get Dirichlet split indices per client
    client_idcs = dirichlet_split_noniid(
        dataset=trainset,
        num_clients=num_clients,
        alpha=alpha,
        seed=seed,
    )

    # 4) Wrap each client’s indices as a Subset + DataLoader
    client_loaders = {}
    for cid, idxs in client_idcs.items():
        subset = Subset(trainset, idxs)
        loader = DataLoader(
            subset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=2,
            pin_memory=True,
        )
        client_loaders[cid] = loader

    return trainset, client_idcs, client_loaders

# =====================================
# Shakespeare + per-client loaders
# =====================================
def get_shakespeare_dirichlet_clients(
    dataset,
    num_clients=10,
    alpha=0.5,
    batch_size=32,
    seed=42,
):
    """
    dataset: a PyTorch Dataset for Shakespeare samples.
             Must expose labels via .targets/.labels or via dataset[i] -> (x, y).
    """
    # 1) Get Dirichlet split indices per client
    client_idcs = dirichlet_split_noniid(
        dataset=dataset,
        num_clients=num_clients,
        alpha=alpha,
        seed=seed,
    )

    # 2) Wrap each client’s indices as a Subset + DataLoader
    client_loaders = {}
    for cid, idxs in client_idcs.items():
        subset = Subset(dataset, idxs)
        loader = DataLoader(
            subset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=2,
            pin_memory=True,
        )
        client_loaders[cid] = loader

    return dataset, client_idcs, client_loaders

In [6]:
# ==============================
# Example usage
# ==============================
if __name__ == "__main__":
    NUM_CLIENTS = 2
    ALPHA = 5  # smaller => more non-IID        # We decided to use alpha of 5 for the experiments

    trainset, client_idcs, client_loaders = get_cifar10_dirichlet_clients(
        data_root="/local/scratch/a/dalwis/single_agent_RL_for_pFL/data",
        num_clients=NUM_CLIENTS,
        alpha=ALPHA,
        batch_size=64,
        seed=123,
    )

    # Quick sanity check: print per-client sizes
    for cid, idxs in client_idcs.items():
        print(f"Client {cid}: {len(idxs)} samples")

    for i in range(NUM_CLIENTS):
      # Optional: check label distribution for a client
      labels = np.array(trainset.targets)
      cid = i
      client_labels = labels[client_idcs[cid]]
      print(f"Client {cid} label histogram:")
      print(np.bincount(client_labels, minlength=10))


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /local/scratch/a/dalwis/single_agent_RL_for_pFL/data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 47922051.37it/s]


Extracting /local/scratch/a/dalwis/single_agent_RL_for_pFL/data/cifar-10-python.tar.gz to /local/scratch/a/dalwis/single_agent_RL_for_pFL/data
Client 0: 22536 samples
Client 1: 27464 samples
Client 0 label histogram:
[ 463 4991    0 4919 4884 4999    0   10    0 2270]
Client 1 label histogram:
[4537    9 5000   81  116    1 5000 4990 5000 2730]


In [8]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import datasets, transforms
import pickle


# ============================================================
# 1. Make CIFAR-10 Dirichlet client splits
# ============================================================

def get_cifar10_dirichlet_clients(data_root, num_clients, alpha, batch_size, seed=123):

    np.random.seed(seed)
    torch.manual_seed(seed)

    transform = transforms.Compose([
        transforms.ToTensor()
    ])

    trainset = datasets.CIFAR10(root=data_root, train=True, download=True, transform=transform)

    labels = np.array(trainset.targets)
    num_classes = 10
    num_samples = len(trainset)

    # Group indices per class
    class_idx = [np.where(labels == c)[0] for c in range(num_classes)]

    client_idcs = {i: [] for i in range(num_clients)}

    # Dirichlet sampling
    for c in range(num_classes):
        np.random.shuffle(class_idx[c])
        proportions = np.random.dirichlet(alpha=[alpha] * num_clients)

        # Split indices for this class according to proportions
        proportions = (np.cumsum(proportions) * len(class_idx[c])).astype(int)[:-1]
        split = np.split(class_idx[c], proportions)

        for cid in range(num_clients):
            client_idcs[cid].extend(split[cid])

    # For reproducibility, shuffle each client's indices
    for cid in range(num_clients):
        np.random.shuffle(client_idcs[cid])

    # Build dataloaders
    client_loaders = {
        cid: DataLoader(Subset(trainset, client_idcs[cid]), batch_size=batch_size, shuffle=True)
        for cid in range(num_clients)
    }

    return trainset, client_idcs, client_loaders



# ============================================================
# 2. Save each client's dataset to disk
# ============================================================

def save_client_splits(save_dir, trainset, client_idcs):
    os.makedirs(save_dir, exist_ok=True)

    data = trainset.data        # numpy array (50000, 32, 32, 3)
    targets = np.array(trainset.targets)

    for cid, idxs in client_idcs.items():
        cid_dir = os.path.join(save_dir, f"client_{cid}")
        os.makedirs(cid_dir, exist_ok=True)

        client_data = data[idxs]
        client_targets = targets[idxs]

        # Save using numpy or pickle
        np.save(os.path.join(cid_dir, "data.npy"), client_data)
        np.save(os.path.join(cid_dir, "targets.npy"), client_targets)

        print(f"Saved Client {cid}: {len(idxs)} samples → {cid_dir}")



# ============================================================
# 3. Example usage
# ============================================================

if __name__ == "__main__":
    NUM_CLIENTS = 2
    ALPHA = 5       # More IID (We decided to use alpha of 5 for the experiments)
    BATCH_SIZE = 64

    trainset, client_idcs, client_loaders = get_cifar10_dirichlet_clients(
        data_root="/local/scratch/a/dalwis/single_agent_RL_for_pFL/data",
        num_clients=NUM_CLIENTS,
        alpha=ALPHA,
        batch_size=BATCH_SIZE,
        seed=123,
    )

    # Print sizes to confirm
    for cid, idxs in client_idcs.items():
        print(f"Client {cid}: {len(idxs)} samples")

    # Save the datasets
    save_client_splits(
        save_dir="/local/scratch/a/dalwis/single_agent_RL_for_pFL/data/cifar_10/saved_splits",
        trainset=trainset,
        client_idcs=client_idcs
    )


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /local/scratch/a/dalwis/single_agent_RL_for_pFL/data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:05<00:00, 33722682.40it/s]


Extracting /local/scratch/a/dalwis/single_agent_RL_for_pFL/data/cifar-10-python.tar.gz to /local/scratch/a/dalwis/single_agent_RL_for_pFL/data
Client 0: 22415 samples
Client 1: 27585 samples
Saved Client 0: 22415 samples → /local/scratch/a/dalwis/single_agent_RL_for_pFL/data/cifar_10/saved_splits/client_0
Saved Client 1: 27585 samples → /local/scratch/a/dalwis/single_agent_RL_for_pFL/data/cifar_10/saved_splits/client_1
