In [None]:
"""
eXtreme Federated Learning guided by Sensitivity (XFL-S)
Author: Khaji Sana 
Description:
    XFL-S implements sensitivity-based layer selection in federated learning to 
    reduce communication overhead while maintaining model accuracy.
"""
# IMPORTS
import os
os.environ["RAY_DEDUP_LOGS"] = "0"

# Core libraries
import flwr
import flwr as fl
from collections import OrderedDict, defaultdict
import torch
import torch.nn as nn
import torch.nn.functional as F
from flwr_datasets import FederatedDataset
from datasets.utils.logging import disable_progress_bar
from flwr_datasets.partitioner import IidPartitioner, PathologicalPartitioner
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Normalize, ToTensor
import numpy as np
from typing import Dict, List, Tuple, Union, Optional, Callable
from flwr.client import NumPyClient, ClientApp
from flwr.common import Parameters, ndarrays_to_parameters, parameters_to_ndarrays, Context, Metrics, FitRes, Scalar
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.client_proxy import ClientProxy
import logging
from flwr.simulation import run_simulation
from sklearn.metrics import precision_score, recall_score, f1_score
import random
import copy

In [None]:
# DEVICE CONFIGURATION
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Training on {device}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")
disable_progress_bar()

In [None]:
# REPRODUCIBILITY SETUP
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

print("Random seeds set for reproducibility")


In [None]:
# SCENARIO CONFIGURATIONS
"""
Four experimental scenarios are available:
1. 6-Layer Model with IID Data Distribution
2. 6-Layer Model with Non-IID Data Distribution
3. 10-Layer Model with IID Data Distribution
4. 10-Layer Model with Non-IID Data Distribution
"""

SCENARIOS = {
    "6_layer_iid": {
        "name": "6-Layer Model with IID Data",
        "model_type": "6_layer",
        "data_distribution": "iid",
        "num_rounds": 20,
        "num_clients": 18,
        "local_epochs": 40,
        "learning_rate": 0.01,
        "batch_size": 32,
        "fraction_evaluate": 1.0,
    },
    "6_layer_noniid": {
        "name": "6-Layer Model with Non-IID Data",
        "model_type": "6_layer",
        "data_distribution": "noniid",
        "num_rounds": 20,
        "num_clients": 30,
        "local_epochs": 50,
        "learning_rate": 0.01,
        "batch_size": 32,
        "fraction_evaluate": 1.0,
    },
    "10_layer_iid": {
        "name": "10-Layer Model with IID Data",
        "model_type": "10_layer",
        "data_distribution": "iid",
        "num_rounds": 20,
        "num_clients": 18,
        "local_epochs": 40,
        "learning_rate": 0.01,
        "batch_size": 32,
        "fraction_evaluate": 1.0,
    },
    "10_layer_noniid": {
        "name": "10-Layer Model with Non-IID Data",
        "model_type": "10_layer",
        "data_distribution": "noniid",
        "num_rounds": 20,
        "num_clients": 30,
        "local_epochs": 50,
        "learning_rate": 0.01,
        "batch_size": 32,
        "fraction_evaluate": 1.0,
    },
}




In [None]:
# SCENARIO SELECTION - Choose ONE scenario by uncommenting it
# SCENARIO 1: 6-Layer Model with IID Data Distribution
# scenario_config = SCENARIOS["6_layer_iid"]

# SCENARIO 2: 6-Layer Model with Non-IID Data Distribution
# scenario_config = SCENARIOS["6_layer_noniid"]

# SCENARIO 3: 10-Layer Model with IID Data Distribution
# scenario_config = SCENARIOS["10_layer_iid"]

# SCENARIO 4: 10-Layer Model with Non-IID Data Distribution
scenario_config = SCENARIOS["10_layer_noniid"]  # <-- Currently active

In [None]:
# EXTRACT CONFIGURATION PARAMETERS
NUM_ROUNDS = scenario_config["num_rounds"]
NUM_CLIENTS = scenario_config["num_clients"]
LOCAL_EPOCHS = scenario_config["local_epochs"]
LEARNING_RATE = scenario_config["learning_rate"]
BATCH_SIZE = scenario_config["batch_size"]
FRACTION_EVALUATE = scenario_config["fraction_evaluate"]
MODEL_TYPE = scenario_config["model_type"]
DATA_DISTRIBUTION = scenario_config["data_distribution"]

print("\n" + "="*70)
print(f"SCENARIO: {scenario_config['name']}")
print("="*70)
print(f"  - Model: {MODEL_TYPE}")
print(f"  - Data Distribution: {DATA_DISTRIBUTION.upper()}")
print(f"  - Rounds: {NUM_ROUNDS}")
print(f"  - Clients: {NUM_CLIENTS}")
print(f"  - Local Epochs (t): {LOCAL_EPOCHS}")
print(f"  - Learning Rate: {LEARNING_RATE}")
print(f"  - Batch Size: {BATCH_SIZE}")
print("="*70 + "\n")


In [None]:
# NEURAL NETWORK MODEL DEFINITIONS

class Net6Layer(nn.Module):
    """
    6-Layer CNN Model for CIFAR-10
    Architecture: 4 Convolutional Layers + 2 Fully Connected Layers
    """
    def __init__(self):
        super(Net6Layer, self).__init__()
        # Convolutional layers
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)   # Layer 1
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)  # Layer 2
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)  # Layer 3
        self.conv4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) # Layer 4
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.25)
        # Fully connected layers
        self.fc1 = nn.Linear(2 * 2 * 128, 256)  # Layer 5
        self.fc2 = nn.Linear(256, 10)           # Layer 6 (output)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # 32x32 -> 16x16
        x = self.pool(F.relu(self.conv2(x)))  # 16x16 -> 8x8
        x = self.pool(F.relu(self.conv3(x)))  # 8x8 -> 4x4
        x = self.pool(F.relu(self.conv4(x)))  # 4x4 -> 2x2
        x = x.view(-1, 2 * 2 * 128)
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        return self.fc2(x)


class Net10Layer(nn.Module):
    """
    10-Layer CNN Model for CIFAR-10
    Architecture: 6 Convolutional Layers + 4 Fully Connected Layers
    """
    def __init__(self):
        super(Net10Layer, self).__init__()
        # Convolutional layers
        self.conv1 = nn.Conv2d(3, 16, 3, 1, 1)    # Layer 1
        self.conv2 = nn.Conv2d(16, 32, 3, 1, 1)   # Layer 2
        self.conv3 = nn.Conv2d(32, 64, 3, 1, 1)   # Layer 3
        self.conv4 = nn.Conv2d(64, 128, 3, 1, 1)  # Layer 4
        self.conv5 = nn.Conv2d(128, 256, 3, 1, 1) # Layer 5
        self.conv6 = nn.Conv2d(256, 512, 3, 1, 1) # Layer 6
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.4)
        # Fully connected layers
        self.fc1 = nn.Linear(1 * 1 * 512, 512)  # Layer 7
        self.fc2 = nn.Linear(512, 256)          # Layer 8
        self.fc3 = nn.Linear(256, 64)           # Layer 9
        self.fc4 = nn.Linear(64, 10)            # Layer 10 (output)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # 32x32 -> 16x16
        x = self.pool(F.relu(self.conv2(x)))  # 16x16 -> 8x8
        x = self.pool(F.relu(self.conv3(x)))  # 8x8 -> 4x4
        x = self.pool(F.relu(self.conv4(x)))  # 4x4 -> 2x2
        x = self.pool(F.relu(self.conv5(x)))  # 2x2 -> 1x1
        x = F.relu(self.conv6(x))             # 1x1 -> 1x1 (no pooling)
        x = x.view(-1, 1 * 1 * 512)
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        return self.fc4(x)

def get_model(model_type: str) -> nn.Module:
    """Factory function to get the appropriate model based on configuration."""
    if model_type == "6_layer":
        return Net6Layer()
    elif model_type == "10_layer":
        return Net10Layer()
    else:
        raise ValueError(f"Unknown model type: {model_type}")


print(f"Neural Network Model defined: {MODEL_TYPE}")

In [None]:
# UTILITY FUNCTIONS
def get_weights(net: nn.Module) -> List[np.ndarray]:
    """Extract model weights as numpy arrays."""
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_weights(net: nn.Module, parameters: List[np.ndarray]) -> None:
    """Set model weights from numpy arrays."""
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


print("Utility functions defined")

In [None]:
# DATA LOADING FUNCTIONS

# Global cache for FederatedDataset
_fds = None


def load_data_iid(partition_id: int, num_partitions: int, batch_size: int):
    """
    Load IID partitioned CIFAR-10 data.
    
    Args:
        partition_id: Client partition ID
        num_partitions: Total number of partitions
        batch_size: Batch size for DataLoader
    
    Returns:
        trainloader, testloader: DataLoaders for training and testing
    """
    global _fds
    if _fds is None:
        partitioner = IidPartitioner(num_partitions=num_partitions)
        _fds = FederatedDataset(
            dataset="uoft-cs/cifar10",
            partitioners={"train": partitioner},
        )
    
    partition = _fds.load_partition(partition_id)
    partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
    
    pytorch_transforms = Compose([
        ToTensor(), 
        Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    def apply_transforms(batch):
        batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
        return batch

    partition_train_test = partition_train_test.with_transform(apply_transforms)
    trainloader = DataLoader(
        partition_train_test["train"], batch_size=batch_size, shuffle=True
    )
    testloader = DataLoader(
        partition_train_test["test"], batch_size=batch_size
    )
    return trainloader, testloader


def load_data_noniid(
    partition_id: int,
    num_partitions: int,
    batch_size: int,
    num_classes_per_client: int = 4,
):
    """
    Load Non-IID partitioned CIFAR-10 data using pathological partitioning.
    
    Args:
        partition_id: Client partition ID
        num_partitions: Total number of partitions
        batch_size: Batch size for DataLoader
        num_classes_per_client: Number of classes per client partition
    
    Returns:
        trainloader, testloader: DataLoaders for training and testing
    """
    global _fds
    if _fds is None:
        partitioner = PathologicalPartitioner(
            num_partitions=num_partitions,
            partition_by="label",
            num_classes_per_partition=num_classes_per_client,
            class_assignment_mode="deterministic",
            shuffle=False,
            seed=42,
        )
        _fds = FederatedDataset(
            dataset="uoft-cs/cifar10",
            partitioners={"train": partitioner},
        )
    
    partition = _fds.load_partition(partition_id)
    tr_te = partition.train_test_split(test_size=0.2, seed=42)
    
    transforms = Compose([
        ToTensor(), 
        Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    tr_te = tr_te.with_transform(
        lambda batch: {
            "img": [transforms(img) for img in batch["img"]],
            "label": batch["label"],
        }
    )
    
    trainloader = DataLoader(tr_te["train"], batch_size=batch_size, shuffle=True)
    testloader = DataLoader(tr_te["test"], batch_size=batch_size)
    return trainloader, testloader


def load_data(partition_id: int, num_partitions: int, batch_size: int):
    """
    Load data based on the selected data distribution configuration.
    
    Args:
        partition_id: Client partition ID
        num_partitions: Total number of partitions
        batch_size: Batch size for DataLoader
    
    Returns:
        trainloader, testloader: DataLoaders for training and testing
    """
    if DATA_DISTRIBUTION == "iid":
        return load_data_iid(partition_id, num_partitions, batch_size)
    elif DATA_DISTRIBUTION == "noniid":
        return load_data_noniid(partition_id, num_partitions, batch_size)
    else:
        raise ValueError(f"Unknown data distribution: {DATA_DISTRIBUTION}")


print(f"Data loading functions defined (Distribution: {DATA_DISTRIBUTION.upper()})")


In [None]:
# TRAINING AND TESTING FUNCTIONS

def train_with_epoch_tracking(
    net: nn.Module,
    trainloader: DataLoader,
    valloader: DataLoader,
    epochs: int,
    learning_rate: float,
    device: torch.device
) -> Dict:
    """
    Train the model with epoch-wise parameter tracking for sensitivity calculation.
    Args:
        net: Neural network model
        trainloader: Training data loader
        valloader: Validation data loader
        epochs: Number of training epochs
        learning_rate: Learning rate
        device: Device to train on
    
    Returns:
        Dictionary containing validation metrics and epoch-wise parameters
    """
    net.to(device)
    criterion = torch.nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)
    net.train()
    
    # Store parameters after each epoch for Δ calculation
    epoch_parameters = []
    
    # Store initial parameters (w_0)
    initial_params = get_weights(net)
    epoch_parameters.append(copy.deepcopy(initial_params))
    
    # Training loop with epoch tracking
    for epoch in range(epochs):
        for batch in trainloader:
            images = batch["img"]
            labels = batch["label"]
            optimizer.zero_grad()
            criterion(net(images.to(device)), labels.to(device)).backward()
            optimizer.step()
        
        # Store parameters after this epoch (w_e)
        current_params = get_weights(net)
        epoch_parameters.append(copy.deepcopy(current_params))
    
    # Validation
    val_loss, val_acc, val_precision, val_recall, val_f1 = test(net, valloader, device)
    
    results = {
        "val_loss": val_loss,
        "val_accuracy": val_acc,
        "val_precision": val_precision,
        "val_recall": val_recall,
        "val_f1": val_f1,
        "epoch_parameters": epoch_parameters
    }
    
    return results


def test(
    net: nn.Module,
    testloader: DataLoader,
    device: torch.device
) -> Tuple[float, float, float, float, float]:
    """
    Test the model and compute comprehensive metrics.
    
    Args:
        net: Neural network model
        testloader: Test data loader
        device: Device to test on
    
    Returns:
        Tuple of (loss, accuracy, precision, recall, f1_score)
    """
    net.to(device)
    criterion = torch.nn.CrossEntropyLoss()
    correct, loss = 0, 0.0
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for batch in testloader:
            images = batch["img"].to(device)
            labels = batch["label"].to(device)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Calculate metrics
    accuracy = correct / len(testloader.dataset)
    loss = loss / len(testloader)
    
    y_true = np.array(all_labels)
    y_pred = np.array(all_predictions)
    
    precision = precision_score(y_true, y_pred, average='macro', zero_division=0)
    recall = recall_score(y_true, y_pred, average='macro', zero_division=0)
    f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)

    return loss, accuracy, precision, recall, f1


print("Training and testing functions defined with epoch tracking")

In [None]:
# XFL-S CLIENT IMPLEMENTATION

# Configure client-side logging
client_logger = logging.getLogger("XFLSClient")
client_logger.setLevel(logging.INFO)


class XFLSClient(NumPyClient):
    """
    XFL-S (eXtreme Federated Learning guided by Sensitivity) Client.
    """
    
    def __init__(
        self,
        net: nn.Module,
        trainloader: DataLoader,
        valloader: DataLoader,
        local_epochs: int,
        learning_rate: float,
        rank: int,
        cid: str
    ):
        self.net = net
        self.rank = rank
        self.cid = cid
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.net.to(self.device)
        self.learning_rate = learning_rate
        self.trainloader = trainloader
        self.valloader = valloader
        self.local_epochs = local_epochs  # This is 't' in the formula

    def calculate_average_local_updates(
        self,
        epoch_parameters: List[List[np.ndarray]]
    ) -> List[np.ndarray]:
        """
        Calculate average local updates 
        
        Args:
            epoch_parameters: List of parameters after each epoch
        
        Returns:
            List of average local updates for each parameter
        """
        t = self.local_epochs
        num_layers = len(epoch_parameters[0])
        
        # Initialize average updates
        delta_t_l = [np.zeros_like(epoch_parameters[0][l]) for l in range(num_layers)]
        
        # Calculate sum of (w_{e,l}^local - w_{e-1,l}^local) for each epoch e
        for e in range(1, t + 1):  # e from 1 to t
            w_e = epoch_parameters[e]          # Parameters after epoch e
            w_e_minus_1 = epoch_parameters[e - 1]  # Parameters after epoch e-1
            
            for l in range(num_layers):
                # Add (w_{e,l}^local - w_{e-1,l}^local) to the sum
                delta_t_l[l] += (w_e[l] - w_e_minus_1[l])
        
        # Divide by t to get average
        for l in range(num_layers):
            delta_t_l[l] = delta_t_l[l] / t
        
        return delta_t_l

    def calculate_sensitivity_scores(
        self,
        delta_t_l: List[np.ndarray],
        global_parameters: List[np.ndarray]
    ) -> List[float]:
        """
        Calculate sensitivity scores
        
        Args:
            delta_t_l: Average local updates
            global_parameters: Global model parameters
        
        Returns:
            List of sensitivity scores for each parameter
        """
        sensitivity_scores = []
        eps = 1e-12  # Small epsilon to avoid division by zero
        
        for l in range(len(delta_t_l)):
            # Calculate  (L2 norm of average local update)
            delta_norm = np.linalg.norm(delta_t_l[l])
            
            # Calculate (L2 norm of global parameters)
            global_norm = np.linalg.norm(global_parameters[l])
            
            # Calculate sensitivity
            sensitivity = delta_norm / (global_norm + eps)
            sensitivity_scores.append(sensitivity)
        
        return sensitivity_scores

    def calculate_probability_distribution(
        self,
        sensitivity_scores: List[float]
    ) -> Tuple[np.ndarray, List[float]]:
        """
        Calculate probability distribution: 
        
        Args:
            sensitivity_scores: Sensitivity scores for each parameter
        
        Returns:
            Tuple of (probabilities, layer_sensitivities)
        """
        # Group by layers (each layer has weight + bias, so we average them)
        num_layers = len(sensitivity_scores) // 2
        layer_sensitivities = []
        
        for layer_idx in range(num_layers):
            weight_sens = sensitivity_scores[2 * layer_idx]      # Weight sensitivity
            bias_sens = sensitivity_scores[2 * layer_idx + 1]    # Bias sensitivity
            # Average sensitivity for this layer
            layer_sensitivity = (weight_sens + bias_sens) / 2
            layer_sensitivities.append(layer_sensitivity)
        
        # Calculate sum of all sensitivities
        total_sensitivity = sum(layer_sensitivities)
        
        # Handle edge case where all sensitivities are zero
        if total_sensitivity <= 0 or not np.isfinite(total_sensitivity):
            # Uniform distribution as fallback
            probabilities = [1.0 / num_layers] * num_layers
        else:
            # Calculate probabilities 
            probabilities = [s / total_sensitivity for s in layer_sensitivities]
        
        # Ensure probabilities sum to 1 (numerical stability)
        probabilities = np.array(probabilities)
        probabilities = probabilities / probabilities.sum()
        
        return probabilities, layer_sensitivities

    def fit(self, parameters: List[np.ndarray], config: Dict) -> Tuple[List[np.ndarray], int, Dict]:
        """
        Train the model locally and select a layer based on sensitivity.
        
        Args:
            parameters: Global model parameters
            config: Configuration dictionary
        
        Returns:
            Tuple of (selected_layer_parameters, num_examples, metrics)
        """
        # Set global weights (x_{t,l})
        set_weights(self.net, parameters)
        global_parameters = get_weights(self.net)
        
        # Train with epoch tracking
        train_results = train_with_epoch_tracking(
            self.net,
            trainloader=self.trainloader,
            valloader=self.valloader,
            epochs=self.local_epochs,
            learning_rate=self.learning_rate,
            device=self.device
        )
        
        # Extract epoch parameters for Δ calculation
        epoch_parameters = train_results["epoch_parameters"]
        
        # STEP 1: Calculate average local updates Δ_{t,l}
        delta_t_l = self.calculate_average_local_updates(epoch_parameters)
        
        # STEP 2: Calculate sensitivity scores s_{t,l}
        sensitivity_scores = self.calculate_sensitivity_scores(delta_t_l, global_parameters)
        
        # STEP 3: Calculate probability distribution p_{t,l}
        probabilities, layer_sensitivities = self.calculate_probability_distribution(sensitivity_scores)
        
        # STEP 4: Select layer based on probability distribution
        selected_layer_idx = np.random.choice(len(probabilities), p=probabilities)
        layer_selected = selected_layer_idx + 1  # Convert to 1-indexed
        
        round_num = int(config.get("round", 1))
        client_logger.info(
            f"[Round {round_num}] Client {self.rank} (CID: {self.cid}) selected layer {layer_selected} "
            f"with probability {probabilities[selected_layer_idx]:.4f} "
            f"(sensitivity: {layer_sensitivities[selected_layer_idx]:.6f})"
        )
        
        # Get final local parameters
        final_local_parameters = get_weights(self.net)
        
        # Package selected layer parameters (weight + bias)
        weight_idx = 2 * (layer_selected - 1)
        bias_idx = weight_idx + 1
        selected_params = [final_local_parameters[weight_idx], final_local_parameters[bias_idx]]
        
        # Calculate communication cost
        effective_upload_bytes_client = np.array([layer_selected], dtype=np.int64).nbytes
        effective_upload_bytes_client += sum(p.nbytes for p in selected_params)
        
        # Prepare metrics
        metrics = {
            "layer_selected": float(layer_selected),
            "effective_upload_bytes": effective_upload_bytes_client,
            "val_accuracy": train_results["val_accuracy"],
            "val_precision": train_results["val_precision"],
            "val_recall": train_results["val_recall"],
            "val_f1": train_results["val_f1"],
            "selection_probability": float(probabilities[selected_layer_idx]),
            "layer_sensitivity": float(layer_sensitivities[selected_layer_idx])
        }
        
        # Add detailed sensitivity information for analysis
        for l, s in enumerate(sensitivity_scores):
            metrics[f"sensitivity_param_{l}"] = float(s)
        for l, p in enumerate(probabilities):
            metrics[f"probability_layer_{l+1}"] = float(p)
        
        # Log results
        client_logger.info(
            f"[Round {round_num}] Client {self.rank} XFL-S training complete - "
            f"Val Acc: {train_results['val_accuracy']:.4f}, "
            f"Selected Layer: {layer_selected} (prob: {probabilities[selected_layer_idx]:.4f})"
        )
        
        # Return payload: [layer_index, weight_array, bias_array]
        payload = [
            np.array([layer_selected], dtype=np.int64),
            *selected_params
        ]
        
        return payload, len(self.trainloader.dataset), metrics

    def evaluate(
        self,
        parameters: List[np.ndarray],
        config: Dict
    ) -> Tuple[float, int, Dict]:
        """
        Evaluate the model on validation data.
        
        Args:
            parameters: Global model parameters
            config: Configuration dictionary
        
        Returns:
            Tuple of (loss, num_examples, metrics)
        """
        set_weights(self.net, parameters)
        loss, accuracy, precision, recall, f1 = test(self.net, self.valloader, self.device)
        
        return float(loss), len(self.valloader.dataset), {
            "accuracy": float(accuracy),
            "precision": precision,
            "recall": recall,
            "f1": f1,
            "eval_loss": float(loss),
        }


print("XFL-S Client implementation defined")

In [None]:
# XFL-S STRATEGY (SERVER-SIDE AGGREGATION)

# Set up server-side logging
strategy_logger = logging.getLogger("XFLSStrategy")
strategy_logger.setLevel(logging.INFO)


class XFLSStrategy(fl.server.strategy.FedAvg):
    """
    XFL-S Strategy for server-side aggregation.
    
    This strategy aggregates layer-wise updates from clients based on
    sensitivity-guided selection.
    """
    
    def __init__(
        self,
        *,
        initial_parameters: Parameters,
        param_keys: List[str],
        fraction_fit: float = 1.0,
        fraction_evaluate: float = 0.0,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 0,
        min_available_clients: int = 2,
        evaluate_fn: Optional[Callable] = None,
        evaluate_metrics_aggregation_fn: Optional[Callable] = None,
    ):
        super().__init__(
            fraction_fit=fraction_fit,
            fraction_evaluate=fraction_evaluate,
            min_fit_clients=min_fit_clients,
            min_evaluate_clients=min_evaluate_clients,
            min_available_clients=min_available_clients,
            initial_parameters=initial_parameters,
            evaluate_fn=evaluate_fn,
            evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
        )
        self.param_keys = param_keys
        self.global_state = OrderedDict(
            zip(self.param_keys, parameters_to_ndarrays(initial_parameters))
        )
        
        # Track statistics
        self.layer_selection_history = defaultdict(list)
        self.communication_savings = []

    def configure_fit(
        self,
        server_round: int,
        parameters: Parameters,
        client_manager
    ) -> List[Tuple]:
        """
        Configure clients for training round.
        
        Args:
            server_round: Current round number
            parameters: Global model parameters
            client_manager: Client manager
        
        Returns:
            List of (client, FitIns) tuples
        """
        clients = client_manager.sample(num_clients=self.min_fit_clients)
        clients = sorted(clients, key=lambda c: c.cid)
        
        fit_ins = []
        for rank, client in enumerate(clients):
            config = {
                "round": server_round,
                "rank": rank,
            }
            fit_ins.append((client, fl.common.FitIns(parameters, config)))
        
        strategy_logger.info(
            f"[Round {server_round}] Configured {len(fit_ins)} clients for XFL-S layer selection"
        )
        return fit_ins

    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple],
        failures: List[Tuple]
    ) -> Tuple[Parameters, Dict]:
        """
        Aggregate layer-wise updates from clients.
        
        Args:
            server_round: Current round number
            results: List of (client, FitRes) tuples
            failures: List of failed clients
        
        Returns:
            Tuple of (aggregated_parameters, metrics)
        """
        if not results:
            return ndarrays_to_parameters(list(self.global_state.values())), {}

        # Track layer updates: {layer_idx: [(weight, bias, num_examples), ...]}
        layer_updates = {}
        
        # Tracking variables
        effective_upload_bytes_total = 0
        training_accuracies = []
        training_precisions = []
        training_recalls = []
        training_f1s = []
        selection_probabilities = []
        layer_selections = []

        # Process each client's result
        for _, fit_res in results:
            client_payload = parameters_to_ndarrays(fit_res.parameters)
            num_examples = fit_res.num_examples

            # Extract payload: [layer_index, weight_array, bias_array]
            layer_idx_arr = client_payload[0]
            weight_param = client_payload[1]
            bias_param = client_payload[2]
            layer_idx = int(layer_idx_arr[0])
            
            # Get selection probability
            probability = fit_res.metrics.get("selection_probability", 1.0)
            
            strategy_logger.info(
                f"[Round {server_round}] Client selected layer {layer_idx} (prob: {probability:.3f})"
            )

            # Group updates by layer
            if layer_idx not in layer_updates:
                layer_updates[layer_idx] = []
            layer_updates[layer_idx].append(((weight_param, bias_param), num_examples))

            # Collect statistics
            if "effective_upload_bytes" in fit_res.metrics:
                effective_upload_bytes_total += fit_res.metrics["effective_upload_bytes"]
            
            # Training metrics
            if "val_accuracy" in fit_res.metrics:
                training_accuracies.append(fit_res.metrics["val_accuracy"])
            if "val_precision" in fit_res.metrics:
                training_precisions.append(fit_res.metrics["val_precision"])
            if "val_recall" in fit_res.metrics:
                training_recalls.append(fit_res.metrics["val_recall"])
            if "val_f1" in fit_res.metrics:
                training_f1s.append(fit_res.metrics["val_f1"])
            
            selection_probabilities.append(probability)
            layer_selections.append(layer_idx)

        # Weighted aggregation for selected layers
        for layer_idx, updates in layer_updates.items():
            total_examples = sum(n for _, n in updates)
            
            # Weighted average
            weighted_weight_sum = sum(w[0] * n for w, n in updates)
            weighted_bias_sum = sum(w[1] * n for w, n in updates)
            avg_weight = weighted_weight_sum / total_examples
            avg_bias = weighted_bias_sum / total_examples
            
            # Update global state
            weight_pos = 2 * (layer_idx - 1)
            bias_pos = weight_pos + 1
            weight_key = self.param_keys[weight_pos]
            bias_key = self.param_keys[bias_pos]
            
            self.global_state[weight_key] = avg_weight
            self.global_state[bias_key] = avg_bias

            strategy_logger.info(
                f"[Round {server_round}] Updated layer {layer_idx} with {len(updates)} client updates"
            )

        # Track statistics
        self.layer_selection_history[server_round] = layer_selections
        
        # Calculate communication efficiency
        total_model_size_bytes = sum(param.nbytes for param in self.global_state.values())
        download_bytes_total = total_model_size_bytes * len(results)
        total_bytes_effective = effective_upload_bytes_total + download_bytes_total
        
        # Calculate theoretical full upload
        theoretical_full_upload = total_model_size_bytes * len(results)
        communication_reduction = 1 - (effective_upload_bytes_total / theoretical_full_upload)
        self.communication_savings.append(communication_reduction)

        # Convert to MB
        effective_upload_mb = effective_upload_bytes_total / (1024 ** 2)
        download_mb = download_bytes_total / (1024 ** 2)
        total_mb_effective = total_bytes_effective / (1024 ** 2)
        theoretical_upload_mb = theoretical_full_upload / (1024 ** 2)

        # Calculate average training metrics
        avg_training_metrics = {}
        if training_accuracies:
            avg_training_metrics["avg_train_accuracy"] = np.mean(training_accuracies)
        if training_precisions:
            avg_training_metrics["avg_train_precision"] = np.mean(training_precisions)
        if training_recalls:
            avg_training_metrics["avg_train_recall"] = np.mean(training_recalls)
        if training_f1s:
            avg_training_metrics["avg_train_f1"] = np.mean(training_f1s)

        # Prepare metrics
        metrics = {
            "upload_MB": effective_upload_mb,
            "download_MB": download_mb,
            "total_update_MB": total_mb_effective,
            "theoretical_upload_MB": theoretical_upload_mb,
            "communication_reduction": communication_reduction,
            "avg_selection_probability": np.mean(selection_probabilities) if selection_probabilities else 0,
            "layer_diversity": len(set(layer_selections)),
            **avg_training_metrics,
        }

        # Enhanced logging
        strategy_logger.info(
            f"[Round {server_round}]  XFL-S Communication - "
            f"Upload: {effective_upload_mb:.4f} MB (vs {theoretical_upload_mb:.4f} MB full), "
            f"Reduction: {communication_reduction:.1%}, "
            f"Avg Probability: {np.mean(selection_probabilities):.3f}, "
            f"Layer Diversity: {len(set(layer_selections))}/{max(layer_selections) if layer_selections else 0}"
        )

        if avg_training_metrics:
            strategy_logger.info(
                f"[Round {server_round}] Training Metrics - "
                f"Avg Accuracy: {avg_training_metrics.get('avg_train_accuracy', 0):.4f}, "
                f"Avg F1: {avg_training_metrics.get('avg_train_f1', 0):.4f}"
            )

        return ndarrays_to_parameters(list(self.global_state.values())), metrics


print("XFL-S Strategy implementation defined")

In [None]:
# SERVER FUNCTIONS

def load_centralized_test_data() -> DataLoader:
    """
    Load centralized test dataset for server-side evaluation.
    
    Returns:
        DataLoader for centralized test data
    """
    partitioner = IidPartitioner(num_partitions=1)
    fds_centralized = FederatedDataset(
        dataset="uoft-cs/cifar10",
        partitioners={"train": partitioner},
    )
    partition = fds_centralized.load_partition(0)
    partition_test = partition.train_test_split(test_size=0.3, seed=42)["test"]
    
    pytorch_transforms = Compose([
        ToTensor(), 
        Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    def apply_transforms(batch):
        batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
        return batch

    partition_test = partition_test.with_transform(apply_transforms)
    centralized_testloader = DataLoader(
        partition_test, batch_size=BATCH_SIZE, shuffle=False
    )
    print(f"Centralized test dataset size: {len(partition_test)}")
    return centralized_testloader


def gen_evaluate_fn(testloader: DataLoader, device: torch.device) -> Callable:
    """
    Generate evaluation function for centralized testing.
    
    Args:
        testloader: Test data loader
        device: Device to evaluate on
    
    Returns:
        Evaluation function
    """
    def evaluate(server_round: int, parameters_ndarrays: List[np.ndarray], config: Dict):
        print(f"[Server] Starting centralized evaluation for round {server_round}")
        net = get_model(MODEL_TYPE)
        set_weights(net, parameters_ndarrays)
        net.to(device)
        loss, accuracy, precision, recall, f1 = test(net, testloader, device)
        
        print(f"[Server] XFL-S Centralized Evaluation Round {server_round}:")
        print(f"  - Loss: {loss:.4f}")
        print(f"  - Accuracy: {accuracy:.4f}")
        print(f"  - Precision: {precision:.4f}")
        print(f"  - Recall: {recall:.4f}")
        print(f"  - F1-Score: {f1:.4f}")
        
        return loss, {
            "centralized_accuracy": accuracy,
            "centralized_precision": precision,
            "centralized_recall": recall,
            "centralized_f1": f1,
            "centralized_loss": loss,
        }
    return evaluate


def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    """
    Aggregate metrics using weighted average.
    
    Args:
        metrics: List of (num_examples, metrics_dict) tuples
    
    Returns:
        Aggregated metrics dictionary
    """
    if not metrics:
        return {}
    
    metric_keys = list(metrics[0][1].keys())
    aggregated_metrics = {}
    
    for key in metric_keys:
        weighted_values = [num_examples * m.get(key, 0) for num_examples, m in metrics]
        examples = [num_examples for num_examples, _ in metrics]
        
        if sum(examples) > 0:
            aggregated_metrics[key] = sum(weighted_values) / sum(examples)
        else:
            aggregated_metrics[key] = 0.0
    
    return aggregated_metrics


print("Server functions defined")

In [None]:
# CLIENT APP FUNCTION

def client_fn(context: Context) -> XFLSClient:
    """
    Create a client instance for federated learning.
    
    Args:
        context: Client context
    
    Returns:
        XFLSClient instance
    """
    pid = context.node_config["partition-id"]
    cid = context.node_id
    nps = context.node_config["num-partitions"]
    
    trainloader, valloader = load_data(pid, nps, BATCH_SIZE)
    net = get_model(MODEL_TYPE)
    
    return XFLSClient(
        net, trainloader, valloader, LOCAL_EPOCHS, LEARNING_RATE, rank=pid, cid=cid
    ).to_client()


client = ClientApp(client_fn=client_fn)

In [None]:
# SERVER APP FUNCTION 

def server_fn(context: Context) -> ServerAppComponents:
    """
    Create server components for federated learning.
    
    Args:
        context: Server context
    
    Returns:
        ServerAppComponents with strategy and config
    """
    # Setup centralized evaluation
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    centralized_testloader = load_centralized_test_data()
    model = get_model(MODEL_TYPE)
    ndarrays = get_weights(model)
    parameters = ndarrays_to_parameters(ndarrays)
    param_keys = list(model.state_dict().keys())

    print("\n" + "="*70)
    print("XFL-S: eXtreme Federated Learning guided by Sensitivity")
    print("="*70)
    print("Mathematical Formulation:")
    print("  1. Δ_{t,l} = (1/t) × Σ(w_{e,l}^local - w_{e-1,l}^local)")
    print("  2. s_{t,l} = ||Δ_{t,l}|| / ||x_{t,l}||")
    print("  3. p_{t,l} = s_{t,l} / Σ(s_{t,k})")
    print("="*70)
    print("Expected Benefits:")
    print("  - Sensitivity-based layer selection")
    print("  - Reduced communication overhead")
    print("  - Maintained model accuracy")
    print("  - Adaptive to data heterogeneity")
    print("="*70)

    strategy = XFLSStrategy(
        fraction_fit=1.0,
        fraction_evaluate=FRACTION_EVALUATE,
        min_fit_clients=NUM_CLIENTS,
        min_evaluate_clients=NUM_CLIENTS,
        min_available_clients=NUM_CLIENTS,
        initial_parameters=parameters,
        param_keys=param_keys,
        evaluate_fn=gen_evaluate_fn(centralized_testloader, device),
        evaluate_metrics_aggregation_fn=weighted_average,
    )
    
    config = ServerConfig(num_rounds=NUM_ROUNDS)
    return ServerAppComponents(strategy=strategy, config=config)


server = ServerApp(server_fn=server_fn)
print("Client and Server apps configured")

In [None]:
# RESOURCE CONFIGURATION
# Configure computational resources
backend_config = {"client_resources": {"num_cpus": 2, "num_gpus": 0.0}}
if device == "cuda":
    backend_config = {"client_resources": {"num_cpus": 4, "num_gpus": 0.5}}

print("\n" + "="*70)
print("Resource Configuration:")
print("="*70)
print(f"  - CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"  - CUDA Device Count: {torch.cuda.device_count()}")
    print(f"  - Current Device: {torch.cuda.current_device()}")
    print(f"  - Device Name: {torch.cuda.get_device_name(0)}")
    print(f"  - GPU Memory per Client: 0.5 GPU")
else:
    print("  - Using CPU for training")
    print(f"  - CPU Cores per Client: 2")
print("="*70)

In [None]:
# MAIN SIMULATION FUNCTION
def main():
    """
    Main function to run the XFL-S simulation.
    """
    print("\n" + "="*70)
    print("STARTING XFL-S SIMULATION")
    print("="*70)
    print(f"Configuration Summary:")
    print(f"  - Scenario: {scenario_config['name']}")
    print(f"  - Algorithm: XFL-S (eXtreme Federated Learning guided by Sensitivity)")
    print(f"  - Model: {MODEL_TYPE}")
    print(f"  - Data Distribution: {DATA_DISTRIBUTION.upper()}")
    print(f"  - Clients: {NUM_CLIENTS}")
    print(f"  - Rounds: {NUM_ROUNDS}")
    print(f"  - Local Epochs (t): {LOCAL_EPOCHS}")
    print(f"  - Dataset: CIFAR-10")
    print(f"  - Device: {device.upper()}")
    print("="*70)
    
    # Run the simulation
    run_simulation(
        server_app=server,
        client_app=client,
        num_supernodes=NUM_CLIENTS,
        backend_config=backend_config,
    )
    
    print("\n" + "="*70)
    print("XFL-S SIMULATION COMPLETED!")
    print("="*70)
    print("Key Metrics to Analyze:")
    print("  - Communication Reduction")
    print("  - Model Accuracy")
    print("  - Layer Selection Distribution")
    print("  - Sensitivity Score Analysis")
    print("  - Probability Distribution Effectiveness")
    print("="*70)
    print("Author: Khaji Sana")
    print("="*70 + "\n")
    
    

In [None]:
# ENTRY POINT
if __name__ == "__main__":
    main()