In [1]:
# import all needed packages
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, TensorDataset
import numpy as np
import sys
import os
import random

# Get the absolute path of the src directory
src_path = os.path.abspath('../src')

# Add src_path to sys.path
if src_path not in sys.path:
    sys.path.append(src_path)
    
import fl
from fl import client_selection_with_constraints, select_indices
from models import QuantStubModel,calculate_model_size

# Hyperparameters

In [2]:
# Hyperparameters
hp = {
    "run_id": "test",
    "learning_rate": 1e-4,
    "batch_size": 32,
    "num_clients": 2,
    "num_rounds": 10,
    "num_classes": 10,  # 10 for MNIST
    "classes_per_client": 2,
    "epochs": 1,    # number of epochs to train in each round
    "split": "RANDOM",   # ["RANDOM", "NONIID"]
    "quantize": True,
    "averaging_setting": "standard"
}

## Obtain Dataset

In [3]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import torch

# Define transformations for the dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to range [-1, 1]
])

# Download and load MNIST dataset
mnist_full = datasets.MNIST(root="../data", train=True, transform=transform, download=True)
mnist_test = datasets.MNIST(root="../data", train=False, transform=transform, download=True)

# Split training data into training and validation sets
train_size = int(0.8 * len(mnist_full))  # 80% for training
val_size = len(mnist_full) - train_size  # 20% for validation
train_dataset, val_dataset = random_split(mnist_full, [train_size, val_size])

# Extract indices from Subset objects
train_indices = train_dataset.indices  # List of training indices
val_indices = val_dataset.indices      # List of validation indices

# Create training and validation MNIST datasets
mnist_train = datasets.MNIST(root="../data", train=True, transform=transform, download=False)
mnist_val = datasets.MNIST(root="../data", train=True, transform=transform, download=False)

# Filter datasets by indices
mnist_train.data = mnist_train.data[torch.tensor(train_indices)]
mnist_train.targets = mnist_train.targets[torch.tensor(train_indices)]
mnist_val.data = mnist_val.data[torch.tensor(val_indices)]
mnist_val.targets = mnist_val.targets[torch.tensor(val_indices)]

# Print dataset sizes
print(f"Train dataset size: {len(mnist_train)}, Validation dataset size: {len(mnist_val)}, Test dataset size: {len(mnist_test)}")

Train dataset size: 48000, Validation dataset size: 12000, Test dataset size: 10000


# Split Dataset: Non-IID / Random

In [4]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split, Subset
import numpy as np
from torch.utils.data import random_split

# Number of clients and non-IID split
num_clients = hp["num_clients"]
batch_size = hp["batch_size"]
classes_per_client = hp["classes_per_client"]
num_classes = hp["num_classes"]

train_data_size = len(mnist_train) // num_clients
val_data_size = len(mnist_val) // num_clients

if hp["split"] == "NONIID":
    
    # Create indices for each class
    train_class_indices = {i: np.where(np.array(mnist_train.targets) == i)[0] for i in range(num_classes)}
    val_class_indices = {i: np.where(np.array(mnist_val.targets) == i)[0] for i in range(num_classes)}

    train_indices = []
    val_indices = []
    for client_id in range(num_clients):
        chosen_classes = np.random.choice(num_classes, classes_per_client, replace=False)
        train_client_idx = []
        val_client_idx = []
        for cls in chosen_classes:
            train_cls_size = max(1, len(train_class_indices[cls]) // (num_clients // classes_per_client))
            val_cls_size = max(1, len(val_class_indices[cls]) // (num_clients // classes_per_client))
            
            # Adjust for insufficient samples
            train_cls_size = min(train_cls_size, len(train_class_indices[cls]))
            val_cls_size = min(val_cls_size, len(val_class_indices[cls]))
            
            train_cls_idx = np.random.choice(train_class_indices[cls], train_cls_size, replace=False)
            train_client_idx.extend(train_cls_idx)
            
            val_cls_idx = np.random.choice(val_class_indices[cls], val_cls_size, replace=False)
            val_client_idx.extend(val_cls_idx)
            
            # Remove assigned indices to avoid duplication
            train_class_indices[cls] = np.setdiff1d(train_class_indices[cls], train_cls_idx)
            val_class_indices[cls] = np.setdiff1d(val_class_indices[cls], val_cls_idx)

        train_indices.append(train_client_idx)
        val_indices.append(val_client_idx)

    # Create datasets and DataLoaders for each client
    train_dataset = [Subset(mnist_train, indices) for indices in train_indices]
    train_loader = [DataLoader(ds, batch_size=batch_size, shuffle=True) for ds in train_dataset]
    val_dataset = [Subset(mnist_val, indices) for indices in val_indices]
    val_loader = [DataLoader(ds, batch_size=batch_size, shuffle=True) for ds in val_dataset]

    # Debugging: Output distribution of classes for all clients
    for i in range(num_clients):
        train_sample_classes = [mnist_train.targets[idx].item() for idx in train_indices[i]]
        val_sample_classes = [mnist_val.targets[idx].item() for idx in val_indices[i]]
        
        print(f"Client {i} has training classes: {set(train_sample_classes)} and {len(train_sample_classes)} samples, "
            f"and validation classes: {set(val_sample_classes)} and {len(val_sample_classes)} samples")

else:
    
    # Split the training data into smaller datasets for each client
    train_dataset = random_split(mnist_train, [train_data_size] * num_clients)
    train_loader = [DataLoader(ds, batch_size=batch_size, shuffle=True) for ds in train_dataset]

    val_dataset = random_split(mnist_train, [train_data_size] * num_clients)
    val_loader = [DataLoader(ds, batch_size=batch_size, shuffle=True) for ds in train_dataset]

# Test DataLoader for evaluation
print(f"Simulated {num_clients} clients, each with {train_data_size} training samples, and {num_clients} with {val_data_size} validation samples")
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)

Simulated 2 clients, each with 24000 training samples, and 2 with 6000 validation samples


## Set up and initialize the Global Model

In [5]:
# 1. Initialization: Instantiate the global model (server)
global_model = QuantStubModel(q=hp['quantize'])
global_model.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")
torch.quantization.prepare_qat(global_model, inplace=True)
#torch.quantization.convert(global_model, inplace=True)

model_size_bytes, model_size_mb = calculate_model_size(global_model)
print(f"Model size averaged_state: {model_size_bytes} bytes ({model_size_mb:.2f} MB)")

Model size averaged_state: 407244 bytes (0.39 MB)




## Client training loop

In [6]:
# set up a training loop that is run on a client for a number of epochs
def train_model(model, train_loader, hp, round_num, epochs=1, q=False):
    
    # 3. Distribution: Create a copy of the global model
    local_model = QuantStubModel(q=q)
    local_model.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")
    torch.quantization.prepare_qat(local_model, inplace=True)
    local_model.load_state_dict(model.state_dict())

    # Define the loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(local_model.parameters(), lr=hp["learning_rate"])
    
    # Training loop
    local_model.train()
    total_loss = 0  # Initialize total loss
    num_batches = 0  # Initialize batch counter

    for epoch in range(epochs):
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = local_model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            num_batches += 1
    
    # Calculate average training loss
    avg_loss = total_loss / num_batches if num_batches > 0 else 0
    return local_model.state_dict(), avg_loss  # Return updated model parameters and average loss


# Validation function
def validate_model(model, val_loader, round_num):
    model.eval()
    criterion = nn.CrossEntropyLoss()  # Define the loss function
    total_loss = 0
    num_batches = 0

    with torch.no_grad():
        for inputs, labels in val_loader:
            outputs = model(inputs)
            loss = criterion(outputs, labels) #F.cross_entropy(outputs, labels, reduction='sum')  # Compute loss for the batch
            total_loss += loss.item()
            num_batches += 1
    
    avg_loss = total_loss / num_batches if num_batches > 0 else 0
    return avg_loss


# Client Selection

In [7]:
# Example Resource Info (to simulate resource heterogeneity)
client_resources = [{"comp_capacity": random.randint(10, 100), "data_size": random.randint(1, 10)} for _ in range(num_clients)]

## Run Federated Learning Round

In [8]:
from datetime import datetime
import time
import tensorflow as tf
from torch.utils.tensorboard import SummaryWriter


# Generate a unique log directory based on the current time
run_number = datetime.now().strftime('%m-%d-%H-%M')  # Month-Day_Hour-Minute
log_dir = f"./logs/{hp['run_id']}/run_{run_number}"  # Use a timestamp to distinguish runs
writer = SummaryWriter(log_dir=log_dir)

# Log hyperparameters as text
hyperparams_text = "\n".join([f"{key}: {value}" for key, value in hp.items()])

# Start time measurement
start_time = time.time()

# Federated Learning with FedCS Client Selection
num_rounds = hp["num_rounds"]
epochs = hp["epochs"]
round_deadline = 3  # Example round deadline (in arbitrary time units)

assert num_clients == len(train_loader)
print("GPU available:", torch.cuda.is_available())

# Conduct federated learning rounds
for round_num in range(num_rounds):
    print(f"Round {round_num + 1}")
    
    # 2. Resource Request
    k = random.randint(1, num_clients)
    resource_requested_clients = select_indices(num_clients, k)

    # 2. Client Selection: Collect client updates
    # TODO: Client ressources
    selected_train_clients = client_selection_with_constraints([client_resources[i] for i in resource_requested_clients], round_deadline)
    filtered_train_loaders = [train_loader[i] for i in selected_train_clients]
    filtered_val_loaders = [val_loader[i] for i in selected_train_clients]
    print(f"-> Resource requested from {len(resource_requested_clients)} clients, {len(selected_train_clients)} clients fulfilled the criteria")

    client_states = []
    round_train_loss = 0  # Initialize round loss
    round_val_loss = 0  # Initialize round loss
    num_batches = 0  # Initialize batch counter

    for client_train_loader in filtered_train_loaders:
        
        # 4. Distribution
        client_state, client_loss = train_model(global_model, client_train_loader, hp, round_num, epochs=epochs, q=hp["quantize"])
        round_train_loss += client_loss
        num_batches += 1

        # 5. Update and Upload
        client_states.append(client_state)

    # Average loss for the round
    avg_round_training_loss = round_train_loss / num_batches if num_batches > 0 else 0
    writer.add_scalar("Metrics/Training loss", avg_round_training_loss, round_num + 1)
    print(f"Training loss after round {round_num + 1}: {avg_round_training_loss}")

    # 6. Aggregation: Aggregate updates using Federated Averaging
    new_global_state = fl.federated_averaging(global_model, client_states, setting=hp["averaging_setting"])
    global_model.load_state_dict(new_global_state)

    val_loss = validate_model(global_model, client_val_loader, round_num)
    writer.add_scalar("Metrics/Validation loss", val_loss, round_num + 1)
    print(f"Validation loss after round {round_num + 1}: {val_loss}")
    print(f"Global model updated for round {round_num + 1}")

if hp["quantize"]:
    torch.quantization.convert(global_model, inplace=True)

# End time measurement
end_time = time.time()

# Print total execution time
print(f"Total time taken: {end_time - start_time:.2f} seconds")

GPU available: True
Round 1
-> Resource requested from 2 clients, 2 clients fulfilled the criteria




Training loss after round 1: 2.2735897159576415
Global model updated for round 1
Validation loss after round 1: 2.237831482887268
Round 2
-> Resource requested from 1 clients, 1 clients fulfilled the criteria
Training loss after round 2: 2.2080322046279908
Global model updated for round 2
Validation loss after round 2: 2.179085894584656
Round 3
-> Resource requested from 2 clients, 2 clients fulfilled the criteria
Training loss after round 3: 2.1527898213068646
Global model updated for round 3
Validation loss after round 3: 2.1261504402160645
Round 4
-> Resource requested from 2 clients, 2 clients fulfilled the criteria
Training loss after round 4: 2.0985108772913614
Global model updated for round 4
Validation loss after round 4: 2.0708004872004193
Round 5
-> Resource requested from 2 clients, 2 clients fulfilled the criteria
Training loss after round 5: 2.0441372764110564
Global model updated for round 5
Validation loss after round 5: 2.0163233151435853
Round 6
-> Resource requested f

## Evaluate Model

In [9]:
def evaluate_model(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    test_accuracy = correct / total
    print(f"Global model accuracy: {test_accuracy:.2%}")
    return test_accuracy

# Evaluate the model on the test dataset
test_accuracy = evaluate_model(global_model, test_loader)

# End TensorBoard writer
final_metrics = {}
writer.add_hparams(hp, final_metrics)
writer.close()

Global model accuracy: 69.40%


# Save the model

In [10]:
# Save the final model
model_save_path = f"./saved_models/{hp['run_id']}/global_model_{hp['run_id']}_{run_number}.pth"
torch.save(global_model.state_dict(), model_save_path)
print(f"Global model saved at {model_save_path}")
model_size_bytes, model_size_mb = calculate_model_size(global_model)
print(f"Model size: {model_size_bytes} bytes ({model_size_mb:.2f} MB)")

Global model saved at ./saved_models/test/global_model_test_11-28-16-56.pth
Model size: 12 bytes (0.00 MB)
