In [1]:
# import all needed packages
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, TensorDataset
import numpy as np
import sys
import os
import random

# Get the absolute path of the src directory
src_path = os.path.abspath('../src')

# Add src_path to sys.path
if src_path not in sys.path:
    sys.path.append(src_path)
    
import fl
from fl import FedCSModel

## Obtain Dataset

In [2]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import torch

# Define transformations for the dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to range [-1, 1]
])

# Download and load MNIST dataset
mnist_full = datasets.MNIST(root="../data", train=True, transform=transform, download=True)
mnist_test = datasets.MNIST(root="../data", train=False, transform=transform, download=True)

# Split training data into training and validation sets
train_size = int(0.8 * len(mnist_full))  # 80% for training
val_size = len(mnist_full) - train_size  # 20% for validation
train_dataset, val_dataset = random_split(mnist_full, [train_size, val_size])

# Extract indices from Subset objects
train_indices = train_dataset.indices  # List of training indices
val_indices = val_dataset.indices      # List of validation indices

# Create training and validation MNIST datasets
mnist_train = datasets.MNIST(root="../data", train=True, transform=transform, download=False)
mnist_val = datasets.MNIST(root="../data", train=True, transform=transform, download=False)

# Filter datasets by indices
mnist_train.data = mnist_train.data[torch.tensor(train_indices)]
mnist_train.targets = mnist_train.targets[torch.tensor(train_indices)]
mnist_val.data = mnist_val.data[torch.tensor(val_indices)]
mnist_val.targets = mnist_val.targets[torch.tensor(val_indices)]

# Create DataLoaders for training and validation datasets
batch_size = 32
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(mnist_val, batch_size=batch_size, shuffle=False)

# Print dataset sizes
print(f"Train dataset size: {len(mnist_train)}, Validation dataset size: {len(mnist_val)}, Test dataset size: {len(mnist_test)}")


Train dataset size: 48000, Validation dataset size: 12000, Test dataset size: 10000


Hyperparameters

In [3]:
# Hyperparameters
hp = {
    "run_id": "delta",
    "learning_rate": 0.1,
    "batch_size": 32,
    "num_clients": 10,
    "num_rounds": 20,
    "num_classes": 10,
    "classes_per_client": 2,
    "epochs": 3,    # number of epochs to train in each round
    "split": "RANDOM"   # ["RANDOM", "NONIID"]
}

# Split Dataset: Non-IID / Random

In [4]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split, Subset
import numpy as np
from torch.utils.data import random_split

# Number of clients and non-IID split
num_clients = hp["num_clients"]
batch_size = hp["batch_size"]
train_data_size = len(mnist_train) // num_clients
val_data_size = len(mnist_val) // num_clients


if hp["split"] == "NONIID":
    classes_per_client = hp["classes_per_client"]
    num_classes = hp["num_classes"]

    # Create indices for each class
    train_class_indices = {i: np.where(np.array(mnist_train.targets) == i)[0] for i in range(num_classes)}
    val_class_indices = {i: np.where(np.array(mnist_val.targets) == i)[0] for i in range(num_classes)}

    # Assign 2 classes per client
    train_indices = []
    val_indices = []
    for client_id in range(num_clients):
        chosen_classes = np.random.choice(num_classes, classes_per_client, replace=False)
        train_client_idx = []
        val_client_idx = []
        for cls in chosen_classes:
            train_cls_size = len(train_class_indices[cls]) // (num_clients // classes_per_client)
            train_cls_idx = np.random.choice(train_class_indices[cls], train_cls_size, replace=False)
            train_client_idx.extend(train_cls_idx)
            val_cls_size = len(val_class_indices[cls]) // (num_clients // classes_per_client)
            val_cls_idx = np.random.choice(val_class_indices[cls], val_cls_size, replace=False)
            val_client_idx.extend(val_cls_idx)
            # Remove assigned indices to avoid duplication
            train_class_indices[cls] = np.setdiff1d(train_class_indices[cls], train_cls_idx)
            val_class_indices[cls] = np.setdiff1d(val_class_indices[cls], val_cls_idx)

        train_indices.append(train_client_idx)
        val_indices.append(val_client_idx)

    # Create datasets and DataLoaders for each client
    train_dataset = [Subset(mnist_train, indices) for indices in train_indices]
    train_loader = [DataLoader(ds, batch_size=batch_size, shuffle=True) for ds in train_dataset]
    val_dataset = [Subset(mnist_val, indices) for indices in val_indices]
    val_loader = [DataLoader(ds, batch_size=batch_size, shuffle=True) for ds in val_dataset]

    # Example: Check the distribution of classes for a specific client
    train_sample_classes = [mnist_train.targets[idx].item() for idx in train_indices[0]]
    print("Client 0 has classes:", set(train_sample_classes))

else:
    
    # Split the training data into smaller datasets for each client
    train_dataset = random_split(mnist_train, [train_data_size] * num_clients)
    train_loader = [DataLoader(ds, batch_size=batch_size, shuffle=True) for ds in train_dataset]

    val_dataset = random_split(mnist_train, [train_data_size] * num_clients)
    val_loader = [DataLoader(ds, batch_size=batch_size, shuffle=True) for ds in train_dataset]

# Test DataLoader for evaluation
print(f"Simulated {num_clients} clients, each with {train_data_size} training samples.")
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)


Simulated 10 clients, each with 4800 training samples.


## Set up and initialize the Global Model

In [5]:
# 1. Initialization: Instantiate the global model (server)
global_model = FedCSModel(q=True)
global_model.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")
torch.quantization.prepare_qat(global_model, inplace=True)



FedCSModel(
  (quant): QuantStub(
    (activation_post_process): FusedMovingAvgObsFakeQuantize(
      fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=127, qscheme=torch.per_tensor_affine, reduce_range=True
      (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
    )
  )
  (dequant): DeQuantStub()
  (fc1): Linear(
    in_features=784, out_features=128, bias=True
    (weight_fake_quant): FusedMovingAvgObsFakeQuantize(
      fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.qint8, quant_min=-128, quant_max=127, qscheme=torch.per_channel_symmetric, reduce_range=False
      (activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
    )
    (activation_post_process): FusedMovingAvgObsFakeQuantize(
      fake_

## Client training loop

In [6]:
# set up a training loop that is run on a client for a number of epochs
def train_model(model, train_loader, hp, epochs=1):
    
    lr = hp["learning_rate"]

    # 3. Distribution: Create a copy of the global model
    local_model = FedCSModel(q=True)
    local_model.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")
    torch.quantization.prepare_qat(local_model, inplace=True)
    local_model.load_state_dict(model.state_dict())

    # Define the loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(local_model.parameters(), lr=lr)
    
    # Training loop
    local_model.train()
    total_loss = 0  # Initialize total loss
    num_batches = 0  # Initialize batch counter

    for epoch in range(epochs):
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = local_model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            num_batches += 1
    
    # Calculate average training loss
    #torch.quantization.convert(local_model, inplace=True)
    avg_loss = total_loss / num_batches if num_batches > 0 else 0
    return local_model.state_dict(), avg_loss  # Return updated model parameters and average loss


# Validation function
def validate_model(model, val_loader):
    model.eval()
    criterion = nn.CrossEntropyLoss()  # Define the loss function
    total_loss = 0
    num_batches = 0

    with torch.no_grad():
        for inputs, labels in val_loader:
            outputs = model(inputs)
            loss = criterion(outputs, labels) #F.cross_entropy(outputs, labels, reduction='sum')  # Compute loss for the batch
            total_loss += loss.item()
            num_batches += 1
    
    avg_loss = total_loss / num_batches if num_batches > 0 else 0
    return avg_loss


# Client Selection

In [7]:
# Example Resource Info (to simulate resource heterogeneity)
client_resources = [{"comp_capacity": random.randint(10, 100), "data_size": random.randint(1, 10)} for _ in range(num_clients)]

def client_selection_with_constraints(client_resources, deadline):
    """
    Select clients based on their resource availability and time constraints.
    """
    selected_clients = []
    total_time = 0  # Track elapsed time
    remaining_clients = list(range(len(client_resources)))  # Indices of available clients

    while remaining_clients:
        # Sort remaining clients by minimum time to complete training and upload
        best_client = None
        min_time = float('inf')

        for client in remaining_clients:
            resource = client_resources[client]
            update_time = resource["data_size"] / resource["comp_capacity"]  # Simplified time calculation
            if total_time + update_time < deadline and update_time < min_time:
                best_client = client
                min_time = update_time

        if best_client is None:
            break  # No more clients can be selected within the deadline

        # Select the best client
        selected_clients.append(best_client)
        total_time += min_time
        remaining_clients.remove(best_client)

    return selected_clients


def select_indices(n, k):
    return random.sample(range(n), k)

## Run Federated Learning Round

In [None]:
"""
This took me 2 mins to run
Simulate Federated Learning
A learning round consists of all clients training their local models and then aggregating the updates
"""
from datetime import datetime
import time
import tensorflow as tf
from torch.utils.tensorboard import SummaryWriter

# Generate a unique log directory based on the current time
run_number = datetime.now().strftime('%m-%d-%H-%M')  # Month-Day_Hour-Minute
log_dir = f"./logs/{hp['run_id']}/run_{run_number}"  # Use a timestamp to distinguish runs
writer = SummaryWriter(log_dir=log_dir)

# Log hyperparameters as text
hyperparams_text = "\n".join([f"{key}: {value}" for key, value in hp.items()])

# Start time measurement
start_time = time.time()

# Federated Learning with FedCS Client Selection
num_rounds = hp["num_rounds"]
epochs = hp["epochs"]
round_deadline = 3  # Example round deadline (in arbitrary time units)

assert num_clients == len(train_loader)
print("GPU available:", torch.cuda.is_available())

# Conduct federated learning rounds
for round_num in range(num_rounds):
    print(f"Round {round_num + 1}")
    
    # 2. Resource Request
    k = random.randint(1, num_clients)
    resource_requested_clients = select_indices(num_clients, k)
    print(f"-- Resource requested from {len(resource_requested_clients)} clients")

    # 2. Client Selection: Collect client updates
    # TODO: Client ressources
    selected_train_clients = client_selection_with_constraints([client_resources[i] for i in resource_requested_clients], round_deadline)
    print(f"-- Filtered {len(selected_train_clients)} clients")
    filtered_train_loaders = [train_loader[i] for i in selected_train_clients]
    filtered_val_loaders = [val_loader[i] for i in selected_train_clients]

    client_states = []
    round_train_loss = 0  # Initialize round loss
    round_val_loss = 0  # Initialize round loss
    num_batches = 0  # Initialize batch counter

    for client_train_loader, client_val_loader in zip(filtered_train_loaders, filtered_val_loaders):
        
        # 4. Distribution
        client_state, client_loss = train_model(global_model, client_train_loader, hp, epochs=epochs)

        # Log client loss
        round_train_loss += client_loss
        num_batches += 1

        # Validation Phase
        val_loss = validate_model(global_model, client_val_loader)
        round_val_loss += val_loss

        # 5. Update and Upload
        client_states.append(client_state)

    # Average loss for the round
    avg_round_training_loss = round_train_loss / num_batches if num_batches > 0 else 0
    writer.add_scalar("Metrics/Training loss", avg_round_training_loss, round_num + 1)
    print(f"Training loss after round {round_num + 1}: {avg_round_training_loss}")

    # 6. Aggregation: Aggregate updates using Federated Averaging
    new_global_state = fl.federated_averaging(global_model, client_states)  
    global_model.load_state_dict(new_global_state)
    print(f"Global model updated for round {round_num + 1}")

    avg_round_val_loss = round_val_loss / num_batches if num_batches > 0 else 0
    writer.add_scalar("Metrics/Validation loss", avg_round_val_loss, round_num + 1)
    print(f"Validation loss after round {round_num + 1}: {avg_round_val_loss}")

torch.quantization.convert(global_model, inplace=True)

# End time measurement
end_time = time.time()

# Print total execution time
print(f"Total time taken: {end_time - start_time:.2f} seconds")

GPU available: True
Round 1
-- Resource requested from 10 clients
-- Filtered 10 clients




Training loss after round 1: 0.7076105917609399
Global model updated for round 1
Validation loss after round 1: 2.3298020237286883
Round 2
-- Resource requested from 5 clients
-- Filtered 5 clients
Training loss after round 2: 0.3503669883691602
Global model updated for round 2
Validation loss after round 2: 0.3663866283198198
Round 3
-- Resource requested from 4 clients
-- Filtered 4 clients
Training loss after round 3: 0.24976531516036227
Global model updated for round 3
Validation loss after round 3: 0.22389184047468008
Round 4
-- Resource requested from 5 clients
-- Filtered 5 clients
Training loss after round 4: 0.205326588033802
Global model updated for round 4
Validation loss after round 4: 0.17989828748752673
Round 5
-- Resource requested from 4 clients
-- Filtered 4 clients
Training loss after round 5: 0.16651358960532686
Global model updated for round 5
Validation loss after round 5: 0.1478503160442536
Round 6
-- Resource requested from 2 clients
-- Filtered 2 clients
Trainin

In [None]:
# Print the quantized model
print("Quantized model:")
print(global_model)

Quantized model:
FedCSModel(
  (quant): Quantize(scale=tensor([0.0157]), zero_point=tensor([64]), dtype=torch.quint8)
  (dequant): DeQuantize()
  (fc1): QuantizedLinear(in_features=784, out_features=128, scale=0.1566275954246521, zero_point=70, qscheme=torch.per_channel_affine)
  (relu): ReLU()
  (fc2): QuantizedLinear(in_features=128, out_features=10, scale=0.2665318250656128, zero_point=58, qscheme=torch.per_channel_affine)
)


## Evaluate Model

In [None]:
def evaluate_model(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    test_accuracy = correct / total
    print(f"Global model accuracy: {test_accuracy:.2%}")
    return test_accuracy
    
# Evaluate the model on the test dataset
test_accuracy = evaluate_model(global_model, test_loader)

# End TensorBoard writer
final_metrics = {}
writer.add_hparams(hp, final_metrics)
writer.close()

Global model accuracy: 96.73%


In [None]:
# Save the final model
model_save_path = f"./saved_models/global_model_{hp['run_id']}_{run_number}.pth"
torch.save(global_model.state_dict(), model_save_path)
print(f"Global model saved at {model_save_path}")

Global model saved at ./saved_models/global_model_delta_11-26-13-49.pth
