In [204]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import math
import random
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from federated_learning import *

In [None]:
# Configuration
model = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
rng_seed = 0

# Clients
num_clients = 10
private_dataset_size = 20

# Dataset
num_classes = 10
input_shape = (1, 28, 28)  # MNIST grayscale images

# Traing params
learning_rate = 0.01
loss_fn = torch.nn.CrossEntropyLoss()
batch_size = 1
rounds = 5
epochs = 1

print(f"Using device: {device}")

Using device: cpu


In [206]:
# Make reproducible
random.seed(rng_seed)
torch.manual_seed(rng_seed)
np.random.seed(rng_seed)

In [207]:
# Download and transform MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)) # normalization params for MNIST
])

mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

clients_dataloaders = []
for i in range(num_clients):
    client_indices = random.sample(range(len(mnist_train)), private_dataset_size)
    client_subset = Subset(mnist_train, client_indices)
    dataloader = DataLoader(client_subset, batch_size=batch_size, shuffle=True)
    clients_dataloaders.append(dataloader)

print(f"Prepared {num_clients} client dataloaders")

# inspect private datasets
for idx in range(len(clients_dataloaders)):
    data = []
    
    for dp, y in clients_dataloaders[idx]:
        data.append(int(y[0]))
    print('Private dataset ', idx, ': ', data, '. Missing: ', [x for x in range(10) if x not in data], sep='')

test_loader = DataLoader(dataset=mnist_test, batch_size=batch_size, shuffle=False)

Prepared 10 client dataloaders
Private dataset 0: [8, 1, 7, 2, 1, 7, 0, 3, 1, 7, 0, 0, 9, 7, 9, 8, 9, 3, 0, 1, 7, 5, 5, 2, 4, 9, 2, 5, 3, 5, 9, 7, 2, 2, 6, 0, 2, 6, 9, 8, 8, 9, 5, 3, 4, 0, 0, 5, 9, 1, 8, 8, 9, 1, 9, 4, 9, 8, 8, 0, 4, 5, 4, 7, 1, 1, 9, 1, 6, 8, 2, 9, 0, 3, 1, 9, 9, 2, 5, 4, 4, 9, 9, 2, 9, 7, 8, 5, 1, 8, 1, 0, 3, 6, 8, 0, 2, 8, 6, 6]. Missing: []
Private dataset 1: [0, 4, 3, 9, 2, 3, 6, 0, 8, 2, 2, 3, 9, 4, 4, 7, 7, 0, 8, 2, 2, 9, 3, 1, 4, 2, 6, 8, 7, 1, 4, 5, 8, 8, 4, 3, 9, 2, 8, 1, 3, 9, 5, 0, 1, 7, 7, 2, 9, 8, 4, 5, 9, 9, 5, 6, 3, 3, 7, 4, 3, 0, 0, 6, 2, 7, 5, 5, 5, 5, 0, 1, 0, 9, 9, 2, 4, 6, 6, 6, 4, 3, 4, 7, 2, 3, 8, 9, 1, 4, 1, 9, 4, 5, 1, 8, 0, 0, 9, 5]. Missing: []
Private dataset 2: [1, 0, 2, 6, 4, 9, 3, 4, 7, 6, 3, 0, 7, 2, 0, 6, 7, 3, 6, 5, 7, 7, 1, 2, 2, 8, 7, 4, 6, 1, 3, 6, 0, 5, 4, 9, 2, 6, 2, 0, 7, 0, 7, 8, 1, 6, 5, 0, 8, 6, 9, 9, 0, 9, 8, 6, 5, 8, 2, 5, 0, 3, 9, 8, 9, 3, 1, 4, 6, 8, 4, 4, 6, 7, 5, 9, 3, 7, 7, 4, 8, 4, 4, 8, 7, 9, 7, 3, 8, 6, 7, 7, 2, 1, 6

In [208]:
# Model training
global_model, _ = federated_learning(
    clients_dataloaders=clients_dataloaders,
    input_shape=input_shape,
    num_classes=num_classes,
    lr=learning_rate,
    criterion=loss_fn,
    model=model,
    rounds=rounds,
    epochs=epochs,
    device=device
)

print("Federated learning simulation complete!")

Round 1
Round 2
Round 3
Round 4
Round 5
Federated learning simulation complete!


In [209]:
# Model evaluation
global_model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = global_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")

Test Accuracy: 19.00%


In [210]:
# Save trained model to file
torch.save(global_model.state_dict(), 'trained_model_MNIST.pth')