In [1]:
# import all needed packages
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, TensorDataset
import numpy as np
import sys
import os
import random

# Get the absolute path of the src directory
src_path = os.path.abspath('../src')

# Add src_path to sys.path
if src_path not in sys.path:
    sys.path.append(src_path)
    
import fl
from client import Client, ClientResources

# Load the Tensorboard notebook

In [2]:
import tensorflow as tf
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(log_dir="./logs/federated_learning")

# Load TensorBoard extension
%load_ext tensorboard

# Start TensorBoard
%tensorboard --logdir=.

2024-12-02 22:36:41.019330: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Obtain Dataset

In [3]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import torch

# Define transformations for the dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to range [-1, 1]
])

# Download and load MNIST dataset
mnist_full = datasets.MNIST(root="../data", train=True, transform=transform, download=True)
mnist_test = datasets.MNIST(root="../data", train=False, transform=transform, download=True)

# Split training data into training and validation sets
train_size = int(0.8 * len(mnist_full))  # 80% for training
val_size = len(mnist_full) - train_size  # 20% for validation
train_dataset, val_dataset = random_split(mnist_full, [train_size, val_size])

# Extract indices from Subset objects
train_indices = train_dataset.indices  # List of training indices
val_indices = val_dataset.indices      # List of validation indices

# Create training and validation MNIST datasets
mnist_train = datasets.MNIST(root="../data", train=True, transform=transform, download=False)
mnist_val = datasets.MNIST(root="../data", train=True, transform=transform, download=False)

# Filter datasets by indices
mnist_train.data = mnist_train.data[torch.tensor(train_indices)]
mnist_train.targets = mnist_train.targets[torch.tensor(train_indices)]
mnist_val.data = mnist_val.data[torch.tensor(val_indices)]
mnist_val.targets = mnist_val.targets[torch.tensor(val_indices)]

# Create DataLoaders for training and validation datasets
batch_size = 32
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(mnist_val, batch_size=batch_size, shuffle=False)

# Print dataset sizes
print(f"Train dataset size: {len(mnist_train)}, Validation dataset size: {len(mnist_val)}, Test dataset size: {len(mnist_test)}")


Train dataset size: 48000, Validation dataset size: 12000, Test dataset size: 10000


### Configuration hyperparameters

In [4]:
# Hyperparameters
hp = {
    "run_id": "delta",
    "learning_rate": 0.1,
    "batch_size": 32,
    "num_clients": 4,
    "num_rounds": 2,
    "num_classes": 10,
    "classes_per_client": 8,
    "epochs": 3,    # number of epochs to train in each round
    "split": "RANDOM"   # ["RANDOM", "NONIID"]
}

# Split Dataset: non-IID / Random

### Create client train and validation datasets

In [5]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split, Subset
import numpy as np
from torch.utils.data import random_split

# Number of clients and non-IID split
num_clients = hp["num_clients"]
batch_size = hp["batch_size"]
train_data_size = len(mnist_train) // num_clients
val_data_size = len(mnist_val) // num_clients


if hp["split"] == "NONIID":
    classes_per_client = hp["classes_per_client"]
    num_classes = hp["num_classes"]

    # Create indices for each class
    train_class_indices = {i: np.where(np.array(mnist_train.targets) == i)[0] for i in range(num_classes)}
    val_class_indices = {i: np.where(np.array(mnist_val.targets) == i)[0] for i in range(num_classes)}

    # Assign 2 classes per client
    train_indices = []
    val_indices = []
    for client_id in range(num_clients):
        chosen_classes = np.random.choice(num_classes, classes_per_client, replace=False)
        train_client_idx = []
        val_client_idx = []
        for cls in chosen_classes:
            train_cls_size = len(train_class_indices[cls]) // (num_clients // classes_per_client)
            train_cls_idx = np.random.choice(train_class_indices[cls], train_cls_size, replace=False)
            train_client_idx.extend(train_cls_idx)
            val_cls_size = len(val_class_indices[cls]) // (num_clients // classes_per_client)
            val_cls_idx = np.random.choice(val_class_indices[cls], val_cls_size, replace=False)
            val_client_idx.extend(val_cls_idx)
            # Remove assigned indices to avoid duplication
            train_class_indices[cls] = np.setdiff1d(train_class_indices[cls], train_cls_idx)
            val_class_indices[cls] = np.setdiff1d(val_class_indices[cls], val_cls_idx)

        train_indices.append(train_client_idx)
        val_indices.append(val_client_idx)

    # Create datasets and DataLoaders for each client
    train_datasets = [Subset(mnist_train, indices) for indices in train_indices]
    train_loaders = [DataLoader(ds, batch_size=batch_size, shuffle=True) for ds in train_datasets]
    val_datasets = [Subset(mnist_val, indices) for indices in val_indices]
    val_loaders = [DataLoader(ds, batch_size=batch_size, shuffle=True) for ds in val_datasets]

    # Example: Check the distribution of classes for a specific client
    train_sample_classes = [mnist_train.targets[idx].item() for idx in train_indices[0]]
    print("Client 0 has classes:", set(train_sample_classes))

else:
    
    # Split the training data into smaller datasets for each client
    train_datasets = random_split(mnist_train, [train_data_size] * num_clients)
    train_loaders = [DataLoader(ds, batch_size=batch_size, shuffle=True) for ds in train_datasets]

    val_datasets = random_split(mnist_train, [train_data_size] * num_clients)
    val_loaders = [DataLoader(ds, batch_size=batch_size, shuffle=True) for ds in train_datasets]

# Test DataLoader for evaluation
print(f"Simulated {num_clients} clients, each with {train_data_size} training samples.")
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)


Simulated 4 clients, each with 12000 training samples.


In [6]:
train_datasets

[<torch.utils.data.dataset.Subset at 0x7fc033fbe1d0>,
 <torch.utils.data.dataset.Subset at 0x7fc04a3c6dd0>,
 <torch.utils.data.dataset.Subset at 0x7fc0370c3a60>,
 <torch.utils.data.dataset.Subset at 0x7fc0370c3a90>]

In [7]:
train_data_size

12000

In [8]:
for inputs, labels in val_loaders[0]:  # Example for the first client
    print(inputs.shape, labels.shape)
    break


torch.Size([32, 1, 28, 28]) torch.Size([32])


## Set up and initialize the Global Model

In [9]:
# 1. Initialization: Instantiate the global model (server)
model = fl.create_model()
global_model = model
print(global_model)

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=128, bias=True)
  (2): ReLU()
  (3): Linear(in_features=128, out_features=10, bias=True)
)


## Client training loop

## TODO: Delegate this to the fl ml utils file

In [10]:
# set up a training loop that is run on a client for a number of epochs
def train_model(model, train_loader, hp, epochs=1):
    
    lr = hp["learning_rate"]

    # 3. Distribution: Create a copy of the global model
    local_model = fl.create_model()
    local_model.load_state_dict(model.state_dict())
    
    # Define the loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(local_model.parameters(), lr=lr)
    
    # Training loop
    local_model.train()
    total_loss = 0  # Initialize total loss
    num_batches = 0  # Initialize batch counter

    for epoch in range(epochs):
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = local_model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            num_batches += 1
    
    # Calculate average training loss
    avg_loss = total_loss / num_batches if num_batches > 0 else 0
    return local_model.state_dict(), avg_loss  # Return updated model parameters and average loss


# # Validation function
# def validate_model(model, val_loader):
#     model.eval()
#     criterion = nn.CrossEntropyLoss()  # Define the loss function
#     total_loss = 0
#     num_batches = 0

#     with torch.no_grad():
#         for inputs, labels in val_loader:
#             outputs = model(inputs)
#             loss = criterion(outputs, labels) #F.cross_entropy(outputs, labels, reduction='sum')  # Compute loss for the batch
#             total_loss += loss.item()
#             num_batches += 1
    
#     avg_loss = total_loss / num_batches if num_batches > 0 else 0
#     return avg_loss


# Clients

### Client Creation
We create clients with various resources, in which we take into account, speed, availability, battery level, bandwidth, etc.

In [11]:
# Create client instances

# Example of a client's resources
# resources = ClientResources(
#     speed_factor=1.5,
#     battery_level=80,
#     bandwidth=10.0,
#     dataset_size=1000,
#     CPU_available=True,
#     CPU_memory_availability=64.0,
#     GPU_available=True,
#     GPU_memory_availability=16.0,
# )

clients = []
for i in range(num_clients):
    mock_resources = ClientResources.generate_random((len(train_datasets[i]), len(train_datasets[i])))

    new_client = Client(id=i, resources=mock_resources, dataset=train_datasets[i], dataloader=train_loaders[i], val_loader=val_loaders[i])
    clients.append(new_client)


In [12]:
def client_selection_with_constraints(clients, deadline):
    """
    Select clients based on their resource availability and time constraints.
    Returns a list of client objects.
    
    Args:
        clients (list): List of Client objects.
        deadline (float): Maximum time allowed for training and upload.
    
    Returns:
        List of selected Client objects.
    """
    # Calculate update times for all clients
    client_times = [
        (client, client.resources.dataset_size / client.resources.speed_factor)  # (Client object, update_time)
        for client in clients
    ]
    
    # Sort clients by their update time
    client_times.sort(key=lambda x: x[1])

    selected_clients = []
    total_time = 0

    # Select clients within the deadline
    for client, update_time in client_times:
        if total_time + update_time <= deadline:
            selected_clients.append(client)
            total_time += update_time
        else:
            break  # Stop once adding a client exceeds the deadline

    return selected_clients



def select_indices(n, k):
    return random.sample(range(n), k)

# Validation

In [13]:
def validate_model(global_model, val_loaders):
    """
    Evaluate the global model on multiple client validation sets.

    Args:
        global_model (nn.Module): The global model to evaluate.
        val_loaders (list): List of DataLoaders for each client's validation set.

    Returns:
        tuple: Weighted average validation loss and accuracy.
    """
    global_model.eval()  # Set the model to evaluation mode
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    criterion = torch.nn.CrossEntropyLoss()

    with torch.no_grad():
        for val_loader in val_loaders:
            client_loss = 0.0
            client_correct = 0
            client_samples = 0

            for inputs, labels in val_loader:
                outputs = global_model(inputs)
                loss = criterion(outputs, labels)
                client_loss += loss.item() * len(labels)  # Accumulate loss weighted by batch size
                _, preds = torch.max(outputs, 1)
                client_correct += (preds == labels).sum().item()  # Count correct predictions
                client_samples += len(labels)  # Track number of samples

            # Accumulate client metrics into global metrics
            total_loss += client_loss
            total_correct += client_correct
            total_samples += client_samples

    # Compute global averages
    avg_loss = total_loss / total_samples if total_samples > 0 else 0.0
    avg_accuracy = total_correct / total_samples if total_samples > 0 else 0.0
    return avg_loss, avg_accuracy


## Run Federated Learning Round

In [14]:
# Wrapper function for client training
def train_client_parallel(client, global_model, epochs=1):
    print(f"\nTraining client {client.id} with resources {client.resources}")
    return client.id, client.train(global_model, epochs)

In [15]:
"""
This took me 2 mins to run
Simulate Federated Learning
A learning round consists of all clients training their local models and then aggregating the updates
"""
from datetime import datetime
import time
from concurrent.futures import ThreadPoolExecutor, as_completed

# Generate a unique log directory based on the current time
run_number = datetime.now().strftime('%m-%d_%H-%M')  # Month-Day_Hour-Minute
log_dir = f"./logs/{hp['run_id']}/run_{run_number}"  # Use a timestamp to distinguish runs
writer = SummaryWriter(log_dir=log_dir)

# Log hyperparameters as text
hyperparams_text = "\n".join([f"{key}: {value}" for key, value in hp.items()])

# Start time measurement
start_time = time.time()

# Federated Learning with FedCS Client Selection
num_rounds = hp["num_rounds"]
epochs = hp["epochs"]
round_deadline = np.inf  # Example round deadline (in ARBITRARY time units)

assert num_clients == len(train_loaders)
print("GPU available:", torch.cuda.is_available())

GPU available: False


In [16]:
# # Conduct federated learning rounds
# for round_num in range(num_rounds):
#     print(f"Round {round_num + 1}")
    
#     # 2. Resource Request
#     k = random.randint(1, num_clients)
#     resource_requested_clients = select_indices(num_clients, k)
#     print(f"-- Resource requested from {len(resource_requested_clients)} clients")

#     # 3. Client Selection: Collect client updates
#     # TODO: Client ressources
#     selected_train_clients = client_selection_with_constraints([clients[i] for i in resource_requested_clients], round_deadline)
#     print(f"-- Filtered to have {len(selected_train_clients)} remaining clients")

#     client_states = []
#     round_train_loss = 0  # Initialize round loss
#     round_val_loss = 0  # Initialize round loss
#     num_batches = 0  # Initialize batch counter


#     # Parallelize client training
#     client_states = []
#     with ThreadPoolExecutor(max_workers=len(clients)) as executor:
#         futures = {executor.submit(train_client_parallel, client, global_model, 1): client for client in clients}
        
#         for future in as_completed(futures):
#             client_id, client_state = future.result()
#             print(f"Client {client_id} completed training.")
#             client_states.append(client_state)
            
#     # 6. Aggregation: Aggregate updates using Federated Averaging
#     new_global_state = fl.federated_averaging(global_model, client_states)  
#     global_model.load_state_dict(new_global_state)
#     print(f"Global model updated for round {round_num + 1}")

#     # avg_round_val_loss = round_val_loss / num_batches if num_batches > 0 else 0
#     # writer.add_scalar("Metrics/Validation loss", avg_round_val_loss, round_num + 1)
#     # print(f"Validation loss after round {round_num + 1}: {avg_round_val_loss}")

# # End time measurement
# end_time = time.time()

# # Print total execution time
# print(f"Total time taken: {end_time - start_time:.2f} seconds")

In [17]:
# Conduct federated learning rounds
for round_num in range(num_rounds):
    print(f"Round {round_num + 1}")
    
    # 2. Resource Request
    k = random.randint(1, num_clients)
    resource_requested_clients = select_indices(num_clients, k)
    print(f"-- Resource requested from {len(resource_requested_clients)} clients")

    # 3. Client Selection: Collect client updates
    selected_train_clients = client_selection_with_constraints(
        [clients[i] for i in resource_requested_clients], round_deadline
    )
    print(f"-- Filtered to have {len(selected_train_clients)} remaining clients")

    client_states = []

    # Parallelize client training
    with ThreadPoolExecutor(max_workers=len(selected_train_clients)) as executor:
        futures = {
            executor.submit(train_client_parallel, client, global_model, 1): client
            for client in selected_train_clients
        }
        
        for future in as_completed(futures):
            client_id, client_state = future.result()
            print(f"Client {client_id} completed training.")
            client_states.append(client_state)
            
    # 6. Aggregation: Aggregate updates using Federated Averaging
    new_global_state = fl.federated_averaging(global_model, client_states)  
    global_model.load_state_dict(new_global_state)
    print(f"Global model updated for round {round_num + 1}")

    # 7. Evaluate on Validation Set
    val_loaders = [client.val_loader for client in clients]  # Get all validation loaders
    val_loss, val_accuracy = validate_model(global_model, val_loaders)
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2%}")
    
    # Optional: Log metrics for visualization
    writer.add_scalar("Metrics/Validation Loss", val_loss, round_num + 1)
    writer.add_scalar("Metrics/Validation Accuracy", val_accuracy, round_num + 1)


Round 1
-- Resource requested from 2 clients
-- Filtered to have 2 remaining clients

Training client 2 with resources ClientResources(speed_factor=1.755834982776784, battery_level=82.5207510132472, bandwidth=83.76747232486069, dataset_size=12000, CPU_available=True, CPU_memory_availability=20.852758641065478, GPU_available=False, GPU_memory_availability=0)

Training client 1 with resources ClientResources(speed_factor=1.6026389530755347, battery_level=61.14394697717021, bandwidth=45.310686196441274, dataset_size=12000, CPU_available=True, CPU_memory_availability=9.039496459714954, GPU_available=True, GPU_memory_availability=28.052662035169973)
Training round complete in 10.37: seconds
Training round complete in 10.38: seconds
Client simulated to take 16.62 seconds for training
Client 1 completed training.
Client simulated to take 18.24 seconds for training
Client 2 completed training.
Global model updated for round 1
Validation Loss: 0.4443, Validation Accuracy: 86.10%
Round 2
-- Reso

## Evaluate Model

In [18]:
def evaluate_model(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    test_accuracy = correct / total
    print(f"Global model accuracy: {test_accuracy:.2%}")
    return test_accuracy
    
# Evaluate the model on the test dataset
test_accuracy = evaluate_model(global_model, test_loader)

# End TensorBoard writer
final_metrics = {}
writer.add_hparams(hp, final_metrics)
writer.close()

Global model accuracy: 92.16%


In [19]:
# Save the final model
model_save_path = f"./saved_models/global_model_{hp['run_id']}_{run_number}.pth"
torch.save(global_model.state_dict(), model_save_path)
print(f"Global model saved at {model_save_path}")

Global model saved at ./saved_models/global_model_delta_12-02_22-36.pth
