In [7]:
import numpy as np
import torch
from torchvision import datasets, transforms
from collections import defaultdict
import random
from torch.utils.data import Subset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import copy

In [2]:
# Load MNIST
transform = transforms.Compose([transforms.ToTensor()])
mnist_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

100.0%
100.0%
100.0%
100.0%


In [3]:
# Organize indices by class
class_indices = defaultdict(list)
for idx, (img, label) in enumerate(mnist_data):
    class_indices[label].append(idx)

In [4]:
# Shuffle within each class
for indices in class_indices.values():
    random.shuffle(indices)

# Generate 10 uneven clients
client_data = [[] for _ in range(10)]
classes_per_client = [[(i + j) % 10 for j in range(2)] for i in range(10)]  # 2 dominant classes per client

In [5]:
for client_id, cls_list in enumerate(classes_per_client):
    for cls in cls_list:
        # Assign 600 samples per dominant class
        selected = class_indices[cls][:600]
        client_data[client_id].extend(selected)
        class_indices[cls] = class_indices[cls][600:]

    # Optionally add 100 random samples from other classes
    for cls in range(10):
        if cls not in cls_list and len(class_indices[cls]) >= 100:
            selected = class_indices[cls][:100]
            client_data[client_id].extend(selected)
            class_indices[cls] = class_indices[cls][100:]

In [6]:
# Convert to subsets
from torch.utils.data import Subset

client_datasets = [Subset(mnist_data, indices) for indices in client_data]

In [8]:
# Step 2: Define the model
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

In [10]:
# Step 3: Federated training loop
def train_model(model, dataloader, criterion, optimizer, epochs=1):
    model.train()
    for _ in range(epochs):
        for x, y in dataloader:
            optimizer.zero_grad()
            output = model(x)
            loss = criterion(output, y)
            loss.backward()
            optimizer.step()

def evaluate_model(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in dataloader:
            output = model(x)
            pred = output.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    return correct / total

# Federated Averaging
def federated_avg(models):
    global_model = copy.deepcopy(models[0])
    for key in global_model.state_dict().keys():
        avg = torch.stack([model.state_dict()[key] for model in models], dim=0).mean(dim=0)
        global_model.state_dict()[key].copy_(avg)
    return global_model

In [12]:
# Load test set
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_data, batch_size=128, shuffle=False)

# Federated training
rounds = 50
global_model = MLP()
criterion = nn.CrossEntropyLoss()

for r in range(rounds):
    local_models = []
    for client_dataset in client_datasets:
        local_model = copy.deepcopy(global_model)
        optimizer = optim.SGD(local_model.parameters(), lr=0.01)
        loader = DataLoader(client_dataset, batch_size=64, shuffle=True)
        train_model(local_model, loader, criterion, optimizer, epochs=1)
        local_models.append(local_model)

    global_model = federated_avg(local_models)
    accuracy = evaluate_model(global_model, test_loader)
    print(f'Round {r+1}, Test Accuracy: {accuracy:.4f}')

Round 1, Test Accuracy: 0.1012
Round 2, Test Accuracy: 0.1310
Round 3, Test Accuracy: 0.1673
Round 4, Test Accuracy: 0.1827
Round 5, Test Accuracy: 0.1880
Round 6, Test Accuracy: 0.1985
Round 7, Test Accuracy: 0.2262
Round 8, Test Accuracy: 0.2689
Round 9, Test Accuracy: 0.3262
Round 10, Test Accuracy: 0.3887
Round 11, Test Accuracy: 0.4561
Round 12, Test Accuracy: 0.5223
Round 13, Test Accuracy: 0.5869
Round 14, Test Accuracy: 0.6514
Round 15, Test Accuracy: 0.7020
Round 16, Test Accuracy: 0.7330
Round 17, Test Accuracy: 0.7558
Round 18, Test Accuracy: 0.7698
Round 19, Test Accuracy: 0.7795
Round 20, Test Accuracy: 0.7854
Round 21, Test Accuracy: 0.7899
Round 22, Test Accuracy: 0.7971
Round 23, Test Accuracy: 0.7987
Round 24, Test Accuracy: 0.8032
Round 25, Test Accuracy: 0.8064
Round 26, Test Accuracy: 0.8078
Round 27, Test Accuracy: 0.8124
Round 28, Test Accuracy: 0.8156
Round 29, Test Accuracy: 0.8206
Round 30, Test Accuracy: 0.8220
Round 31, Test Accuracy: 0.8254
Round 32, Test Ac

In [13]:
import os

# Create directory to store the client datasets
save_dir = "saved_clients"
os.makedirs(save_dir, exist_ok=True)

# Save each client dataset (indices only, since MNIST is publicly downloadable)
for i, indices in enumerate(client_data):
    file_path = os.path.join(save_dir, f"client_{i}_indices.pt")
    torch.save(indices, file_path)

print("Client dataset indices saved to disk.")


Client dataset indices saved to disk.


client_dataset: Each client has its own non-IID subset of MNIST data.

local_model: Each client starts with the same global model.

train_model(...): The model is trained only on that client's data — no sharing of raw data.

local_models: All clients' updated models after local training.

federated_avg(...): Averages all model parameters (e.g., weights, biases) layer-by-layer.

