<a href="https://colab.research.google.com/github/DawnSpider96/L361-Federated-Learning/blob/release/Copy_of_L361_2025_Lab_1_From_Centralised_To_Federated_Part_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Dependencies
---


In [1]:
# `pip` could produce some errors. Do not worry about them.
# The execution has been verified; it's working anyway.
# ! pip install --quiet --upgrade "pip"
# ! pip install --quiet matplotlib tqdm seaborn
# ! pip install git+https://github.com/Iacob-Alexandru-Andrei/flower.git@teaching \
#     torch torchvision ray=="2.6.3"

### Imports.


In [2]:
# import sys
# sys.path.append('../../')
# print(sys.path)

In [3]:
import random
from pathlib import Path
import tarfile
from typing import Any
from logging import INFO
from collections import defaultdict, OrderedDict
from collections.abc import Sequence, Callable
import numbers

import numpy as np
import torch
from torch import nn
from torch.nn import Module
from torch.utils.data import DataLoader, Dataset
from enum import IntEnum
import flwr
from flwr.server import History, ServerConfig
from flwr.server.strategy import FedAvgM as FedAvg, Strategy
from c2m3.flower.fed_frank_wolfe_strategy import FrankWolfeSync
from flwr.common import log, NDArrays, Scalar, Parameters, ndarrays_to_parameters
from flwr.client.client import Client

from c2m3.common.client_utils import (
    Net,
    load_femnist_dataset,
    get_network_generator_cnn as get_network_generator,
    train_femnist,
    test_femnist,
    save_history,
    get_model_parameters,
    set_model_parameters
)


# Add new seeds here for easy autocomplete
class Seeds(IntEnum):
    """Seeds for reproducibility."""

    DEFAULT = 1337


np.random.seed(Seeds.DEFAULT)
random.seed(Seeds.DEFAULT)
torch.manual_seed(Seeds.DEFAULT)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


PathType = Path | str | None


def get_device() -> str:
    """Get the device (cuda, mps, cpu)."""
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"
    elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
        device = "mps"
    return device

2025-03-15 22:03:45.549179: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-15 22:03:45.721313: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from .autonotebook import tqdm as notebook_tqdm
2025-03-15 22:03:48,846	INFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [4]:
home_dir = Path.cwd() / ".."
dataset_dir: Path = home_dir / "femnist"
data_dir: Path = dataset_dir / "data"
centralized_partition: Path = dataset_dir / "client_data_mappings" / "centralized"
centralized_mapping: Path = dataset_dir / "client_data_mappings" / "centralized" / "0"
federated_partition: Path = dataset_dir / "client_data_mappings" / "fed_natural"

# Decompress dataset
if not dataset_dir.exists():
    with tarfile.open(home_dir / "femnist.tar.gz", "r:gz") as tar:
        tar.extractall(path=home_dir)
    log(INFO, "Dataset extracted in %s", dataset_dir)

## Build Flower FL client.
---

In [5]:
class FlowerRayClient(flwr.client.NumPyClient):
    """Flower client for the FEMNIST dataset."""

    def __init__(
        self,
        cid: int,
        partition_dir: Path,
        model_generator: Callable[[], Module],
    ) -> None:
        """Init the client with its unique id and the folder to load data from.

        Parameters
        ----------
            cid (int): Unique client id for a client used to map it to its data
                partition
            partition_dir (Path): The directory containing data for each
                client/client id
            model_generator (Callable[[], Module]): The model generator function
        """
        self.cid = cid
        log(INFO, "cid: %s", self.cid)
        self.partition_dir = partition_dir
        self.device = str(
            torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        )
        self.model_generator: Callable[[], Module] = model_generator
        self.properties: dict[str, Scalar] = {
            "tensor_type": "numpy.ndarray",
            "partition": self.partition_dir,
            "cid": self.cid
            }
        self.data_dir = data_dir

    def set_parameters(self, parameters: NDArrays) -> Module:
        """Load weights inside the network.

        Parameters
        ----------
            parameters (NDArrays): set of weights to be loaded.

        Returns
        -------
            [Module]: Network with new set of weights.
        """
        net = self.model_generator()
        return set_model_parameters(net, parameters)

    def get_parameters(self, config: dict[str, Scalar]) -> NDArrays:
        """Return weights from a given model.

        If no model is passed, then a local model is created.
        This can be used to initialise a model in the
        server.
        The config param is not used but is mandatory in Flower.

        Parameters
        ----------
            config (dict[int, Scalar]): dictionary containing configuration info.

        Returns
        -------
            NDArrays: weights from the model.
        """
        net = self.model_generator()
        return get_model_parameters(net)

    def fit(
        self, parameters: NDArrays, config: dict[str, Scalar]
    ) -> tuple[NDArrays, int, dict]:
        """Receive and train a model on the local client data.

        It uses parameters from the config dict

        Parameters
        ----------
            net (NDArrays): Pytorch model parameters
            config (dict[str, Scalar]): dictionary describing the training parameters

        Returns
        -------
            tuple[NDArrays, int, dict]: Returns the updated model, the size of the local
                dataset and other metrics
        """
        # Only create model right before training/testing
        # To lower memory usage when idle
        net = self.set_parameters(parameters)
        net.to(self.device)

        train_loader: DataLoader = self._create_data_loader(config, name="train")
        train_loss = self._train(net, train_loader=train_loader, config=config)
        return get_model_parameters(net), len(train_loader), {"train_loss": train_loss}

    def evaluate(
        self, parameters: NDArrays, config: dict[str, Scalar]
    ) -> tuple[float, int, dict]:
        """Receive and test a model on the local client data.

        It uses parameters from the config dict

        Parameters
        ----------
            net (NDArrays): Pytorch model parameters
            config (dict[str, Scalar]): dictionary describing the testing parameters

        Returns
        -------
            tuple[float, int, dict]: Returns the loss accumulate during testing, the
                size of the local dataset and other metrics such as accuracy
        """
        net = self.set_parameters(parameters)
        net.to(self.device)

        test_loader: DataLoader = self._create_data_loader(config, name="test")
        loss, accuracy = self._test(net, test_loader=test_loader, config=config)
        return loss, len(test_loader), {"local_accuracy": accuracy}

    def _create_data_loader(self, config: dict[str, Scalar], name: str) -> DataLoader:
        """Create the data loader using the specified config parameters.

        Parameters
        ----------
            config (dict[str, Scalar]): dictionary containing dataloader and dataset
                parameters
            mode (str): Load the training or testing set for the client

        Returns
        -------
            DataLoader: A pytorch dataloader iterable for training/testing
        """
        batch_size = int(config["batch_size"])
        num_workers = int(config["num_workers"])
        dataset = self._load_dataset(name)
        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            drop_last=(name == "train"),
        )

    def _load_dataset(self, name: str) -> Dataset:
        full_file: Path = self.partition_dir / str(self.cid)
        return load_femnist_dataset(
            mapping=full_file,
            name=name,
            data_dir=data_dir,
        )

    def _train(
        self, net: Module, train_loader: DataLoader, config: dict[str, Scalar]
    ) -> float:
        return train_femnist(
            net=net,
            train_loader=train_loader,
            epochs=int(config["epochs"]),
            device=self.device,
            optimizer=torch.optim.AdamW(
                net.parameters(),
                lr=float(config["client_learning_rate"]),
                weight_decay=float(config["weight_decay"]),
            ),
            criterion=torch.nn.CrossEntropyLoss(),
            max_batches=int(config["max_batches"]),
        )

    def _test(
        self, net: Module, test_loader: DataLoader, config: dict[str, Scalar]
    ) -> tuple[float, float]:
        return test_femnist(
            net=net,
            test_loader=test_loader,
            device=self.device,
            criterion=torch.nn.CrossEntropyLoss(),
            max_batches=int(config["max_batches"]),
        )

    def get_properties(self, config: dict[str, Scalar]) -> dict[str, Scalar]:
        """Return properties for this client.

        Parameters
        ----------
            config (dict[str, Scalar]): Options to be used for selecting specific
            properties.

        Returns
        -------
            dict[str, Scalar]: Returned properties.
        """
        return self.properties

    def get_train_set_size(self) -> int:
        """Return the client train set size.

        Returns
        -------
            int: train set size of the client.
        """
        return len(self._load_dataset("train"))  # type: ignore[reportArgumentType]

    def get_test_set_size(self) -> int:
        """Return the client test set size.

        Returns
        -------
            int: test set size of the client.
        """
        return len(self._load_dataset("test"))  # type: ignore[reportArgumentType]


def fit_client_seeded(
    client: FlowerRayClient,
    params: NDArrays,
    conf: dict[str, Any],
    seed: Seeds = Seeds.DEFAULT,
    **kwargs: Any,
) -> tuple[NDArrays, int, dict]:
    """Wrap to always seed client training."""
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)
    return client.fit(params, conf, **kwargs)

The underlying FL simulator used by Flower is based on [Ray](https://www.ray.io/). It expects each client only to require a client ID for instantiation. Therefore, using the following generator function, we can determine the specific network used for FL together with the FEMNIST partition to which the `cid` refers.

While we will not use `Ray` in this lab due to its heavyweight nature, we will keep all code API compatible with the default flower framework.

In [6]:
def get_flower_client_generator(
    model_generator: Callable[[], Module],
    partition_dir: Path,
    mapping_fn: Callable[[int], int] | None = None,
) -> Callable[[str], FlowerRayClient]:
    """Wrap the client instance generator.

    This provides the client generator with a model generator function.
    Also, the partition directory must be passed.
    A mapping function could be used for filtering/ordering clients.

    Parameters
    ----------
        model_generator (Callable[[], Module]): model generator function.
        partition_dir (Path): directory containing the partition.
        mapping_fn (Optional[Callable[[int], int]]): function mapping sorted/filtered
            ids to real cid.

    Returns
    -------
        Callable[[str], FlowerRayClient]: client instance.
    """

    def client_fn(cid: str) -> FlowerRayClient:
        """Create a single client instance given the client id `cid`.

        Parameters
        ----------
            cid (str): client id, Flower requires this to be of type str.

        Returns
        -------
            FlowerRayClient: client instance.
        """
        return FlowerRayClient(
            cid=mapping_fn(int(cid)) if mapping_fn is not None else int(cid),
            partition_dir=partition_dir,
            model_generator=model_generator,
        )

    return client_fn

To ensure the Flower client behaves the same as our simple demo client, a simple test using the centralised partition we defined earlier should suffice.

In [7]:
network_generator = get_network_generator()
seed_net: Net = network_generator()
seed_model_params: NDArrays = get_model_parameters(seed_net)

centralized_flower_client_generator: Callable[[str], FlowerRayClient] = (
    get_flower_client_generator(network_generator, centralized_partition)
)
centralized_flower_client = centralized_flower_client_generator(str(0))

INFO flwr 2025-03-15 22:03:52,810 | 1340840146.py:21 | cid: 0


In [8]:
seed_net

Net(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=62, bias=True)
)

In [9]:
seed_net.state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[-0.1687, -0.0018,  0.0492, -0.0310, -0.1198],
                        [-0.1885,  0.0340,  0.0787, -0.1295, -0.0962],
                        [ 0.0834,  0.0324, -0.1770,  0.1067,  0.1511],
                        [-0.1027,  0.0402,  0.0832,  0.0041, -0.0374],
                        [ 0.1546, -0.0569,  0.1585,  0.1856, -0.1941]]],
              
              
                      [[[ 0.0943,  0.1154,  0.1787, -0.0138,  0.1351],
                        [-0.1492, -0.1147, -0.0823, -0.0529, -0.0068],
                        [ 0.0701,  0.0090,  0.0214,  0.0280, -0.1008],
                        [ 0.0416,  0.0831, -0.1684,  0.1052,  0.1083],
                        [-0.1703,  0.0304,  0.0741,  0.1640,  0.1074]]],
              
              
                      [[[ 0.0854, -0.1765,  0.1057,  0.1865,  0.1976],
                        [ 0.1198, -0.1540,  0.0689, -0.1652,  0.0416],
                        [-0.0303, -0.1599, -0.1664,  0

In [10]:
centralized_train_config: dict[str, Any] = {
    "epochs": 1,
    "batch_size": 32,
    "client_learning_rate": 0.01,
    "weight_decay": 0.001,
    "num_workers": 0,
    "max_batches": 100,
}

test_config: dict[str, Any] = {
    "batch_size": 32,
    "num_workers": 0,
    "max_batches": 100,
}

# Train parameters on the centralised dataset
trained_params, num_examples, train_metrics = fit_client_seeded(
    centralized_flower_client, params=seed_model_params, conf=centralized_train_config
)
log(INFO, "Train Metrics = %s", train_metrics)

INFO flwr 2025-03-15 22:04:01,124 | 4024557794.py:20 | Train Metrics = {'train_loss': 0.11729214794933795}


In [11]:
def sample_random_clients(
    total_clients: int,
    filter_less: int,
    partition: Path,
    seed: int | None = Seeds.DEFAULT,
) -> Sequence[int]:
    """Sample randomly clients.

    A filter on the client train set size is performed.

    Parameters
    ----------
        total_clients (int): total number of clients to sample.
        filter_less (int): max number of train samples for which the client is
            **discarded**.
        partition (Path): path to the folder containing the partitioning.
        seed (Optional[int], optional): seed for the random generator. Defaults to None.

    Returns
    -------
        Sequence[int]: list of sample client ids as int.
    """
    real_federated_cid_client_generator: Callable[[str], FlowerRayClient] = (
        get_flower_client_generator(network_generator, federated_partition)
    )
    if seed is not None:
        random.seed(seed)
    list_of_ids = []
    while len(list_of_ids) < total_clients:
        current_id = random.randint(0, 3229)
        if (
            real_federated_cid_client_generator(str(current_id)).get_train_set_size()
            > filter_less
        ):
            list_of_ids.append(current_id)
    return list_of_ids

While FEMNIST has more than 3000 clients, our small-scale experiments will not require more than 100 at any point.

In [12]:
total_clients: int = 100
list_of_ids = sample_random_clients(
    total_clients, 32, federated_partition
)

federated_client_generator: Callable[[str], FlowerRayClient] = (
    get_flower_client_generator(
        network_generator, federated_partition, lambda seq_id: list_of_ids[seq_id]
    )
)

INFO flwr 2025-03-15 22:04:01,164 | 1340840146.py:21 | cid: 2530


INFO flwr 2025-03-15 22:04:01,173 | 1340840146.py:21 | cid: 2997


INFO flwr 2025-03-15 22:04:01,178 | 1340840146.py:21 | cid: 1473


INFO flwr 2025-03-15 22:04:01,182 | 1340840146.py:21 | cid: 2688


INFO flwr 2025-03-15 22:04:01,187 | 1340840146.py:21 | cid: 2601


INFO flwr 2025-03-15 22:04:01,192 | 1340840146.py:21 | cid: 1425


INFO flwr 2025-03-15 22:04:01,197 | 1340840146.py:21 | cid: 1273


INFO flwr 2025-03-15 22:04:01,203 | 1340840146.py:21 | cid: 1887


INFO flwr 2025-03-15 22:04:01,208 | 1340840146.py:21 | cid: 2828


INFO flwr 2025-03-15 22:04:01,214 | 1340840146.py:21 | cid: 823


INFO flwr 2025-03-15 22:04:01,219 | 1340840146.py:21 | cid: 3168


INFO flwr 2025-03-15 22:04:01,224 | 1340840146.py:21 | cid: 220


INFO flwr 2025-03-15 22:04:01,229 | 1340840146.py:21 | cid: 1167


INFO flwr 2025-03-15 22:04:01,235 | 1340840146.py:21 | cid: 1287


INFO flwr 2025-03-15 22:04:01,241 | 1340840146.py:21 | cid: 2343


INFO flwr 2025-03-15 22:04:01,246 | 1340840146.py:21 | cid: 2975


INFO flwr 2025-03-15 22:04:01,251 | 1340840146.py:21 | cid: 740


INFO flwr 2025-03-15 22:04:01,255 | 1340840146.py:21 | cid: 206


INFO flwr 2025-03-15 22:04:01,261 | 1340840146.py:21 | cid: 100


INFO flwr 2025-03-15 22:04:01,268 | 1340840146.py:21 | cid: 2985


INFO flwr 2025-03-15 22:04:01,272 | 1340840146.py:21 | cid: 2640


INFO flwr 2025-03-15 22:04:01,276 | 1340840146.py:21 | cid: 241


INFO flwr 2025-03-15 22:04:01,282 | 1340840146.py:21 | cid: 1800


INFO flwr 2025-03-15 22:04:01,287 | 1340840146.py:21 | cid: 2271


INFO flwr 2025-03-15 22:04:01,291 | 1340840146.py:21 | cid: 175


INFO flwr 2025-03-15 22:04:01,296 | 1340840146.py:21 | cid: 1889


INFO flwr 2025-03-15 22:04:01,301 | 1340840146.py:21 | cid: 2772


INFO flwr 2025-03-15 22:04:01,305 | 1340840146.py:21 | cid: 557


INFO flwr 2025-03-15 22:04:01,310 | 1340840146.py:21 | cid: 654


INFO flwr 2025-03-15 22:04:01,315 | 1340840146.py:21 | cid: 796


INFO flwr 2025-03-15 22:04:01,320 | 1340840146.py:21 | cid: 651


INFO flwr 2025-03-15 22:04:01,324 | 1340840146.py:21 | cid: 161


INFO flwr 2025-03-15 22:04:01,330 | 1340840146.py:21 | cid: 1257


INFO flwr 2025-03-15 22:04:01,336 | 1340840146.py:21 | cid: 1469


INFO flwr 2025-03-15 22:04:01,340 | 1340840146.py:21 | cid: 1552


INFO flwr 2025-03-15 22:04:01,344 | 1340840146.py:21 | cid: 3123


INFO flwr 2025-03-15 22:04:01,348 | 1340840146.py:21 | cid: 1887


INFO flwr 2025-03-15 22:04:01,353 | 1340840146.py:21 | cid: 1127


INFO flwr 2025-03-15 22:04:01,360 | 1340840146.py:21 | cid: 571


INFO flwr 2025-03-15 22:04:01,366 | 1340840146.py:21 | cid: 800


INFO flwr 2025-03-15 22:04:01,372 | 1340840146.py:21 | cid: 214


INFO flwr 2025-03-15 22:04:01,378 | 1340840146.py:21 | cid: 1345


INFO flwr 2025-03-15 22:04:01,382 | 1340840146.py:21 | cid: 2756


INFO flwr 2025-03-15 22:04:01,387 | 1340840146.py:21 | cid: 2638


INFO flwr 2025-03-15 22:04:01,392 | 1340840146.py:21 | cid: 609


INFO flwr 2025-03-15 22:04:01,399 | 1340840146.py:21 | cid: 3002


INFO flwr 2025-03-15 22:04:01,403 | 1340840146.py:21 | cid: 85


INFO flwr 2025-03-15 22:04:01,407 | 1340840146.py:21 | cid: 912


INFO flwr 2025-03-15 22:04:01,412 | 1340840146.py:21 | cid: 1888


INFO flwr 2025-03-15 22:04:01,419 | 1340840146.py:21 | cid: 1281


INFO flwr 2025-03-15 22:04:01,426 | 1340840146.py:21 | cid: 419


INFO flwr 2025-03-15 22:04:01,434 | 1340840146.py:21 | cid: 1225


INFO flwr 2025-03-15 22:04:01,441 | 1340840146.py:21 | cid: 806


INFO flwr 2025-03-15 22:04:01,446 | 1340840146.py:21 | cid: 1657


INFO flwr 2025-03-15 22:04:01,453 | 1340840146.py:21 | cid: 241


INFO flwr 2025-03-15 22:04:01,458 | 1340840146.py:21 | cid: 2259


INFO flwr 2025-03-15 22:04:01,464 | 1340840146.py:21 | cid: 277


INFO flwr 2025-03-15 22:04:01,470 | 1340840146.py:21 | cid: 2376


INFO flwr 2025-03-15 22:04:01,475 | 1340840146.py:21 | cid: 1001


INFO flwr 2025-03-15 22:04:01,481 | 1340840146.py:21 | cid: 2670


INFO flwr 2025-03-15 22:04:01,487 | 1340840146.py:21 | cid: 638


INFO flwr 2025-03-15 22:04:01,493 | 1340840146.py:21 | cid: 1226


INFO flwr 2025-03-15 22:04:01,498 | 1340840146.py:21 | cid: 52


INFO flwr 2025-03-15 22:04:01,504 | 1340840146.py:21 | cid: 219


INFO flwr 2025-03-15 22:04:01,510 | 1340840146.py:21 | cid: 1727


INFO flwr 2025-03-15 22:04:01,515 | 1340840146.py:21 | cid: 752


INFO flwr 2025-03-15 22:04:01,521 | 1340840146.py:21 | cid: 2620


INFO flwr 2025-03-15 22:04:01,526 | 1340840146.py:21 | cid: 2469


INFO flwr 2025-03-15 22:04:01,533 | 1340840146.py:21 | cid: 83


INFO flwr 2025-03-15 22:04:01,540 | 1340840146.py:21 | cid: 427


INFO flwr 2025-03-15 22:04:01,546 | 1340840146.py:21 | cid: 2906


INFO flwr 2025-03-15 22:04:01,552 | 1340840146.py:21 | cid: 3183


INFO flwr 2025-03-15 22:04:01,559 | 1340840146.py:21 | cid: 885


INFO flwr 2025-03-15 22:04:01,567 | 1340840146.py:21 | cid: 253


INFO flwr 2025-03-15 22:04:01,575 | 1340840146.py:21 | cid: 1634


INFO flwr 2025-03-15 22:04:01,580 | 1340840146.py:21 | cid: 2953


INFO flwr 2025-03-15 22:04:01,586 | 1340840146.py:21 | cid: 1835


INFO flwr 2025-03-15 22:04:01,592 | 1340840146.py:21 | cid: 2758


INFO flwr 2025-03-15 22:04:01,597 | 1340840146.py:21 | cid: 592


INFO flwr 2025-03-15 22:04:01,605 | 1340840146.py:21 | cid: 670


INFO flwr 2025-03-15 22:04:01,612 | 1340840146.py:21 | cid: 1983


INFO flwr 2025-03-15 22:04:01,619 | 1340840146.py:21 | cid: 2457


INFO flwr 2025-03-15 22:04:01,625 | 1340840146.py:21 | cid: 351


INFO flwr 2025-03-15 22:04:01,632 | 1340840146.py:21 | cid: 2995


INFO flwr 2025-03-15 22:04:01,638 | 1340840146.py:21 | cid: 2885


INFO flwr 2025-03-15 22:04:01,643 | 1340840146.py:21 | cid: 227


INFO flwr 2025-03-15 22:04:01,649 | 1340840146.py:21 | cid: 2689


INFO flwr 2025-03-15 22:04:01,654 | 1340840146.py:21 | cid: 2343


INFO flwr 2025-03-15 22:04:01,659 | 1340840146.py:21 | cid: 817


INFO flwr 2025-03-15 22:04:01,664 | 1340840146.py:21 | cid: 887


INFO flwr 2025-03-15 22:04:01,670 | 1340840146.py:21 | cid: 2965


INFO flwr 2025-03-15 22:04:01,677 | 1340840146.py:21 | cid: 1172


INFO flwr 2025-03-15 22:04:01,685 | 1340840146.py:21 | cid: 1722


INFO flwr 2025-03-15 22:04:01,690 | 1340840146.py:21 | cid: 2216


INFO flwr 2025-03-15 22:04:01,695 | 1340840146.py:21 | cid: 1321


INFO flwr 2025-03-15 22:04:01,700 | 1340840146.py:21 | cid: 2035


INFO flwr 2025-03-15 22:04:01,705 | 1340840146.py:21 | cid: 693


INFO flwr 2025-03-15 22:04:01,711 | 1340840146.py:21 | cid: 301


INFO flwr 2025-03-15 22:04:01,717 | 1340840146.py:21 | cid: 3018


INFO flwr 2025-03-15 22:04:01,723 | 1340840146.py:21 | cid: 2510


Now, to test that the newly partitioned clients can be trained.

In [13]:
test_config: dict[str, Any] = {
    "batch_size": 32,
    "num_workers": 0,
    "max_batches": 100,
}

In [14]:
num_clients = 4
clientIds = random.sample(list(range(total_clients)), num_clients)
clients = [federated_client_generator(str(cid)) for cid in clientIds]
print(f'{clients=}')

INFO flwr 2025-03-15 22:04:01,748 | 1340840146.py:21 | cid: 740


INFO flwr 2025-03-15 22:04:01,752 | 1340840146.py:21 | cid: 419


INFO flwr 2025-03-15 22:04:01,754 | 1340840146.py:21 | cid: 1888


INFO flwr 2025-03-15 22:04:01,757 | 1340840146.py:21 | cid: 2035


clients=[<__main__.FlowerRayClient object at 0x7ca08868ab60>, <__main__.FlowerRayClient object at 0x7ca08868bb50>, <__main__.FlowerRayClient object at 0x7ca2135c3160>, <__main__.FlowerRayClient object at 0x7ca08868b9d0>]


In [15]:
# def train(clients, numEpoch):
#     epoch_config: dict[str, Any] = {
#     "epochs": numEpoch,
#     "batch_size": 32,
#     "client_learning_rate": 0.01,
#     "weight_decay": 0.001,
#     "num_workers": 0,
#     "max_batches": 100,
#     }

#     trained_models = [
#     fit_client_seeded(
#         client, seed_model_params, epoch_config
#     )
#     for client in clients
#     ]

#     params = [model for model, *rest in trained_models]
#     metrics = [rest for _, *rest in trained_models]
#     log(INFO, "Metrics from trained models are: %s", metrics)
#     return params, metrics

The two basic blocks of synchronous server-client FL systems are:
- A client with some local training method and data---i.e., SGD. This is what we have built thus far.
- A server which coordinates training sends the federated model to clients at the start of each round and aggregates model updates at the end of each round.

The pieces necessary for starting an FL simulation are now in play; we need to arrange them to fit the Flower API. First, we shall require a separate federated evaluation function which can be called outside the context of a specific client. It will use the centralised test set to be as simple as possible.

In [16]:
def get_federated_evaluation_function(
    batch_size: int,
    num_workers: int,
    model_generator: Callable[[], Module],
    criterion: Module,
    max_batches: int,
) -> Callable[[int, NDArrays, dict[str, Any]], tuple[float, dict[str, Scalar]]]:
    """Wrap the external federated evaluation function.

    It provides the external federated evaluation function with some
    parameters for the dataloader, the model generator function, and
    the criterion used in the evaluation.

    Parameters
    ----------
        batch_size (int): batch size of the test set to use.
        num_workers (int): correspond to `num_workers` param in the Dataloader object.
        model_generator (Callable[[], Module]):  model generator function.
        criterion (Module): PyTorch Module containing the criterion for evaluating the
        model.

    Returns
    -------
        Callable[[int, NDArrays, dict[str, Any]], tuple[float, dict[str, Scalar]]]:
            external federated evaluation function.
    """

    def federated_evaluation_function(
        server_round: int,
        parameters: NDArrays,
        fed_eval_config: dict[
            str, Any
        ],  # mandatory argument, even if it's not being used
    ) -> tuple[float, dict[str, Scalar]]:
        """Evaluate federated model on the server.

        It uses the centralized val set for sake of simplicity.

        Parameters
        ----------
            server_round (int): current federated round.
            parameters (NDArrays): current model parameters.
            fed_eval_config (dict[str, Any]): mandatory argument in Flower, can contain
                some configuration info

        Returns
        -------
            tuple[float, dict[str, Scalar]]: evaluation results
        """
        device: str = get_device()
        net: Module = set_model_parameters(model_generator(), parameters)
        net.to(device)

        full_file: Path = centralized_mapping
        dataset: Dataset = load_femnist_dataset(data_dir, full_file, "val")

        valid_loader = DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            drop_last=False,
        )

        loss, acc = test_femnist(
            net=net,
            test_loader=valid_loader,
            device=device,
            criterion=criterion,
            max_batches=max_batches,
        )
        return loss, {"accuracy": acc}

    return federated_evaluation_function


federated_evaluation_function = get_federated_evaluation_function(
    batch_size=test_config["batch_size"],
    num_workers=test_config["num_workers"],
    model_generator=network_generator,
    criterion=nn.CrossEntropyLoss(),
    max_batches=test_config["max_batches"],
)

In [17]:
def aggregate_weighted_average(metrics: list[tuple[int, dict]]) -> dict:
    """Combine results from multiple clients following training or evaluation.

    Parameters
    ----------
        metrics (list[tuple[int, dict]]): collected clients metrics

    Returns
    -------
        dict: result dictionary containing the aggregate of the metrics passed.
    """
    average_dict: dict = defaultdict(list)
    total_examples: int = 0
    for num_examples, metrics_dict in metrics:
        for key, val in metrics_dict.items():
            if isinstance(val, numbers.Number):
                average_dict[key].append((num_examples, val))
        total_examples += num_examples
    return {
        key: {
            "avg": float(
                sum([num_examples * metric for num_examples, metric in val])
                / float(total_examples)
            ),
            "all": val,
        }
        for key, val in average_dict.items()
    }

In [18]:
# Federated configuration dictionary
federated_train_config: dict[str, Any] = {
    "epochs": 50,
    "batch_size": 32,
    "client_learning_rate": 0.01,
    "weight_decay": 0.001,
    "num_workers": 0,
    "max_batches": 100,
}

The only challenge left is the FL simulation itself. In `Flower`, a `Server` object handles this for us by using `Ray` and spawning many heavyweight worker process.

Given the limited-resource scenario in which we find ourselves, we provide you with a slightly modified simulation function which uses a simple thread pool. Feel free to swap it out for the original simulation or replace it with your own implementation if so inclined.

> The server we use is not the default `Flower` server as it returns the model parameters from every single round in a `(round, NDArrays)` tuple.

In [19]:
def start_seeded_simulation(
    client_fn: Callable[[str], Client],
    num_clients: int,
    config: ServerConfig,
    strategy: Strategy,
    name: str,
    return_all_parameters: bool = False,
    seed: int = Seeds.DEFAULT,
    iteration: int = 0,
) -> tuple[list[tuple[int, NDArrays]], History]:
    """Wrap to seed client selection."""
    np.random.seed(seed ^ iteration)
    torch.manual_seed(seed ^ iteration)
    random.seed(seed ^ iteration)
    parameter_list, hist = flwr.simulation.start_simulation_no_ray(
        client_fn=client_fn,
        num_clients=num_clients,
        client_resources={},
        config=config,
        strategy=strategy,
    )
    save_history(home_dir, hist, name)
    return parameter_list, hist

`run_simulation_frank_wolfe` is an adaptation of the original simulation function (now renamed to `run_simulation_fedavg`), the only difference being the strategy used. The strategy can be found in [c2m3/match/fed_frank_wolfe_strategy.py](https://github.com/DawnSpider96/L361-Federated-Learning/blob/c2m3/c2m3/match/fed_frank_wolfe_strategy.py#L43)

In [20]:
num_rounds = 10

num_total_clients = 20

num_evaluate_clients = 0
num_clients_per_round = 5

initial_parameters = ndarrays_to_parameters(seed_model_params)


def run_simulation_frank_wolfe(
    # How long the FL process runs for:
    num_rounds: int = num_rounds,
    # Number of clients available
    num_total_clients: int = num_total_clients,
    # Number of clients used for train/eval
    num_clients_per_round: int = num_clients_per_round,
    num_evaluate_clients: int = num_evaluate_clients,
    # If less clients are overall available stop FL
    min_available_clients: int = num_total_clients,
    # If less clients are available for fit/eval stop FL
    min_fit_clients: int = num_clients_per_round,
    min_evaluate_clients: int = num_evaluate_clients,
    # Function to test the federated model performance
    # external to a client instantiation
    evaluate_fn: (
        Callable[
            [int, NDArrays, dict[str, Scalar]],
            tuple[float, dict[str, Scalar]] | None,
        ]
        | None
    ) = federated_evaluation_function,
    # Functions to generate a config for client fit/evaluate
    # by-default the same config is shallow-copied to all clients in Flower
    # this version simply uses the configs defined above
    on_fit_config_fn: Callable[
        [int], dict[str, Scalar]
    ] = lambda _x: federated_train_config,
    on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] = lambda _x: test_config,
    # The "Parameters" type is merely a more packed version
    # of numpy array lists, used internally by Flower
    initial_parameters: Parameters = initial_parameters,
    # If this is set to True, aggregation will work even if some clients fail
    accept_failures: bool = False,
    # How to combine the metrics dictionary returned by all clients for fit/eval
    fit_metrics_aggregation_fn: Callable | None = aggregate_weighted_average,
    evaluate_metrics_aggregation_fn: Callable | None = aggregate_weighted_average,
    federated_client_generator: Callable[
        [str], flwr.client.NumPyClient
    ] = federated_client_generator,
    # Aggregation learning rate for FedAvg
    server_learning_rate: float = 1.0,
    server_momentum: float = 0.0,
) -> tuple[list[tuple[int, NDArrays]], History]:
    """Run a federated simulation using Flower."""
    log(INFO, "FL will execute for %s rounds", num_rounds)

    # Percentage of clients used for train/eval
    fraction_fit: float = float(num_clients_per_round) / num_total_clients
    fraction_evaluate: float = float(num_evaluate_clients) / num_total_clients

    strategy = FrankWolfeSync(
        fraction_fit=fraction_fit,
        fraction_evaluate=fraction_evaluate,
        min_fit_clients=min_fit_clients,
        min_evaluate_clients=min_evaluate_clients,
        min_available_clients=min_available_clients,
        on_fit_config_fn=on_fit_config_fn,
        on_evaluate_config_fn=on_evaluate_config_fn,
        evaluate_fn=evaluate_fn,
        initial_parameters=initial_parameters,
        accept_failures=accept_failures,
        fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
        evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
        # batch_size = int(config["batch_size"])
        # num_workers = int(config["num_workers"])
        # dataset = self._load_dataset(name)
    )
    # resetting the seed for the random selection of clients
    # this way the list of clients trained is guaranteed to be always the same

    cfg = ServerConfig(num_rounds)

    def simulator_client_generator(cid: str) -> Client:
        return federated_client_generator(cid).to_client()

    parameters_for_each_round, hist = start_seeded_simulation(
        client_fn=simulator_client_generator,
        num_clients=num_total_clients,
        config=cfg,
        strategy=strategy,
        name="c2m3",
        return_all_parameters=True,
        seed=Seeds.DEFAULT,
    )
    return parameters_for_each_round, hist




def run_simulation_fedavg(
    # How long the FL process runs for:
    num_rounds: int = num_rounds,
    # Number of clients available
    num_total_clients: int = num_total_clients,
    # Number of clients used for train/eval
    num_clients_per_round: int = num_clients_per_round,
    num_evaluate_clients: int = num_evaluate_clients,
    # If less clients are overall available stop FL
    min_available_clients: int = num_total_clients,
    # If less clients are available for fit/eval stop FL
    min_fit_clients: int = num_clients_per_round,
    min_evaluate_clients: int = num_evaluate_clients,
    # Function to test the federated model performance
    # external to a client instantiation
    evaluate_fn: (
        Callable[
            [int, NDArrays, dict[str, Scalar]],
            tuple[float, dict[str, Scalar]] | None,
        ]
        | None
    ) = federated_evaluation_function,
    # Functions to generate a config for client fit/evaluate
    # by-default the same config is shallow-copied to all clients in Flower
    # this version simply uses the configs defined above
    on_fit_config_fn: Callable[
        [int], dict[str, Scalar]
    ] = lambda _x: federated_train_config,
    on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] = lambda _x: test_config,
    # The "Parameters" type is merely a more packed version
    # of numpy array lists, used internally by Flower
    initial_parameters: Parameters = initial_parameters,
    # If this is set to True, aggregation will work even if some clients fail
    accept_failures: bool = False,
    # How to combine the metrics dictionary returned by all clients for fit/eval
    fit_metrics_aggregation_fn: Callable | None = aggregate_weighted_average,
    evaluate_metrics_aggregation_fn: Callable | None = aggregate_weighted_average,
    federated_client_generator: Callable[
        [str], flwr.client.NumPyClient
    ] = federated_client_generator,
    # Aggregation learning rate for FedAvg
    server_learning_rate: float = 1.0,
    server_momentum: float = 0.0,
) -> tuple[list[tuple[int, NDArrays]], History]:
    """Run a federated simulation using Flower."""
    log(INFO, "FL will execute for %s rounds", num_rounds)

    # Percentage of clients used for train/eval
    fraction_fit: float = float(num_clients_per_round) / num_total_clients
    fraction_evaluate: float = float(num_evaluate_clients) / num_total_clients

    strategy = FedAvg(
        fraction_fit=fraction_fit,
        fraction_evaluate=fraction_evaluate,
        min_fit_clients=min_fit_clients,
        min_evaluate_clients=min_evaluate_clients,
        min_available_clients=min_available_clients,
        on_fit_config_fn=on_fit_config_fn,
        on_evaluate_config_fn=on_evaluate_config_fn,
        evaluate_fn=evaluate_fn,
        initial_parameters=initial_parameters,
        accept_failures=accept_failures,
        fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
        evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
        server_learning_rate=server_learning_rate,
        server_momentum=server_momentum,
        # batch_size = int(config["batch_size"])
        # num_workers = int(config["num_workers"])
        # dataset = self._load_dataset(name)
        
    )
    # resetting the seed for the random selection of clients
    # this way the list of clients trained is guaranteed to be always the same

    cfg = ServerConfig(num_rounds)

    def simulator_client_generator(cid: str) -> Client:
        return federated_client_generator(cid).to_client()

    parameters_for_each_round, hist = start_seeded_simulation(
        client_fn=simulator_client_generator,
        num_clients=num_total_clients,
        config=cfg,
        strategy=strategy,
        name="c2m3",
        return_all_parameters=True,
        seed=Seeds.DEFAULT,
    )
    return parameters_for_each_round, hist

See below: This is a copy of one of the gradient matrices, output of [collect_gradients_frank_wolfe_model_pair](https://github.com/crisostomi/cycle-consistent-model-merging/blob/6ee822f56114181ea7eba4cb7533a0b6e27ea749/src/ccmm/matching/frank_wolfe_sync_matching.py#L72). It (like the other gradient matrices) is too skewed towards the central diagonal (ie the value of assigning the i-th worker to the i-th task is significantly higher than any other task).

Hence the output of the `linear_sum_assignment` will specify no permutations whatsoever. This effect carries over into the updating of permutation matrices, where the projected gradients will be updated to identity matrices, and afterwards the updated permutation matrices do not change from how they were initialised (identity matrix).

In [21]:
import numpy as np
from scipy.optimize import linear_sum_assignment
cost = np.array([[ 5.6573396 ,  0.12816703,  0.7734268 ,  0.54617214, -0.10447261,
         0.39014438],
         [ 1.1931067 ,  1.5022745 ,  7.1273413 ,  0.79405797,  0.34921485,
         1.2170054 ],
       [ 0.29189748,  5.714079  ,  1.1553278 ,  0.15853024,  0.47528026,
         0.7882026 ],
       [ 0.7397236 ,  0.10359962,  0.6959048 ,  5.283332  ,  0.27297193,
         0.99973345],
       [ 0.5185135 ,  1.0186552 ,  1.0369794 ,  0.6470743 ,  5.6826034 ,
         1.0205868 ],
       [ 0.8446162 ,  0.9361492 ,  1.0751884 ,  1.6859208 ,  0.61274135,
         6.6700554 ]])
linear_sum_assignment(cost, maximize=True)

(array([0, 1, 2, 3, 4, 5]), array([0, 2, 1, 3, 4, 5]))

In [22]:
parameters_for_each_round, hist = run_simulation_frank_wolfe()

INFO flwr 2025-03-15 22:04:01,847 | 4257692453.py:56 | FL will execute for 10 rounds


INFO flwr 2025-03-15 22:04:01,858 | app.py:149 | Starting Flower simulation, config: ServerConfig(num_rounds=10, round_timeout=None)


INFO flwr 2025-03-15 22:04:01,862 | server_returns_parameters.py:81 | Initializing global parameters


INFO flwr 2025-03-15 22:04:01,865 | server_returns_parameters.py:273 | Using initial parameters provided by strategy


INFO flwr 2025-03-15 22:04:01,870 | server_returns_parameters.py:84 | Evaluating initial parameters


 11%|█         | 100/891 [00:01<00:09, 81.59it/s]
INFO flwr 2025-03-15 22:04:03,337 | server_returns_parameters.py:87 | initial parameters (loss, other metrics): 413.68426275253296, {'accuracy': 0.0065625}


INFO flwr 2025-03-15 22:04:03,340 | server_returns_parameters.py:97 | FL starting


DEBUG flwr 2025-03-15 22:04:03,343 | server_returns_parameters.py:223 | fit_round 1: strategy sampled 5 clients (out of 20)


INFO flwr 2025-03-15 22:04:03,346 | 1340840146.py:21 | cid: 2530
INFO flwr 2025-03-15 22:04:03,346 | 1340840146.py:21 | cid: 220


INFO flwr 2025-03-15 22:04:03,349 | 1340840146.py:21 | cid: 2688
INFO flwr 2025-03-15 22:04:03,356 | 1340840146.py:21 | cid: 2343


INFO flwr 2025-03-15 22:04:03,358 | 1340840146.py:21 | cid: 1887


DEBUG flwr 2025-03-15 22:04:21,670 | server_returns_parameters.py:237 | fit_round 1 received 5 results and 0 failures


INFO flwr 2025-03-15 22:04:21,674 | 1340840146.py:21 | cid: 2688


INFO flwr 2025-03-15 22:04:21,680 | 1340840146.py:21 | cid: 1887


props={'tensor_type': 'numpy.ndarray', 'partition': PosixPath('/home/dawn/repos/c2m3-federated/c2m3/notebooks/../femnist/client_data_mappings/fed_natural'), 'cid': 2688}


INFO flwr 2025-03-15 22:04:21,688 | 1340840146.py:21 | cid: 220


props={'tensor_type': 'numpy.ndarray', 'partition': PosixPath('/home/dawn/repos/c2m3-federated/c2m3/notebooks/../femnist/client_data_mappings/fed_natural'), 'cid': 1887}


INFO flwr 2025-03-15 22:04:21,695 | 1340840146.py:21 | cid: 2343


props={'tensor_type': 'numpy.ndarray', 'partition': PosixPath('/home/dawn/repos/c2m3-federated/c2m3/notebooks/../femnist/client_data_mappings/fed_natural'), 'cid': 220}


INFO flwr 2025-03-15 22:04:21,703 | 1340840146.py:21 | cid: 2530


props={'tensor_type': 'numpy.ndarray', 'partition': PosixPath('/home/dawn/repos/c2m3-federated/c2m3/notebooks/../femnist/client_data_mappings/fed_natural'), 'cid': 2343}


  rank_zero_warn(


props={'tensor_type': 'numpy.ndarray', 'partition': PosixPath('/home/dawn/repos/c2m3-federated/c2m3/notebooks/../femnist/client_data_mappings/fed_natural'), 'cid': 2530}
dict_keys(['a', 'b', 'c', 'd', 'e'])
<class 'c2m3.modules.pl_module.MyLightningModule'>


Weight matching:  17%|█▋        | 34/200 [00:14<01:10,  2.35it/s]


perms_to_apply={'P_conv1': tensor([0, 1, 2, 3, 4, 5]), 'P_fc1': tensor([  0,   1,   2,   3,   4,  69,   6,   7,   8,   9,  10,  81,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  68,  40,  41,
         11,   5,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  91,  62,  63,  64,  65,  61,  67, 108,  96,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  42,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  43,  92,  93,  94,  95,  66,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107,  39, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119]), 'P_fc2': tensor([ 0,  1,  2, 10,  4,  5, 47,  7,  8,  9, 52, 11, 12, 13, 14, 15, 28, 17,
        18, 56, 20, 21, 22, 23, 24, 25, 26, 27, 63, 29, 30, 31, 32, 33, 43, 35,
        36, 37, 38, 42, 40, 41, 58, 72, 44, 45, 60, 16, 48, 49, 50, 51, 39,  6,
        54, 5

 11%|█         | 100/891 [00:00<00:07, 107.83it/s]
INFO flwr 2025-03-15 22:04:37,622 | server_returns_parameters.py:120 | fit progress: (1, 388.59137320518494, {'accuracy': 0.1790625}, 34.27909849199932)


INFO flwr 2025-03-15 22:04:37,625 | server_returns_parameters.py:171 | evaluate_round 1: no clients selected, cancel


DEBUG flwr 2025-03-15 22:04:37,627 | server_returns_parameters.py:223 | fit_round 2: strategy sampled 5 clients (out of 20)


INFO flwr 2025-03-15 22:04:37,630 | 1340840146.py:21 | cid: 206
INFO flwr 2025-03-15 22:04:37,632 | 1340840146.py:21 | cid: 100


INFO flwr 2025-03-15 22:04:37,634 | 1340840146.py:21 | cid: 2828
INFO flwr 2025-03-15 22:04:37,636 | 1340840146.py:21 | cid: 2985


INFO flwr 2025-03-15 22:04:37,639 | 1340840146.py:21 | cid: 1473


KeyboardInterrupt: 

In [None]:
parameters_for_each_round_fedavg, hist_fedavg = run_simulation_fedavg()

In [None]:
hist

In [None]:
hist_fedavg

In [26]:
import matplotlib.pyplot as plt

def plot_metrics(hist1, hist2, legend_labels=['FrankWolfe', 'FedAvg'], save_path=None):
    
    acc1 = hist1.metrics_centralized['accuracy']
    rounds_acc1, acc_values1 = zip(*acc1)
    
    acc2 = hist2.metrics_centralized['accuracy']
    rounds_acc2, acc_values2 = zip(*acc2)
    
    loss1 = hist1.losses_centralized
    rounds_loss1, loss_values1 = zip(*loss1)
    
    loss2 = hist2.losses_centralized
    rounds_loss2, loss_values2 = zip(*loss2)
    
    fig, axs = plt.subplots(2, 1, figsize=(12, 10))
    
    axs[0].plot(rounds_acc1, acc_values1, 'o-', color='blue', linewidth=2, markersize=8, 
               label=f'{legend_labels[0]} Accuracy')
    axs[0].plot(rounds_acc2, acc_values2, 's-', color='cyan', linewidth=2, markersize=8, 
               label=f'{legend_labels[1]} Accuracy')
    
    axs[0].set_title('Accuracy Comparison', fontsize=14)
    axs[0].set_xlabel('Round Number', fontsize=12)
    axs[0].set_ylabel('Accuracy', fontsize=12)
    axs[0].grid(True, linestyle='--', alpha=0.7)
    axs[0].legend(loc='best')
    
    all_rounds_acc = sorted(list(set(rounds_acc1 + rounds_acc2)))
    axs[0].set_xticks(all_rounds_acc)
    
    axs[1].plot(rounds_loss1, loss_values1, 'o-', color='red', linewidth=2, markersize=8, 
               label=f'{legend_labels[0]} Loss')
    axs[1].plot(rounds_loss2, loss_values2, 's-', color='orange', linewidth=2, markersize=8, 
               label=f'{legend_labels[1]} Loss')
    
    axs[1].set_title('Loss Comparison', fontsize=14)
    axs[1].set_xlabel('Round Number', fontsize=12)
    axs[1].set_ylabel('Loss', fontsize=12)
    axs[1].grid(True, linestyle='--', alpha=0.7)
    axs[1].legend(loc='best')

In [None]:
plot_metrics(hist, hist_fedavg)

In [None]:
log(
    INFO,
    "Size of the list with the model parameters: %s",
    len(parameters_for_each_round),
)