In [1]:
# Step 1: Setup
import numpy as np
import random
from collections import defaultdict
from torchvision import datasets, transforms
from torch.utils.data import Subset

transform = transforms.Compose([transforms.ToTensor()])
mnist_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Organize indices by class
class_indices = defaultdict(list)
for idx, (_, label) in enumerate(mnist_data):
    class_indices[label].append(idx)

for indices in class_indices.values():
    random.shuffle(indices)

# Step 2: Generate clients with quantity skew
client_data = [[] for _ in range(10)]

# Assign random target sizes (small to large)
target_sizes = np.random.randint(500, 2500, size=10)  # between 500 and 2500 samples per client

# Each client gets samples mainly from 2 classes
classes_per_client = [[(i + j) % 10 for j in range(2)] for i in range(10)]

for client_id, (cls_list, target_size) in enumerate(zip(classes_per_client, target_sizes)):
    samples_needed = target_size
    while samples_needed > 0:
        for cls in cls_list:
            if len(class_indices[cls]) > 0 and samples_needed > 0:
                client_data[client_id].append(class_indices[cls].pop())
                samples_needed -= 1
        # Optionally add random classes
        for cls in range(10):
            if cls not in cls_list and len(class_indices[cls]) > 0 and samples_needed > 0:
                client_data[client_id].append(class_indices[cls].pop())
                samples_needed -= 1

# Convert to datasets
client_datasets = [Subset(mnist_data, indices) for indices in client_data]

# Show client dataset sizes
for i, dataset in enumerate(client_datasets):
    print(f"Client {i}: {len(dataset)} samples")


Client 0: 1483 samples
Client 1: 2040 samples
Client 2: 676 samples
Client 3: 1058 samples
Client 4: 1021 samples
Client 5: 902 samples
Client 6: 1470 samples
Client 7: 2279 samples
Client 8: 525 samples
Client 9: 2021 samples


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import copy

# Define your model
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

# Training function
def train_model(model, dataloader, criterion, optimizer, epochs=1):
    model.train()
    for _ in range(epochs):
        for x, y in dataloader:
            optimizer.zero_grad()
            output = model(x)
            loss = criterion(output, y)
            loss.backward()
            optimizer.step()

# Evaluation function
def evaluate_model(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in dataloader:
            output = model(x)
            pred = output.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    return correct / total

# Federated Averaging
def federated_avg(models):
    global_model = copy.deepcopy(models[0])
    for key in global_model.state_dict().keys():
        avg = torch.stack([model.state_dict()[key] for model in models], dim=0).mean(dim=0)
        global_model.state_dict()[key].copy_(avg)
    return global_model

# Load the test set
transform = transforms.ToTensor()
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_data, batch_size=128, shuffle=False)

# Federated training
rounds = 50
global_model = MLP()
criterion = nn.CrossEntropyLoss()

for r in range(rounds):
    local_models = []
    for client_dataset in client_datasets:
        local_model = copy.deepcopy(global_model)
        optimizer = optim.SGD(local_model.parameters(), lr=0.01)
        loader = DataLoader(client_dataset, batch_size=64, shuffle=True)
        train_model(local_model, loader, criterion, optimizer, epochs=1)
        local_models.append(local_model)

    global_model = federated_avg(local_models)
    accuracy = evaluate_model(global_model, test_loader)
    print(f'Round {r+1}, Test Accuracy: {accuracy:.4f}')


Round 1, Test Accuracy: 0.1033
Round 2, Test Accuracy: 0.1133
Round 3, Test Accuracy: 0.1330
Round 4, Test Accuracy: 0.1682
Round 5, Test Accuracy: 0.2051
Round 6, Test Accuracy: 0.2371
Round 7, Test Accuracy: 0.2675
Round 8, Test Accuracy: 0.2898
Round 9, Test Accuracy: 0.3113
Round 10, Test Accuracy: 0.3432
Round 11, Test Accuracy: 0.3768
Round 12, Test Accuracy: 0.4100
Round 13, Test Accuracy: 0.4377
Round 14, Test Accuracy: 0.4616
Round 15, Test Accuracy: 0.4819
Round 16, Test Accuracy: 0.4978
Round 17, Test Accuracy: 0.5123
Round 18, Test Accuracy: 0.5213
Round 19, Test Accuracy: 0.5323
Round 20, Test Accuracy: 0.5457
Round 21, Test Accuracy: 0.5686
Round 22, Test Accuracy: 0.5854
Round 23, Test Accuracy: 0.6047
Round 24, Test Accuracy: 0.6246
Round 25, Test Accuracy: 0.6437
Round 26, Test Accuracy: 0.6610
Round 27, Test Accuracy: 0.6762
Round 28, Test Accuracy: 0.6899
Round 29, Test Accuracy: 0.6972
Round 30, Test Accuracy: 0.7044
Round 31, Test Accuracy: 0.7118
Round 32, Test Ac