In [2]:
import numpy as np
from collections import defaultdict, Counter
from torchvision import datasets, transforms
from torch.utils.data import Subset
import random

# Load MNIST training data
transform = transforms.Compose([transforms.ToTensor()])
mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Parameters
num_clients = 100
total_samples = len(mnist_train)  # 60,000

# Extract labels and indices
all_indices = np.arange(total_samples)
labels = np.array(mnist_train.targets)
np.random.shuffle(all_indices)

# Use Dirichlet to sample varying client sizes (sums to total_samples)
proportions = np.random.dirichlet(alpha=np.ones(num_clients), size=1)[0]
client_sizes = (proportions * total_samples).astype(int)

# Fix rounding issue
diff = total_samples - np.sum(client_sizes)
client_sizes[np.argmax(client_sizes)] += diff

# Step 1: Assign samples randomly while preserving global distribution
client_indices = defaultdict(list)
start = 0
for i in range(num_clients):
    end = start + client_sizes[i]
    client_indices[i] = all_indices[start:end].tolist()
    start = end

# Step 2: Create Subset datasets
client_datasets = [Subset(mnist_train, client_indices[i]) for i in range(num_clients)]

# Debug: Show class distribution for first 5 clients
for i in range(5):
    client_labels = [mnist_train.targets[idx].item() for idx in client_datasets[i].indices]
    label_count = dict(sorted(Counter(client_labels).items()))
    print(f"Client {i} class distribution: {label_count} (Total: {len(client_datasets[i])})")

# Example: Randomly choose 20 clients for training in a round
selected_clients = random.sample(range(num_clients), 20)
print(f"\nSelected clients for this round: {selected_clients}")
selected_datasets = [client_datasets[i] for i in selected_clients]

Client 0 class distribution: {0: 84, 1: 88, 2: 72, 3: 72, 4: 64, 5: 74, 6: 65, 7: 54, 8: 72, 9: 61} (Total: 706)
Client 1 class distribution: {0: 76, 1: 89, 2: 63, 3: 77, 4: 79, 5: 63, 6: 69, 7: 87, 8: 71, 9: 71} (Total: 745)
Client 2 class distribution: {0: 130, 1: 159, 2: 142, 3: 142, 4: 142, 5: 107, 6: 134, 7: 143, 8: 131, 9: 128} (Total: 1358)
Client 3 class distribution: {0: 16, 1: 18, 2: 17, 3: 15, 4: 11, 5: 16, 6: 9, 7: 21, 8: 15, 9: 20} (Total: 158)
Client 4 class distribution: {0: 101, 1: 104, 2: 97, 3: 97, 4: 71, 5: 98, 6: 89, 7: 117, 8: 92, 9: 95} (Total: 961)

Selected clients for this round: [52, 94, 64, 99, 75, 37, 93, 15, 35, 96, 14, 76, 40, 62, 5, 61, 33, 49, 86, 87]
