## Install Prerequisites and Dependencies

In [None]:
!pip install torch torchvision ray flwr hydra-core omegaconf

In [None]:
!pip install -U flwr[simulation]

## Configuration via base.yaml:

The **'conf/base.yaml'** file contains essential configurations for the federated learning pipeline. It allows you to easily modify important parameters before running the simulation, such as:

In [None]:
# Sample configuration (from your base.yaml file)
cfg = {
    'num_rounds': 3,
    'num_clients': 5,
    'batch_size': 32,
    'num_classes': 10,
    'num_clients_per_round_fit': 5,
    'num_clients_per_round_eval': 5,
    'config_fit': {
        'lr': 0.001,
        'weight_decay': 1e-4,
        'local_epochs': 2,
        'batch_size': 32,
        'clip_value': 1.0,
        'momentum': 0.9,
    },
    'dp': {
        'epsilon': 1.0,
        'delta': 1e-5,
        'noise_scale': 0.01
    },
    'client_resources': {
        'num_cpus': 1,    # Reduce CPU usage per client
        'num_gpus': 0.0   # Disable GPU usage if not available in Colab
    }
}

## Dataset Preparation (dataset.py):

This script loads and partitions the MNIST and CIFAR10 datasets, applying transformations like normalization and ensuring that data is split among clients. Each client trains its model on its data without sharing it, adhering to federated learning principles.

In [None]:
'''
dataset.py:
This script takes care of loading datasets and splitting them up for our federated learning setup.
The goal here is to support both MNIST and CIFAR10, apply some useful transformations (like normalization),
and then split the data among multiple clients. This is crucial because, in federated learning,
clients train models on their own data without sharing it directly, so data preparation is essential.
'''
import torch
from torch.utils.data import random_split, DataLoader
from torchvision.datasets import MNIST, CIFAR10
from torchvision.transforms import ToTensor, Normalize, Compose, RandomHorizontalFlip, RandomCrop



def get_dataset(dataset_choice: str, data_path: str = './data'):
    """
    Loads the dataset the user selects (either MNIST or CIFAR10) with proper transformations.

    Arguments:
        dataset_choice (str): The dataset to load ('mnist' or 'cifar10').
        data_path (str): Where to store the dataset.

    Returns:
        Tuple: A tuple containing the training and test datasets.

    Raises:
        ValueError: If the dataset_choice isn't 'mnist' or 'cifar10'.

    Using this function, we can preprocess MNIST and CIFAR10 in different ways. For example, MNIST is grayscale, so it gets simple normalization,
    while CIFAR10 gets a bit more fancy treatment with things like random horizontal flips and cropping to make the model more robust.
    """

    if dataset_choice == 'mnist':
        # MNIST is pretty straightforward. We convert the images to tensors and normalize them.
        transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
        trainset = MNIST(data_path, train=True, download=True, transform=transform)
        testset = MNIST(data_path, train=False, download=True, transform=transform)

    elif dataset_choice == 'cifar10':
        # Training CIFAR10 includes some extra augmentations to improve generalization and proficiency
        transform_train = Compose([
            RandomHorizontalFlip(),                             # Randomly flip the image horizontally
            RandomCrop(32, padding=4),                          # Crop and pad to simulate data variability
            ToTensor(),
            Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))         # Normalization of RGB channels
        ])

        # There is no need to augment the test set, and normalization is the only processing step.
        transform_test = Compose([
            ToTensor(),
            Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        trainset = CIFAR10(data_path, train=True, download=True, transform=transform_train)
        testset = CIFAR10(data_path, train=False, download=True, transform=transform_test)

    else:
        raise ValueError(f"Unsupported dataset: {dataset_choice}. Please choose either 'mnist' or 'cifar10'.")

    return trainset, testset



def prepare_dataset(num_partitions: int, batch_size: int, dataset_choice: str, val_ratio: float = 0.1):
    """
    Prepares the dataset for federated learning by splitting it into multiple partitions (clients).

    Arguments:
        num_partitions (int): Number of clients (partitions) to split the training data into.
        batch_size (int): Batch size to use for the data loaders.
        dataset_choice (str): The dataset choice ('mnist' or 'cifar10').
        val_ratio (float): Each client should use 10% of data for validation

    Returns:
        Tuple: A tuple containing data loaders for training, validation, and testing.

        Data is divided between clients using the IID method (Independent and Identically Distributed).
        Due to the fact that our dataset comes from a larger pool with the same distribution (such as MNIST or CIFAR10),
        we assume that the data drawn from each client is also drawn from the same distribution. Based on the law of large numbers,
        each client's data represents the overall dataset fairly. Therefore, the IID method must be used to divide the data in order to maintain
        this consistency and make sure that all clients are trained on the same global data.
    """

    # Loading the dataset based on what the user chose (either MNIST or CIFAR10).
    trainset, testset = get_dataset(dataset_choice)

    # Spliting the training set into 'num_partitions' (one for each client).
    num_images = len(trainset) // num_partitions
    partition_len = [num_images] * num_partitions                   # Ensure that all clients get the same amount of data.

    # Randomly splitting the training data among the clients.
    trainsets = random_split(trainset, partition_len, torch.Generator().manual_seed(2024))


    trainloaders = []
    valloaders = []

    # Splitting related part of the data further into training and validation sets for each client.
    for trainset_ in trainsets:
        num_total = len(trainset_)
        num_val = int(val_ratio * num_total)
        num_train = num_total - num_val

        for_train, for_val = random_split(trainset_, [num_train, num_val], torch.Generator().manual_seed(2024))

        # Appending data loaders for both training and validation sets to their respective lists.
        trainloaders.append(DataLoader(for_train, batch_size=batch_size, shuffle=True, num_workers=2))
        valloaders.append(DataLoader(for_val, batch_size=batch_size, shuffle=False, num_workers=2))

    # Creation of the test loader for the test set (which is shared by all clients).
    # The test set is kept in its original size and used after aggregation to evaluate the global model.
    testloaders = DataLoader(testset, batch_size=128)

    return trainloaders, valloaders, testloaders


## Model Definition (model.py):

Defines flexible model architectures for MNIST and CIFAR10. Based on the selected dataset, an appropriate model is dynamically initialized. This script also handles optimization, loss calculation, and evaluation, enabling seamless switching between datasets and models.

In [None]:
'''
model.py
The script defines the models we use for the MNIST and CIFAR10 datasets.
The appropriate model architecture is dynamically selected and initialized based on the dataset being used.
As part of our training and testing functions, we also handle optimization, loss calculation, and evaluation.
We can easily switch between datasets and models with this code because it is designed to be flexible.
'''
import torch
import torch.nn as nn
import torch.nn.functional as F

# MNIST Model Definition
class MNISTNet(nn.Module):
    def __init__(self, num_classes: int) -> None:
        super(MNISTNet, self).__init__()

        self.conv1 = nn.Conv2d(1, 32, 5)       # Initial convolutional layers based on grayscale MNIST images (1 channel), 32 filters, 5x5 kernels
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 5)
        self.fc1 = nn.Linear(64 * 4 * 4, 120)  # Fully connected layers for classification
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)  # Output layer with 10 neurons, number of classes for MNIST

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 4 * 4)              # Flatten the output from the conv layers to feed into fully connected layers
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)                         # No activation on the final layer (since we use softmax in the loss function)
        return x

# CIFAR10 Model Definition (CIFAR10 images are RGB, so we have 3 input channels. The network here is deeper than the MNIST model.)
class CIFAR10Net(nn.Module):
    def __init__(self, num_classes: int) -> None:
        super(CIFAR10Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)  # 64 filters, 3x3 kernel size, padding to keep image size the same
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.batch_norm1 = nn.BatchNorm2d(64)         # Batch normalization helps with training stability and convergence
        self.batch_norm2 = nn.BatchNorm2d(128)
        self.batch_norm3 = nn.BatchNorm2d(256)
        self.fc1 = nn.Linear(256 * 4 * 4, 512)          # Fully connected layers for classification
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, num_classes)
        self.dropout = nn.Dropout(0.5)                  # Dropout to prevent overfitting

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.batch_norm1(self.conv1(x))))
        x = self.pool(F.relu(self.batch_norm2(self.conv2(x))))
        x = self.pool(F.relu(self.batch_norm3(self.conv3(x))))
        x = x.view(-1, 256 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.fc3(x)                              # No activation on the final layer (since we use softmax in the loss function)
        return x


# Function to get the appropriate model based on dataset choice
def get_model(dataset_choice: str, num_classes: int):
    """
    Returns the model corresponding to the dataset being used (MNIST or CIFAR10).

    Arguments:
        dataset_choice (str): The dataset we're working with ('mnist' or 'cifar10').
        num_classes (int): The number of output classes (10 for both MNIST and CIFAR10).

    Returns:
        torch.nn.Module: The appropriate model (either MNISTNet or CIFAR10Net).
    """
    if dataset_choice.lower() == "mnist":
        return MNISTNet(num_classes)
    elif dataset_choice.lower() == "cifar10":
        return CIFAR10Net(num_classes)
    else:
        raise ValueError(f"Unknown dataset: {dataset_choice}. Please choose either 'mnist' or 'cifar10'.")

# Training function
def train(net, trainloader, optimizer, epochs, device: str, clip_value=None):
    """
    Trains the given network using the specified training data loader and optimizer.

    Arguments:
        net: The neural network model to train.
        trainloader: Providing batches of training data.
        optimizer: The optimizer for training (AdamW is used in this project).
        epochs: The number of epochs to train for.
        device (str): The device to run the training on ('cpu' or 'gpu').
        clip_value: The value for gradient clipping to prevent exploding gradients. This is optional

    Returns:
        avg_loss: The average loss over the training set.
        accuracy: The accuracy of the model on the training data.
    """
    criterion = torch.nn.CrossEntropyLoss()             # Standard loss function for classification
    net.train()
    net.to(device)

    correct = 0
    total = 0
    total_loss = 0.0

    for _ in range(epochs):
        for images, labels in trainloader:
            images, labels = images.to(device), labels.to(device)  # Move data to the selected device
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()

            # For CIFAR-10 image classification, gradient clipping is crucial as it prevents exploding gradients, stabilizes training, and ensures controlled weight updates.
            if clip_value is not None:
                torch.nn.utils.clip_grad_norm_(net.parameters(), clip_value)

            optimizer.step()

            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    avg_loss = total_loss / len(trainloader)

    return avg_loss, accuracy

# Testing function
def test(net, testloader, device: str):
    """
    Evaluates the model's performance on the test set.

    Args:
        net: The trained neural network model to evaluate.
        testloader: DataLoader providing batches of test data.
        device (str): The device to run the evaluation on ('cpu' or 'cuda').

    Returns:
        loss: The total loss on the test set.
        accuracy: The accuracy of the model on the test data.
    """
    criterion = torch.nn.CrossEntropyLoss()                     # Loss function for classification
    correct = 0
    loss = 0.0
    net.eval()
    net.to(device)

    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()

    accuracy = correct / len(testloader.dataset)
    return loss, accuracy


## Client Setup (client.py):

Manages local training and evaluation on each client through the FlowerClient class. Each client is dynamically initialized, trains on its partitioned data, and incorporates Differential Privacy (DP) to protect data by adding noise to updates before they are sent to the server.

In [None]:
'''
# client.py
This script defines the FlowerClient class, which is responsible for managing the local training and evaluation on each client.
In addition, it ensures that each client is correctly initialized, receives updates from the server, and utilizes differential privacy
techniques to protect the data. Additionally, each client is generated dynamically and initialized only once.
'''
import torch
import flwr as fl
from typing import Dict
from flwr.common import NDArrays, Scalar, Context
#from model import MNISTNet, CIFAR10Net, train, test
import logging
import numpy as np

# Setting up loggig up important information like initialization, training, and evaluation
logging.basicConfig(level=logging.INFO, format='[%(asctime)s] [%(levelname)s] [%(name)s] - %(message)s')
logger = logging.getLogger(__name__)


initialized_clients = set()               # This set help initializ clients to avoid accidentally re-initializing the same client

class FlowerClient(fl.client.NumPyClient):
    def __init__(self, trainloader, valloader, num_classes, dataset_choice, client_id, epsilon=0.1, delta=1e-5, noise_scale=0.01) -> None:
        """
        Initializes the client for federated learning with differential privacy settings.

        Arguments:
            trainloader: DataLoader for training data.
            valloader: DataLoader for validation data.
            num_classes: Number of output classes (10 for MNIST or CIFAR-10).
            dataset_choice: The dataset being used (MNIST or CIFAR10).
            client_id: Unique ID for each client.
            epsilon: Privacy budget for DP.
            delta: Privacy tolerance for DP.
            noise_scale: Scale of noise added for DP.
        """
        super().__init__()
        self.trainloader = trainloader
        self.valloader = valloader
        self.dataset_choice = dataset_choice
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.client_id = client_id
        self.epsilon = epsilon
        self.delta = delta
        self.noise_scale = noise_scale

        # Choose the appropriate model based on the dataset (MNIST or CIFAR10)
        if self.dataset_choice == "mnist":
            self.model = MNISTNet(num_classes=num_classes)
        else:
            self.model = CIFAR10Net(num_classes=num_classes)

        self.model.to(self.device)                  # Move the model to the appropriate device (GPU or CPU)

        if self.client_id in initialized_clients:   # Check if this client has already been initialized; if yes, skip initialization
            logger.warning(f"[Client {self.client_id}] already initialized, skipping.")
        else:
            initialized_clients.add(self.client_id)
            logger.info(f"[Client {self.client_id}] Initialized with DP settings (ε={self.epsilon}, δ={self.delta}, noise_scale={self.noise_scale})")


    def set_parameters(self, parameters: NDArrays):
        """Sets the model parameters received from the server."""

        params_dict = self.model.state_dict()  # Get the model's parameters

        for i, key in enumerate(params_dict.keys()):
            params_dict[key] = torch.tensor(parameters[i], dtype=params_dict[key].dtype).to(self.device)
        self.model.load_state_dict(params_dict, strict=True)


    def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays:
        """Gets the model parameters to send back to the server."""
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]  # Convert parameters to NumPy arrays


    def fit(self, parameters: NDArrays, config: Dict[str, Scalar]):
        """Trains the model using the current parameters and returns the updated parameters."""

        self.set_parameters(parameters)         # Set the model parameters received from the server


        lr = config.get('lr', 0.001)            # Set training configurations
        weight_decay = config.get('weight_decay', 1e-4)
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr, weight_decay=weight_decay)
        epochs = config.get('local_epochs', 1)
        clip_value = config.get('clip_value', 1)

        # Training the model
        train_loss, train_accuracy = train(self.model, self.trainloader, optimizer, epochs, self.device)


        # Apply differential privacy to the model's parameters before sending them to the server
        dp_parameters = self.apply_differential_privacy(self.model.state_dict())


        # Logging that DP has been applied with specific settings
        logger.info(f"[Client {self.client_id}] Applied Differential Privacy (ε={self.epsilon}, δ={self.delta})")


        return [dp_parameters[key].cpu().numpy() for key in dp_parameters], len(self.trainloader), {"loss": train_loss, "accuracy": train_accuracy}


    def apply_differential_privacy(self, state_dict):
        """Adds noise to the model parameters to ensure differential privacy."""

        dp_parameters = {}
        for name, param in state_dict.items():
            if param.dtype == torch.long:
                param = param.float()                           # Convert to float for noise addition
            noise = torch.randn_like(param) * self.noise_scale  # Generate noise based on the noise scale
            dp_parameters[name] = param + noise                 # Add noise to the parameters
        return dp_parameters


    def evaluate(self, parameters: NDArrays, config: Dict[str, Scalar]):
        """Evaluates the model on the validation data and returns the loss and accuracy."""

        self.set_parameters(parameters)                         # Set the model parameters received from the server

        # Evaluate the model on the validation data
        eval_loss, eval_accuracy = test(self.model, self.valloader, self.device)

        logger.info(f"Client {self.client_id} evaluation results: Loss = {eval_loss:.4f}, Accuracy = {eval_accuracy:.4f}")

        return float(eval_loss), len(self.valloader), {"loss": eval_loss, "accuracy": eval_accuracy}



def generate_client_fn(trainloaders, valloaders, num_classes, dataset_choice, epsilon=0.1, delta=1e-5, noise_scale=0.01):
    """
    This function generates a unique initialization function for each client, ensuring that it is dynamically generated.

    Arguments:
        trainloaders: List of training DataLoaders
        valloaders: List of validation DataLoaders
        num_classes: Number of output classes for the model
        dataset_choice: The dataset being used ('MNIST' or 'CIFAR10').
        epsilon: Privacy budget for DP.
        delta: Privacy tolerance for DP.
        noise_scale: Scale of noise added for DP.

    Returns:
        A function that initializes a unique FlowerClient for each client.
        It is necessary to start the simulation of federated learning environment using FLower framework
    """


    def client_fn(context: Context):
        cid = int(context.node_id) % len(trainloaders)              # Make sure the client ID is within a valid range
        logger.info(f"Initializing client {cid}")

        return FlowerClient(
                            trainloader=trainloaders[cid],
                            valloader=valloaders[cid],
                            num_classes=num_classes,
                            dataset_choice=dataset_choice,
                            client_id=cid,
                            epsilon=epsilon,
                            delta=delta,
                            noise_scale=noise_scale
                        ).to_client()

    return client_fn


## Server-Side Operations (server.py):

Defines the central server's responsibilities in aggregating model updates from clients using Secure Aggregation, which ensures that sensitive client information remains private. The server also coordinates training sessions and aggregates parameters and evaluation metrics from clients.

In [None]:
'''
server.py
This script defines the server-side operations in a federated learning setup.
Servers aggregate updates from clients, apply secure aggregation techniques to keep sensitive information private and coordinate training sessions.
Additionally, both parameters and evaluation metrics can be aggregated using custom functions.
'''
import torch
#from model import get_model, test
from typing import Dict, Any, List, Tuple
from flwr.common import NDArrays, Scalar, FitRes, Parameters, parameters_to_ndarrays, ndarrays_to_parameters
import numpy as np
import logging
import flwr as fl
import os

# Logging server operations and issues that occur during the process
logging.basicConfig(level=logging.INFO, format='[%(asctime)s] [%(levelname)s] [%(name)s] - %(message)s')
logger = logging.getLogger(__name__)


# Secure Aggregation Function
def secure_aggregation(parameters_list: List[NDArrays]) -> Parameters:
    """
    Aggregates model parameters from clients securely.
    Arguments:
        parameters_list: List of parameter from each client.
    Returns:
        Parameters: Aggregated parameters after secure aggregation.

    In this function, all parameters from all clients are combined.
    The idea is to perform this aggregation securely and avoid exposing any client's data individually.
    In order to calculate the average value across all participants, we sum the parameters from each client and divide them by the number of clients.
    """

    if not parameters_list or len(parameters_list) == 0:
        raise ValueError("The parameters list is empty or None.")           # Handle cases where there are no parameters

    aggregated_params = None
    num_clients = len(parameters_list)


    for param_ndarrays in parameters_list:                                                       # Aggregate parameters from each client
        if aggregated_params is None:
            aggregated_params = [np.zeros_like(p) for p in param_ndarrays]

        aggregated_params = [p1 + p2 for p1, p2 in zip(aggregated_params, param_ndarrays)]       # Sum the parameters across clients

    aggregated_params = [p / num_clients for p in aggregated_params]                             # Divide by the number of clients to get the average

    return ndarrays_to_parameters(aggregated_params)


# Ensure the directory for saving model weights exists
weights_dir = './model_weights'
if not os.path.exists(weights_dir):
    os.makedirs(weights_dir)

# After training and aggregating the model parameters in secure_aggregation, save the weights.
def save_model_weights(model, server_round):
    """Save the global model weights after each round."""
    model_path = os.path.join(weights_dir, f'model_weights_round_{server_round}.pth')
    torch.save(model.state_dict(), model_path)
    logger.info(f"Model weights saved at {model_path} for round {server_round}")

class SecureAggregationFedAvg(fl.server.strategy.FedAvg):                                         # Custom strategy that includes secure aggregation
    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[int, FitRes]],
        failures: List[BaseException],
        ) -> Tuple[Parameters, Dict[str, Scalar]]:
        """
        Using the secure aggregation instead of the default behavior.

        Arguments:
            server_round: The current round of federated training.
            results: List of client training results.
            failures: List of clients that failed during the round.

        Returns:
            Tuple of aggregated parameters and an empty dictionary for additional information.

        As part of this strategy, failures that occur during training rounds are logged, and data aggregation is secure to protect client information.
        The method ensures the privacy of the client's contributions.
        """

        if failures:
            logger.warning(f"Round {server_round}: {len(failures)} clients failed.")                  # Log any client failures

        # Convert client parameters to ndarray format for aggregation
        parameters_list = [parameters_to_ndarrays(res.parameters) for _, res in results if res.parameters is not None]

        if not parameters_list:
            raise ValueError("The parameters list is empty or None.")


        aggregated_params = secure_aggregation(parameters_list)                                      # Applying secure aggregation to the parameters

        logger.info(f"Round {server_round}: Applied Secure Aggregation on client parameters")

        return aggregated_params, {}


# Fonfiguring the training settings per round
def get_on_fit_config(config):
    """
    Generates configuration settings for each training round.

    Argumants:
        config: Configuration object that contains the hyperparameters.

    Returns:
        A function that generates the configuration for each server round.

    The purpose of this function is to customize things like learning rate, weight decay, and number of epochs for each training round.
    """
    def fit_config_fn(server_round: int):
        return {

            'lr': config['config_fit']['lr'],
            'weight_decay': config['config_fit']['weight_decay'],
            'momentum': config['config_fit']['momentum'],
            'local_epochs': config['config_fit']['local_epochs'],
            'batch_size': config['config_fit']['batch_size'],
            'clip_value': config['config_fit']['clip_value'],
        }

    return fit_config_fn


def get_evaluate_fn(num_classes, testloader, dataset_choice):
    """
    Generates the evaluation function to test the global model.

    Arguments:
        num_classes: The number of output classes in the dataset.
        testloader: DataLoader for the test dataset.
        dataset_choice: The dataset being used ('MNIST' or 'CIFAR10').

    Returns:
        A function that evaluates the model after each training round.

    The purpose of this function is measuring how well the global model is doing after each round.
    The latest global model parameters will load, and the global model will test on the test dataset
    """
    def evaluate_fn(server_round: int, parameters: NDArrays, config: Dict[str, Scalar]):

        model = get_model(dataset_choice, num_classes)  # Get the model architecture based on the dataset choice

        # Load the received parameters into the model's state_dict safely with weights_only=True
        state_dict = {k: torch.tensor(v) for k, v in zip(model.state_dict().keys(), parameters)}

        # Instead of directly calling load_state_dict, ensure weights_only is applied for safety
        model.load_state_dict(state_dict, strict=True)

        # Evaluate the global model on the test dataset
        loss, accuracy = test(model, testloader, device='cuda:0' if torch.cuda.is_available() else 'cpu')

        logger.info(f"Global model evaluation after round {server_round}: Loss = {loss:.4f}, Accuracy = {accuracy:.4f}")

        return float(loss), {"accuracy": float(accuracy)}

    return evaluate_fn



def fit_metrics_aggregation_fn(metrics_list):                                                          # Aggregation function for training metrics
    """
    Aggregates training metrics (accuracy and loss) from all clients.

    Arguments:
        metrics_list: List of metrics dictionaries from clients.

    Returns:
        A dictionary with the average accuracy and loss.

    The purpose of this function is to combine the accuracy and loss values from all clients to calculate the global average
    and utilize it in providing an overview of the training performance after each round.
    """
    accuracies = [metrics.get("accuracy", 0) for _, metrics in metrics_list]
    losses = [metrics.get("loss", 0) for _, metrics in metrics_list]
    avg_accuracy = np.mean(accuracies)
    avg_loss = np.mean(losses)

    return {"accuracy": avg_accuracy, "loss": avg_loss}



def evaluate_metrics_aggregation_fn(metrics_list):                                                      # Aggregation function for evaluation metrics
    """
    Aggregation of the evaluation metrics (accuracy and loss) from all clients.

    Arguments:
        metrics_list: List of evaluation metrics from clients.

    Returns:
        A dictionary with the average accuracy and loss.

    Similar to the fit_metrics_aggregation_fn, this function aggregates the evaluation metrics
    (accuracy and loss) from all clients to provide a global view of the model's performance.
    """
    accuracies = [metrics.get("accuracy", 0) for _, metrics in metrics_list]
    losses = [metrics.get("loss", 0) for _, metrics in metrics_list]
    avg_accuracy = np.mean(accuracies)
    avg_loss = np.mean(losses)

    return {"accuracy": avg_accuracy, "loss": avg_loss}


## Federated Learning Orchestration (main.py):

The entry point for running the federated learning simulation. This script sets up the clients, dataset, and training strategy, and coordinates multiple rounds of communication between clients and the server, ensuring that results are logged after each round.

In [None]:
import multiprocessing

multiprocessing.set_start_method('spawn', force=True) # Set multiprocessing start method to avoid os.fork() issues in multithreaded environments
'''
main.py
This script is the entry point for this federated learning simulation using the Flower framework.
It sets up the simulation, including the dataset, clients, and training strategy, then runs the federated learning process over multiple rounds.
Furthermore, it applies secure aggregation to ensure privacy and logs the results after each round.
'''
import logging
import pickle
from pathlib import Path
import os
import hydra
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig
import flwr as fl
#from dataset import prepare_dataset
#from client import generate_client_fn
#from server import get_on_fit_config, get_evaluate_fn, SecureAggregationFedAvg, fit_metrics_aggregation_fn, evaluate_metrics_aggregation_fn
import numpy as np


os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ['HYDRA_FULL_ERROR'] = '1'
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'


# Configuration for logging and displaying the progress of the FL simulation
logging.basicConfig(level=logging.INFO, format='[%(asctime)s] [%(levelname)s] [%(name)s] - %(message)s')
logger = logging.getLogger(__name__)



@hydra.main(config_path="conf", config_name="base", version_base=None)
def main(cfg):

    dataset_choice = input("Choose a dataset for FL simulation (MNIST/CIFAR10): ").strip().lower()          # Prompt the user to select a dataset (MNIST or CIFAR10)

    if dataset_choice not in ["mnist", "cifar10"]:                                                          # Validate the dataset choice and ensure it's either MNIST or CIFAR10
        raise ValueError("Invalid dataset choice. Please choose 'MNIST' or 'CIFAR10'.")

    # Logging the basic configuration parameters
    logger.info(f"num_rounds: {cfg['num_rounds']}, num_clients: {cfg['num_clients']}, batch_size: {cfg['batch_size']}")
    logger.info(f"num_clients_per_round_fit: {cfg['num_clients_per_round_fit']}, num_clients_per_round_eval: {cfg['num_clients_per_round_eval']}")
    logger.info(f"Chosen dataset: {dataset_choice.upper()}")

    # Preparation of the dataset by splitting it into client-based loaders for training, validation, and testing
    trainloaders, validationloaders, testloader = prepare_dataset(cfg['num_clients'], cfg['batch_size'], dataset_choice)
    logger.info(f"Number of clients: {cfg['num_clients']}, Client 0 dataset size: {len(trainloaders[0].dataset)}")

    # Client function generation to initialize clients using the relevant data and privacy settings
    client_fn = generate_client_fn(
                                trainloaders,
                                validationloaders,
                                cfg['num_classes'],
                                dataset_choice,
                                epsilon=cfg['dp']['epsilon'],
                                delta=cfg['dp']['delta'],
                                noise_scale=cfg['dp']['noise_scale']
                                )


    # Dederated learning strategy definition with secure aggregation
    strategy = SecureAggregationFedAvg(
                fraction_fit=1.0,                                                               # All clients participate in training each round
                min_fit_clients=cfg['num_clients'],                                                # Minimum number of clients needed for training
                fraction_evaluate=1.0,                                                          # All clients participate in evaluation each round
                min_evaluate_clients=cfg['num_clients'],                                           # Minimum number of clients needed for evaluation
                min_available_clients=cfg['num_clients'],                                          # Minimum clients that need to be available
                on_fit_config_fn=get_on_fit_config(cfg),                                        # Configuration settings for each training round
                evaluate_fn=get_evaluate_fn(cfg['num_classes'], testloader, dataset_choice),       # Function to evaluate the global model
                fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,                          # Aggregation of training metrics
                evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,                # Aggregation of evaluation metrics
    )


    # Start point of the the Flower simulation with the defined client function, strategy, and number of rounds
    logger.info("Starting Flower simulation...")
    '''
    history = fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=cfg['num_clients'],
        config=fl.server.ServerConfig(num_rounds=cfg['num_rounds']),                               # Number of rounds for the simulation
        strategy=strategy,
        client_resources={
            'num_cpus': cfg['client_resources']['num_cpus'],  # Resource usage handling per client
            'num_gpus': cfg['client_resources']['num_gpus']
        }
    )
    '''
    # Modify resource allocation here for Colab environment
    history = fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=cfg['num_clients'],
        config=fl.server.ServerConfig(num_rounds=cfg['num_rounds']),
        strategy=strategy,
        client_resources={
            'num_cpus': cfg['client_resources']['num_cpus'],  # Reduce CPU per client
            'num_gpus': cfg['client_resources']['num_gpus']   # Set GPU to 0 if unavailable
        }
    )


    # Saving the results
    results_path = Path('./') / 'results.pkl'
    with open(str(results_path), 'wb') as h:
        pickle.dump({'history': history}, h, protocol=pickle.HIGHEST_PROTOCOL)


    # Logging a summary of the completed simulation
    logger.info(f"Run finished {cfg['num_rounds']} round(s)")


    # Displaying distributed training loss for each round
    if hasattr(history, 'losses_distributed'):
        logger.info("\nLoss (Distributed Training):")
        for i, loss in enumerate(history.losses_distributed, 1):
            logger.info(f"    Round {i}: loss = {loss}")
        if len(history.losses_distributed) < cfg['num_rounds']:
            missing_rounds = set(range(1, cfg['num_rounds'] + 1)) - set(range(1, len(history.losses_distributed) + 1))
            logger.info(f"    Missing loss data for rounds: {missing_rounds}")


    # Displaying centralized evaluation loss for each round
    if hasattr(history, 'losses_centralized'):
        logger.info("\nLoss (Centralized Evaluation):")
        for i, loss in enumerate(history.losses_centralized, 0):
            logger.info(f"    Round {i}: {loss[0]:.2f}")


    # Displaying distributed training accuracy for each round
    if hasattr(history, 'metrics_distributed'):
        logger.info("\nAccuracy (Distributed Training):")
        for i, (round_num, acc) in enumerate(history.metrics_distributed.get("accuracy", []), 1):
            logger.info(f"    Round {i}: {acc * 100:.2f}%")
        if len(history.metrics_distributed.get("accuracy", [])) < cfg['num_rounds']:
            missing_rounds = set(range(1, cfg['num_rounds'] + 1)) - set(range(1, len(history.metrics_distributed.get("accuracy", [])) + 1))
            logger.info(f"    Missing accuracy data for rounds: {missing_rounds}")


    # Logging that secure aggregation has been applied in each round
    logger.info("Secure Aggregation applied in each round")


    # Final logging statement indicating that the simulation has completed
    logger.info("Flower simulation finished")




main(cfg)
