In [15]:
# import all needed packages
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, TensorDataset
import numpy as np
import sys
import os

# Get the absolute path of the src directory
src_path = os.path.abspath('../src')

# Add src_path to sys.path
if src_path not in sys.path:
    sys.path.append(src_path)
    
import fl

## Obtain Dataset

In [6]:
# download dataset and preprocess/transform
from torchvision import datasets, transforms

# Define transformations for the dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to range [-1, 1]
])

# Download and load MNIST dataset
mnist_train = datasets.MNIST(root="../data", train=True, transform=transform, download=True)
mnist_test = datasets.MNIST(root="../data", train=False, transform=transform, download=True)

# Print dataset sizes
print(f"Train dataset size: {len(mnist_train)}, Test dataset size: {len(mnist_test)}")

Train dataset size: 60000, Test dataset size: 10000


## Split Dataset

In [7]:
from torch.utils.data import random_split

# Define the number of clients and split sizes
num_clients = 5
client_data_size = len(mnist_train) // num_clients

# Split the training data into smaller datasets for each client
client_datasets = random_split(mnist_train, [client_data_size] * num_clients)

# Create DataLoaders for each client
client_loaders = [DataLoader(ds, batch_size=32, shuffle=True) for ds in client_datasets]

# Test DataLoader for evaluation
test_loader = DataLoader(mnist_test, batch_size=32, shuffle=False)

print(f"Simulated {num_clients} clients, each with {client_data_size} training samples.")


Simulated 5 clients, each with 12000 training samples.


In [8]:
[client_data_size] * num_clients

[12000, 12000, 12000, 12000, 12000]

## Set up and initialize the Global Model

In [17]:
# instantiate the global model (server)
model = fl.create_model()
global_model = model
print(global_model)

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=128, bias=True)
  (2): ReLU()
  (3): Linear(in_features=128, out_features=10, bias=True)
)


## Client training loop

In [18]:
# set up a training loop that is run on a client for a number of epochs
def train_client(model, dataloader, epochs=1):
    # Create a copy of the global model
    local_model = create_model()
    local_model.load_state_dict(model.state_dict())
    
    # Define the loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(local_model.parameters(), lr=0.01)
    
    # Training loop
    local_model.train()
    for epoch in range(epochs):
        for inputs, labels in dataloader:
            optimizer.zero_grad()
            outputs = local_model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    
    return local_model.state_dict()  # Return updated model parameters


## Run Federated Learning Round

In [14]:
# This took me 2 mins to run
# Simulate federated learning
# A learning round consists of all clients training their local models and then aggregating the updates
num_rounds = 4
assert num_clients == len(client_loaders)

# Here we conduct federated learning rounds
for round_num in range(num_rounds):
    print(f"Round {round_num + 1}")
    
    # Collect client updates
    client_states = []
    for client_loader in client_loaders:
        client_state = train_client(global_model, client_loader, epochs=3)
        client_states.append(client_state)
    
    # Aggregate updates using Federated Averaging
    new_global_state = fl.federated_averaging(global_model, client_states)
    global_model.load_state_dict(new_global_state)
    
    print(f"Global model updated for round {round_num + 1}")


Round 1
Global model updated for round 1
Round 2
Global model updated for round 2
Round 3
Global model updated for round 3
Round 4
Global model updated for round 4


## Evaluate Model

In [19]:
def evaluate_model(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = correct / total
    print(f"Global model accuracy: {accuracy:.2%}")

# Evaluate the model on the test dataset
evaluate_model(global_model, test_loader)


Global model accuracy: 93.38%
