In [61]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import math
import random
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from federated_learning import *

In [62]:
# Configuration
num_clients = 5
num_classes = 10
input_shape = (1, 28, 28)  # MNIST grayscale images
private_dataset_size = 5
batch_size = 1
rounds = 25
epochs = 3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = 2

print(f"Using device: {device}")

Using device: cpu


In [None]:
# Download and transform MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)) # normalization params for MNIST
])

mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Split dataset indices for clients
dataset_size = len(mnist_train)
indices = list(range(dataset_size))
split_size = dataset_size // num_clients
split_size = 15

clients_dataloaders = []
# for i in range(num_clients):
#     start = i * split_size
#     end = start + split_size if i < num_clients - 1 else dataset_size
#     client_indices = indices[start:end]
#     client_subset = Subset(mnist_train, client_indices)
#     dataloader = DataLoader(client_subset, batch_size=batch_size, shuffle=True)
#     clients_dataloaders.append(dataloader)

for i in range(num_clients):
    client_indices = random.sample(range(len(mnist_train)), private_dataset_size)
    client_subset = Subset(mnist_train, client_indices)
    dataloader = DataLoader(client_subset, batch_size=batch_size, shuffle=True)
    clients_dataloaders.append(dataloader)


# for i in range(len(clients_dataloaders)):
#     print("client", i)
#     loader = clients_dataloaders[i]
#     for x, y in loader:
#         print(y)

print(f"Prepared {num_clients} client dataloaders")

test_loader = DataLoader(dataset=mnist_test, batch_size=64, shuffle=False)

client 0
tensor([6])
tensor([1])
tensor([7])
tensor([4])
tensor([5])
client 1
tensor([0])
tensor([8])
tensor([1])
tensor([2])
tensor([7])
client 2
tensor([6])
tensor([0])
tensor([4])
tensor([7])
tensor([7])
client 3
tensor([4])
tensor([6])
tensor([8])
tensor([6])
tensor([1])
client 4
tensor([5])
tensor([4])
tensor([9])
tensor([5])
tensor([2])
Prepared 5 client dataloaders


In [64]:
global_model, _ = federated_learning(
    clients_dataloaders=clients_dataloaders,
    input_shape=input_shape,
    num_classes=num_classes,
    model=model,
    rounds=rounds,
    epochs=epochs,
    device=device
)

print("Federated learning simulation complete!")

Round 1
Round 2
Round 3
Round 4
Round 5
Round 6
Round 7
Round 8
Round 9
Round 10
Round 11
Round 12
Round 13
Round 14
Round 15
Round 16
Round 17
Round 18
Round 19
Round 20
Round 21
Round 22
Round 23
Round 24
Round 25
Federated learning simulation complete!


In [65]:
global_model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = global_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")

Test Accuracy: 40.55%


In [66]:
# save trained model to file
torch.save(global_model.state_dict(), 'trained_model_MNIST.pth')