<a href="https://colab.research.google.com/github/DawnSpider96/L361-Federated-Learning/blob/release/Copy_of_L361_2025_Lab_1_From_Centralised_To_Federated_Part_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Dependencies
---


In [1]:
# `pip` could produce some errors. Do not worry about them.
# The execution has been verified; it's working anyway.
# ! pip install --quiet --upgrade "pip"
# ! pip install --quiet matplotlib tqdm seaborn
# ! pip install git+https://github.com/Iacob-Alexandru-Andrei/flower.git@teaching \
#     torch torchvision ray=="2.6.3"

### Imports.


In [2]:
import sys
sys.path.append('../../')
print(sys.path)

['/usr/lib/python310.zip', '/usr/lib/python3.10', '/usr/lib/python3.10/lib-dynload', '', '/home/dawn/venvs/fed/lib/python3.10/site-packages', '/home/dawn/repos/cycle-consistent-model-merging/src', '/home/dawn/repos/c2m3-federated', '/home/dawn/repos/c2m3-fed', '../../']


In [3]:
import random
from pathlib import Path
import tarfile
from typing import Any
from logging import INFO
from collections import defaultdict, OrderedDict
from collections.abc import Sequence, Callable
import numbers

import numpy as np
import torch
from torch import nn
from torch.nn import Module
from torch.utils.data import DataLoader, Dataset
from enum import IntEnum
import flwr
from flwr.server import History, ServerConfig
from flwr.server.strategy import FedAvgM as FedAvg, Strategy
from c2m3.flower.fed_frank_wolfe_strategy import FrankWolfeSync
from flwr.common import log, NDArrays, Scalar, Parameters, ndarrays_to_parameters
from flwr.client.client import Client

from c2m3.common.client_utils import (
    Net,
    load_femnist_dataset,
    get_network_generator_cnn as get_network_generator,
    train_femnist,
    test_femnist,
    save_history,
    get_model_parameters,
    set_model_parameters
)


# Add new seeds here for easy autocomplete
class Seeds(IntEnum):
    """Seeds for reproducibility."""

    DEFAULT = 42 # [42, 123, 456, 789, 101]


np.random.seed(Seeds.DEFAULT)
random.seed(Seeds.DEFAULT)
torch.manual_seed(Seeds.DEFAULT)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


PathType = Path | str | None


def get_device() -> str:
    """Get the device (cuda, mps, cpu)."""
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"
    elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
        device = "mps"
    return device

  from .autonotebook import tqdm as notebook_tqdm
2025-03-28 05:05:58.865429: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-28 05:05:58.936794: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-28 05:05:58.937740: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-03-28 05:06:01,694	INFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [4]:
home_dir = Path.cwd() / ".."
dataset_dir: Path = home_dir / "data" / "femnist"
data_dir: Path = dataset_dir / "data"
centralized_partition: Path = dataset_dir / "client_data_mappings" / "centralized"
centralized_mapping: Path = dataset_dir / "client_data_mappings" / "centralized" / "0"
federated_partition: Path = dataset_dir / "client_data_mappings" / "lda_0.1"
# Decompress dataset
# if not dataset_dir.exists():
#     with tarfile.open(home_dir / "femnist.tar.gz", "r:gz") as tar:
#         tar.extractall(path=home_dir)
#     log(INFO, "Dataset extracted in %s", dataset_dir)

## Build Flower FL client.
---

In [5]:
class FlowerRayClient(flwr.client.NumPyClient):
    """Flower client for the FEMNIST dataset."""

    def __init__(
        self,
        cid: int,
        partition_dir: Path,
        model_generator: Callable[[], Module],
    ) -> None:
        """Init the client with its unique id and the folder to load data from.

        Parameters
        ----------
            cid (int): Unique client id for a client used to map it to its data
                partition
            partition_dir (Path): The directory containing data for each
                client/client id
            model_generator (Callable[[], Module]): The model generator function
        """
        self.cid = cid
        log(INFO, "cid: %s", self.cid)
        self.partition_dir = partition_dir
        self.device = str(
            torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        )
        self.model_generator: Callable[[], Module] = model_generator
        self.properties: dict[str, Scalar] = {
            "tensor_type": "numpy.ndarray",
            "partition": self.partition_dir,
            "cid": self.cid
            }
        self.data_dir = data_dir

    def set_parameters(self, parameters: NDArrays) -> Module:
        """Load weights inside the network.

        Parameters
        ----------
            parameters (NDArrays): set of weights to be loaded.

        Returns
        -------
            [Module]: Network with new set of weights.
        """
        net = self.model_generator()
        return set_model_parameters(net, parameters)

    def get_parameters(self, config: dict[str, Scalar]) -> NDArrays:
        """Return weights from a given model.

        If no model is passed, then a local model is created.
        This can be used to initialise a model in the
        server.
        The config param is not used but is mandatory in Flower.

        Parameters
        ----------
            config (dict[int, Scalar]): dictionary containing configuration info.

        Returns
        -------
            NDArrays: weights from the model.
        """
        net = self.model_generator()
        return get_model_parameters(net)

    def fit(
        self, parameters: NDArrays, config: dict[str, Scalar]
    ) -> tuple[NDArrays, int, dict]:
        """Receive and train a model on the local client data.

        It uses parameters from the config dict

        Parameters
        ----------
            net (NDArrays): Pytorch model parameters
            config (dict[str, Scalar]): dictionary describing the training parameters

        Returns
        -------
            tuple[NDArrays, int, dict]: Returns the updated model, the size of the local
                dataset and other metrics
        """
        # Only create model right before training/testing
        # To lower memory usage when idle
        net = self.set_parameters(parameters)
        net.to(self.device)

        train_loader: DataLoader = self._create_data_loader(config, name="train")
        train_loss = self._train(net, train_loader=train_loader, config=config)
        return get_model_parameters(net), len(train_loader), {"train_loss": train_loss}

    def evaluate(
        self, parameters: NDArrays, config: dict[str, Scalar]
    ) -> tuple[float, int, dict]:
        """Receive and test a model on the local client data.

        It uses parameters from the config dict

        Parameters
        ----------
            net (NDArrays): Pytorch model parameters
            config (dict[str, Scalar]): dictionary describing the testing parameters

        Returns
        -------
            tuple[float, int, dict]: Returns the loss accumulate during testing, the
                size of the local dataset and other metrics such as accuracy
        """
        net = self.set_parameters(parameters)
        net.to(self.device)

        test_loader: DataLoader = self._create_data_loader(config, name="test")
        loss, accuracy = self._test(net, test_loader=test_loader, config=config)
        return loss, len(test_loader), {"local_accuracy": accuracy}

    def _create_data_loader(self, config: dict[str, Scalar], name: str) -> DataLoader:
        """Create the data loader using the specified config parameters.

        Parameters
        ----------
            config (dict[str, Scalar]): dictionary containing dataloader and dataset
                parameters
            mode (str): Load the training or testing set for the client

        Returns
        -------
            DataLoader: A pytorch dataloader iterable for training/testing
        """
        batch_size = int(config["batch_size"])
        num_workers = int(config["num_workers"])
        dataset = self._load_dataset(name)
        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            drop_last=(name == "train"),
        )

    def _load_dataset(self, name: str) -> Dataset:
        full_file: Path = self.partition_dir / str(self.cid)
        return load_femnist_dataset(
            mapping=full_file,
            name=name,
            data_dir=data_dir,
        )

    def _train(
        self, net: Module, train_loader: DataLoader, config: dict[str, Scalar]
    ) -> float:
        return train_femnist(
            net=net,
            train_loader=train_loader,
            epochs=int(config["epochs"]),
            device=self.device,
            optimizer=torch.optim.AdamW(
                net.parameters(),
                lr=float(config["client_learning_rate"]),
                weight_decay=float(config["weight_decay"]),
            ),
            criterion=torch.nn.CrossEntropyLoss(),
            max_batches=int(config["max_batches"]),
        )

    def _test(
        self, net: Module, test_loader: DataLoader, config: dict[str, Scalar]
    ) -> tuple[float, float]:
        return test_femnist(
            net=net,
            test_loader=test_loader,
            device=self.device,
            criterion=torch.nn.CrossEntropyLoss(),
            max_batches=int(config["max_batches"]),
        )

    def get_properties(self, config: dict[str, Scalar]) -> dict[str, Scalar]:
        """Return properties for this client.

        Parameters
        ----------
            config (dict[str, Scalar]): Options to be used for selecting specific
            properties.

        Returns
        -------
            dict[str, Scalar]: Returned properties.
        """
        return self.properties

    def get_train_set_size(self) -> int:
        """Return the client train set size.

        Returns
        -------
            int: train set size of the client.
        """
        return len(self._load_dataset("train"))  # type: ignore[reportArgumentType]

    def get_test_set_size(self) -> int:
        """Return the client test set size.

        Returns
        -------
            int: test set size of the client.
        """
        return len(self._load_dataset("test"))  # type: ignore[reportArgumentType]


# def fit_client_seeded(
#     client: FlowerRayClient,
#     params: NDArrays,
#     conf: dict[str, Any],
#     seed: Seeds = Seeds.DEFAULT,
#     **kwargs: Any,
# ) -> tuple[NDArrays, int, dict]:
#     """Wrap to always seed client training."""
#     np.random.seed(seed)
#     torch.manual_seed(seed)
#     random.seed(seed)
#     return client.fit(params, conf, **kwargs)

In [6]:
def get_flower_client_generator(
    model_generator: Callable[[], Module],
    partition_dir: Path,
    mapping_fn: Callable[[int], int] | None = None,
) -> Callable[[str], FlowerRayClient]:
    """Wrap the client instance generator.

    This provides the client generator with a model generator function.
    Also, the partition directory must be passed.
    A mapping function could be used for filtering/ordering clients.

    Parameters
    ----------
        model_generator (Callable[[], Module]): model generator function.
        partition_dir (Path): directory containing the partition.
        mapping_fn (Optional[Callable[[int], int]]): function mapping sorted/filtered
            ids to real cid.

    Returns
    -------
        Callable[[str], FlowerRayClient]: client instance.
    """

    def client_fn(cid: str) -> FlowerRayClient:
        """Create a single client instance given the client id `cid`.

        Parameters
        ----------
            cid (str): client id, Flower requires this to be of type str.

        Returns
        -------
            FlowerRayClient: client instance.
        """
        return FlowerRayClient(
            cid=mapping_fn(int(cid)) if mapping_fn is not None else int(cid),
            partition_dir=partition_dir,
            model_generator=model_generator,
        )

    return client_fn

In [7]:
network_generator = get_network_generator()
seed_net: Net = network_generator()
seed_model_params: NDArrays = get_model_parameters(seed_net)

centralized_flower_client_generator: Callable[[str], FlowerRayClient] = (
    get_flower_client_generator(network_generator, centralized_partition)
)
centralized_flower_client = centralized_flower_client_generator(str(0))

INFO flwr 2025-03-28 05:06:04,507 | 3963006861.py:21 | cid: 0


In [8]:
seed_net

CNN(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=62, bias=True)
)

In [9]:
seed_net.state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[ 0.1529,  0.1660, -0.0469,  0.1837, -0.0438],
                        [ 0.0404, -0.0974,  0.1175,  0.1763, -0.1467],
                        [ 0.1738,  0.0374,  0.1478,  0.0271,  0.0964],
                        [-0.0282,  0.1542,  0.0296, -0.0934,  0.0510],
                        [-0.0921, -0.0235, -0.0812,  0.1327, -0.1579]]],
              
              
                      [[[-0.0922, -0.0565, -0.1203,  0.0189, -0.1975],
                        [ 0.1806, -0.1699,  0.1544,  0.0333, -0.0649],
                        [ 0.1236,  0.0312,  0.1616,  0.0219, -0.0631],
                        [ 0.0537, -0.0542,  0.0842,  0.1786,  0.1156],
                        [-0.0874,  0.1155,  0.0358,  0.1016, -0.1219]]],
              
              
                      [[[-0.1980, -0.0773, -0.1534,  0.1641,  0.0576],
                        [ 0.0828,  0.0633, -0.0035,  0.1565, -0.1421],
                        [ 0.0126, -0.1365,  0.0617, -0

In [10]:
def sample_random_clients(
    total_clients: int,
    filter_less: int,
    partition: Path,
    seed: int | None = Seeds.DEFAULT,
) -> Sequence[int]:
    """Sample randomly clients.

    A filter on the client train set size is performed.

    Parameters
    ----------
        total_clients (int): total number of clients to sample.
        filter_less (int): max number of train samples for which the client is
            **discarded**.
        partition (Path): path to the folder containing the partitioning.
        seed (Optional[int], optional): seed for the random generator. Defaults to None.

    Returns
    -------
        Sequence[int]: list of sample client ids as int.
    """
    real_federated_cid_client_generator: Callable[[str], FlowerRayClient] = (
        get_flower_client_generator(network_generator, federated_partition)
    )
    if seed is not None:
        random.seed(seed)
    list_of_ids = []
    while len(list_of_ids) < total_clients:
        current_id = random.randint(0, 619)
        if (
            real_federated_cid_client_generator(str(current_id)).get_train_set_size()
            > filter_less
        ):
            list_of_ids.append(current_id)
    return list_of_ids

While FEMNIST has more than 3000 clients, our small-scale experiments will not require more than 100 at any point.

In [11]:
total_clients: int = 100
list_of_ids = sample_random_clients(
    total_clients, 32, federated_partition
)

federated_client_generator: Callable[[str], FlowerRayClient] = (
    get_flower_client_generator(
        network_generator, federated_partition, lambda seq_id: list_of_ids[seq_id]
    )
)

INFO flwr 2025-03-28 05:06:04,594 | 3963006861.py:21 | cid: 114


INFO flwr 2025-03-28 05:06:04,602 | 3963006861.py:21 | cid: 104


INFO flwr 2025-03-28 05:06:04,608 | 3963006861.py:21 | cid: 432


INFO flwr 2025-03-28 05:06:04,616 | 3963006861.py:21 | cid: 616


INFO flwr 2025-03-28 05:06:04,622 | 3963006861.py:21 | cid: 558


INFO flwr 2025-03-28 05:06:04,629 | 3963006861.py:21 | cid: 6


INFO flwr 2025-03-28 05:06:04,635 | 3963006861.py:21 | cid: 284


INFO flwr 2025-03-28 05:06:04,641 | 3963006861.py:21 | cid: 389


INFO flwr 2025-03-28 05:06:04,646 | 3963006861.py:21 | cid: 44


INFO flwr 2025-03-28 05:06:04,652 | 3963006861.py:21 | cid: 80


INFO flwr 2025-03-28 05:06:04,657 | 3963006861.py:21 | cid: 370


INFO flwr 2025-03-28 05:06:04,662 | 3963006861.py:21 | cid: 233


INFO flwr 2025-03-28 05:06:04,668 | 3963006861.py:21 | cid: 103


INFO flwr 2025-03-28 05:06:04,674 | 3963006861.py:21 | cid: 166


INFO flwr 2025-03-28 05:06:04,680 | 3963006861.py:21 | cid: 73


INFO flwr 2025-03-28 05:06:04,686 | 3963006861.py:21 | cid: 167


INFO flwr 2025-03-28 05:06:04,692 | 3963006861.py:21 | cid: 570


INFO flwr 2025-03-28 05:06:04,699 | 3963006861.py:21 | cid: 57


INFO flwr 2025-03-28 05:06:04,705 | 3963006861.py:21 | cid: 274


INFO flwr 2025-03-28 05:06:04,710 | 3963006861.py:21 | cid: 322


INFO flwr 2025-03-28 05:06:04,717 | 3963006861.py:21 | cid: 469


INFO flwr 2025-03-28 05:06:04,723 | 3963006861.py:21 | cid: 551


INFO flwr 2025-03-28 05:06:04,728 | 3963006861.py:21 | cid: 408


INFO flwr 2025-03-28 05:06:04,735 | 3963006861.py:21 | cid: 48


INFO flwr 2025-03-28 05:06:04,740 | 3963006861.py:21 | cid: 432


INFO flwr 2025-03-28 05:06:04,745 | 3963006861.py:21 | cid: 541


INFO flwr 2025-03-28 05:06:04,750 | 3963006861.py:21 | cid: 117


INFO flwr 2025-03-28 05:06:04,754 | 3963006861.py:21 | cid: 348


INFO flwr 2025-03-28 05:06:04,760 | 3963006861.py:21 | cid: 269


INFO flwr 2025-03-28 05:06:04,766 | 3963006861.py:21 | cid: 305


INFO flwr 2025-03-28 05:06:04,772 | 3963006861.py:21 | cid: 382


INFO flwr 2025-03-28 05:06:04,777 | 3963006861.py:21 | cid: 543


INFO flwr 2025-03-28 05:06:04,783 | 3963006861.py:21 | cid: 114


INFO flwr 2025-03-28 05:06:04,789 | 3963006861.py:21 | cid: 245


INFO flwr 2025-03-28 05:06:04,794 | 3963006861.py:21 | cid: 87


INFO flwr 2025-03-28 05:06:04,801 | 3963006861.py:21 | cid: 128


INFO flwr 2025-03-28 05:06:04,808 | 3963006861.py:21 | cid: 271


INFO flwr 2025-03-28 05:06:04,814 | 3963006861.py:21 | cid: 552


INFO flwr 2025-03-28 05:06:04,820 | 3963006861.py:21 | cid: 408


INFO flwr 2025-03-28 05:06:04,825 | 3963006861.py:21 | cid: 462


INFO flwr 2025-03-28 05:06:04,830 | 3963006861.py:21 | cid: 602


INFO flwr 2025-03-28 05:06:04,836 | 3963006861.py:21 | cid: 60


INFO flwr 2025-03-28 05:06:04,841 | 3963006861.py:21 | cid: 72


INFO flwr 2025-03-28 05:06:04,847 | 3963006861.py:21 | cid: 552


INFO flwr 2025-03-28 05:06:04,852 | 3963006861.py:21 | cid: 484


INFO flwr 2025-03-28 05:06:04,857 | 3963006861.py:21 | cid: 96


INFO flwr 2025-03-28 05:06:04,863 | 3963006861.py:21 | cid: 478


INFO flwr 2025-03-28 05:06:04,869 | 3963006861.py:21 | cid: 100


INFO flwr 2025-03-28 05:06:04,875 | 3963006861.py:21 | cid: 111


INFO flwr 2025-03-28 05:06:04,881 | 3963006861.py:21 | cid: 432


INFO flwr 2025-03-28 05:06:04,887 | 3963006861.py:21 | cid: 77


INFO flwr 2025-03-28 05:06:04,893 | 3963006861.py:21 | cid: 51


INFO flwr 2025-03-28 05:06:04,898 | 3963006861.py:21 | cid: 242


INFO flwr 2025-03-28 05:06:04,905 | 3963006861.py:21 | cid: 410


INFO flwr 2025-03-28 05:06:04,910 | 3963006861.py:21 | cid: 271


INFO flwr 2025-03-28 05:06:04,915 | 3963006861.py:21 | cid: 569


INFO flwr 2025-03-28 05:06:04,920 | 3963006861.py:21 | cid: 222


INFO flwr 2025-03-28 05:06:04,926 | 3963006861.py:21 | cid: 321


INFO flwr 2025-03-28 05:06:04,933 | 3963006861.py:21 | cid: 543


INFO flwr 2025-03-28 05:06:04,939 | 3963006861.py:21 | cid: 70


INFO flwr 2025-03-28 05:06:04,945 | 3963006861.py:21 | cid: 122


INFO flwr 2025-03-28 05:06:04,950 | 3963006861.py:21 | cid: 40


INFO flwr 2025-03-28 05:06:04,956 | 3963006861.py:21 | cid: 535


INFO flwr 2025-03-28 05:06:04,962 | 3963006861.py:21 | cid: 321


INFO flwr 2025-03-28 05:06:04,968 | 3963006861.py:21 | cid: 307


INFO flwr 2025-03-28 05:06:04,974 | 3963006861.py:21 | cid: 9


INFO flwr 2025-03-28 05:06:04,979 | 3963006861.py:21 | cid: 218


INFO flwr 2025-03-28 05:06:04,984 | 3963006861.py:21 | cid: 70


INFO flwr 2025-03-28 05:06:04,989 | 3963006861.py:21 | cid: 556


INFO flwr 2025-03-28 05:06:04,995 | 3963006861.py:21 | cid: 8


INFO flwr 2025-03-28 05:06:05,000 | 3963006861.py:21 | cid: 106


INFO flwr 2025-03-28 05:06:05,005 | 3963006861.py:21 | cid: 109


INFO flwr 2025-03-28 05:06:05,011 | 3963006861.py:21 | cid: 215


INFO flwr 2025-03-28 05:06:05,017 | 3963006861.py:21 | cid: 270


INFO flwr 2025-03-28 05:06:05,022 | 3963006861.py:21 | cid: 52


INFO flwr 2025-03-28 05:06:05,028 | 3963006861.py:21 | cid: 3


INFO flwr 2025-03-28 05:06:05,034 | 3963006861.py:21 | cid: 452


INFO flwr 2025-03-28 05:06:05,039 | 3963006861.py:21 | cid: 77


INFO flwr 2025-03-28 05:06:05,044 | 3963006861.py:21 | cid: 36


INFO flwr 2025-03-28 05:06:05,049 | 3963006861.py:21 | cid: 130


INFO flwr 2025-03-28 05:06:05,054 | 3963006861.py:21 | cid: 40


INFO flwr 2025-03-28 05:06:05,058 | 3963006861.py:21 | cid: 105


INFO flwr 2025-03-28 05:06:05,065 | 3963006861.py:21 | cid: 158


INFO flwr 2025-03-28 05:06:05,069 | 3963006861.py:21 | cid: 181


INFO flwr 2025-03-28 05:06:05,074 | 3963006861.py:21 | cid: 340


INFO flwr 2025-03-28 05:06:05,080 | 3963006861.py:21 | cid: 254


INFO flwr 2025-03-28 05:06:05,085 | 3963006861.py:21 | cid: 39


INFO flwr 2025-03-28 05:06:05,091 | 3963006861.py:21 | cid: 471


INFO flwr 2025-03-28 05:06:05,097 | 3963006861.py:21 | cid: 228


INFO flwr 2025-03-28 05:06:05,102 | 3963006861.py:21 | cid: 71


INFO flwr 2025-03-28 05:06:05,107 | 3963006861.py:21 | cid: 549


INFO flwr 2025-03-28 05:06:05,114 | 3963006861.py:21 | cid: 182


INFO flwr 2025-03-28 05:06:05,120 | 3963006861.py:21 | cid: 353


INFO flwr 2025-03-28 05:06:05,127 | 3963006861.py:21 | cid: 118


INFO flwr 2025-03-28 05:06:05,135 | 3963006861.py:21 | cid: 446


INFO flwr 2025-03-28 05:06:05,141 | 3963006861.py:21 | cid: 201


INFO flwr 2025-03-28 05:06:05,149 | 3963006861.py:21 | cid: 338


INFO flwr 2025-03-28 05:06:05,155 | 3963006861.py:21 | cid: 307


INFO flwr 2025-03-28 05:06:05,162 | 3963006861.py:21 | cid: 302


INFO flwr 2025-03-28 05:06:05,168 | 3963006861.py:21 | cid: 388


Now, to test that the newly partitioned clients can be trained.

In [12]:
test_config: dict[str, Any] = {
    "batch_size": 32,
    "num_workers": 0,
    "max_batches": 100,
}

In [13]:
# num_clients = 4
# clientIds = random.sample(list(range(total_clients)), num_clients)
# clients = [federated_client_generator(str(cid)) for cid in clientIds]
# print(f'{clients=}')

In [14]:
def get_federated_evaluation_function(
    batch_size: int,
    num_workers: int,
    model_generator: Callable[[], Module],
    criterion: Module,
    max_batches: int,
) -> Callable[[int, NDArrays, dict[str, Any]], tuple[float, dict[str, Scalar]]]:
    """Wrap the external federated evaluation function.

    It provides the external federated evaluation function with some
    parameters for the dataloader, the model generator function, and
    the criterion used in the evaluation.

    Parameters
    ----------
        batch_size (int): batch size of the test set to use.
        num_workers (int): correspond to `num_workers` param in the Dataloader object.
        model_generator (Callable[[], Module]):  model generator function.
        criterion (Module): PyTorch Module containing the criterion for evaluating the
        model.

    Returns
    -------
        Callable[[int, NDArrays, dict[str, Any]], tuple[float, dict[str, Scalar]]]:
            external federated evaluation function.
    """

    def federated_evaluation_function(
        server_round: int,
        parameters: NDArrays,
        fed_eval_config: dict[
            str, Any
        ],  # mandatory argument, even if it's not being used
    ) -> tuple[float, dict[str, Scalar]]:
        """Evaluate federated model on the server.

        It uses the centralized val set for sake of simplicity.

        Parameters
        ----------
            server_round (int): current federated round.
            parameters (NDArrays): current model parameters.
            fed_eval_config (dict[str, Any]): mandatory argument in Flower, can contain
                some configuration info

        Returns
        -------
            tuple[float, dict[str, Scalar]]: evaluation results
        """
        device: str = get_device()
        net: Module = set_model_parameters(model_generator(), parameters)
        net.to(device)

        full_file: Path = centralized_mapping
        dataset: Dataset = load_femnist_dataset(data_dir, full_file, "val")

        valid_loader = DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            drop_last=False,
        )

        loss, acc = test_femnist(
            net=net,
            test_loader=valid_loader,
            device=device,
            criterion=criterion,
            max_batches=max_batches,
        )
        return loss, {"accuracy": acc}

    return federated_evaluation_function


federated_evaluation_function = get_federated_evaluation_function(
    batch_size=test_config["batch_size"],
    num_workers=test_config["num_workers"],
    model_generator=network_generator,
    criterion=nn.CrossEntropyLoss(),
    max_batches=test_config["max_batches"],
)

In [15]:
def aggregate_weighted_average(metrics: list[tuple[int, dict]]) -> dict:
    """Combine results from multiple clients following training or evaluation.

    Parameters
    ----------
        metrics (list[tuple[int, dict]]): collected clients metrics

    Returns
    -------
        dict: result dictionary containing the aggregate of the metrics passed.
    """
    average_dict: dict = defaultdict(list)
    total_examples: int = 0
    for num_examples, metrics_dict in metrics:
        for key, val in metrics_dict.items():
            if isinstance(val, numbers.Number):
                average_dict[key].append((num_examples, val))
        total_examples += num_examples
    return {
        key: {
            "avg": float(
                sum([num_examples * metric for num_examples, metric in val])
                / float(total_examples)
            ),
            "all": val,
        }
        for key, val in average_dict.items()
    }

In [16]:
# Federated configuration dictionary
federated_train_config: dict[str, Any] = {
    "epochs": 25,
    "batch_size": 32,
    "client_learning_rate": 0.01,
    "weight_decay": 0.001,
    "num_workers": 0,
    "max_batches": 100,
    "central_dir": home_dir / "data" / "femnist" / "client_data_mappings" / "centralized" / "0"
}

FL SIMULATION

In [17]:
def start_seeded_simulation(
    client_fn: Callable[[str], Client],
    num_clients: int,
    config: ServerConfig,
    strategy: Strategy,
    name: str,
    return_all_parameters: bool = False,
    seed: int = Seeds.DEFAULT,
    iteration: int = 0,
) -> tuple[list[tuple[int, NDArrays]], History]:
    """Wrap to seed client selection."""
    np.random.seed(seed ^ iteration)
    torch.manual_seed(seed ^ iteration)
    random.seed(seed ^ iteration)
    parameter_list, hist = flwr.simulation.start_simulation_no_ray(
        client_fn=client_fn,
        num_clients=num_clients,
        client_resources={},
        config=config,
        strategy=strategy,
    )
    save_history(home_dir, hist, name)
    return parameter_list, hist

`run_simulation_frank_wolfe` is an adaptation of the original simulation function (now renamed to `run_simulation_fedavg`), the only difference being the strategy used. The strategy can be found in [c2m3/match/fed_frank_wolfe_strategy.py](https://github.com/DawnSpider96/L361-Federated-Learning/blob/c2m3/c2m3/match/fed_frank_wolfe_strategy.py#L43)

In [18]:
num_rounds = 10

num_total_clients = 20

num_evaluate_clients = 0
num_clients_per_round = 5

initial_parameters = ndarrays_to_parameters(seed_model_params)


def run_simulation_frank_wolfe(
    # How long the FL process runs for:
    num_rounds: int = num_rounds,
    # Number of clients available
    num_total_clients: int = num_total_clients,
    # Number of clients used for train/eval
    num_clients_per_round: int = num_clients_per_round,
    num_evaluate_clients: int = num_evaluate_clients,
    # If less clients are overall available stop FL
    min_available_clients: int = num_total_clients,
    # If less clients are available for fit/eval stop FL
    min_fit_clients: int = num_clients_per_round,
    min_evaluate_clients: int = num_evaluate_clients,
    # Function to test the federated model performance
    # external to a client instantiation
    evaluate_fn: (
        Callable[
            [int, NDArrays, dict[str, Scalar]],
            tuple[float, dict[str, Scalar]] | None,
        ]
        | None
    ) = federated_evaluation_function,
    # Functions to generate a config for client fit/evaluate
    # by-default the same config is shallow-copied to all clients in Flower
    # this version simply uses the configs defined above
    on_fit_config_fn: Callable[
        [int], dict[str, Scalar]
    ] = lambda _x: federated_train_config,
    on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] = lambda _x: test_config,
    # The "Parameters" type is merely a more packed version
    # of numpy array lists, used internally by Flower
    initial_parameters: Parameters = initial_parameters,
    # If this is set to True, aggregation will work even if some clients fail
    accept_failures: bool = False,
    # How to combine the metrics dictionary returned by all clients for fit/eval
    fit_metrics_aggregation_fn: Callable | None = aggregate_weighted_average,
    evaluate_metrics_aggregation_fn: Callable | None = aggregate_weighted_average,
    federated_client_generator: Callable[
        [str], flwr.client.NumPyClient
    ] = federated_client_generator,
    # Aggregation learning rate for FedAvg
    server_learning_rate: float = 1.0,
    server_momentum: float = 0.0,
) -> tuple[list[tuple[int, NDArrays]], History]:
    """Run a federated simulation using Flower."""
    log(INFO, "FL will execute for %s rounds", num_rounds)

    # Percentage of clients used for train/eval
    fraction_fit: float = float(num_clients_per_round) / num_total_clients
    fraction_evaluate: float = float(num_evaluate_clients) / num_total_clients

    strategy = FrankWolfeSync(
        fraction_fit=fraction_fit,
        fraction_evaluate=fraction_evaluate,
        min_fit_clients=min_fit_clients,
        min_evaluate_clients=min_evaluate_clients,
        min_available_clients=min_available_clients,
        on_fit_config_fn=on_fit_config_fn,
        on_evaluate_config_fn=on_evaluate_config_fn,
        evaluate_fn=evaluate_fn,
        initial_parameters=initial_parameters,
        accept_failures=accept_failures,
        fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
        evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
        # batch_size = int(config["batch_size"])
        # num_workers = int(config["num_workers"])
        # dataset = self._load_dataset(name)
    )
    # resetting the seed for the random selection of clients
    # this way the list of clients trained is guaranteed to be always the same

    cfg = ServerConfig(num_rounds)

    def simulator_client_generator(cid: str) -> Client:
        return federated_client_generator(cid).to_client()

    parameters_for_each_round, hist = start_seeded_simulation(
        client_fn=simulator_client_generator,
        num_clients=num_total_clients,
        config=cfg,
        strategy=strategy,
        name="c2m3",
        return_all_parameters=True,
        seed=Seeds.DEFAULT,
    )
    return parameters_for_each_round, hist




def run_simulation_fedavg(
    # How long the FL process runs for:
    num_rounds: int = num_rounds,
    # Number of clients available
    num_total_clients: int = num_total_clients,
    # Number of clients used for train/eval
    num_clients_per_round: int = num_clients_per_round,
    num_evaluate_clients: int = num_evaluate_clients,
    # If less clients are overall available stop FL
    min_available_clients: int = num_total_clients,
    # If less clients are available for fit/eval stop FL
    min_fit_clients: int = num_clients_per_round,
    min_evaluate_clients: int = num_evaluate_clients,
    # Function to test the federated model performance
    # external to a client instantiation
    evaluate_fn: (
        Callable[
            [int, NDArrays, dict[str, Scalar]],
            tuple[float, dict[str, Scalar]] | None,
        ]
        | None
    ) = federated_evaluation_function,
    # Functions to generate a config for client fit/evaluate
    # by-default the same config is shallow-copied to all clients in Flower
    # this version simply uses the configs defined above
    on_fit_config_fn: Callable[
        [int], dict[str, Scalar]
    ] = lambda _x: federated_train_config,
    on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] = lambda _x: test_config,
    # The "Parameters" type is merely a more packed version
    # of numpy array lists, used internally by Flower
    initial_parameters: Parameters = initial_parameters,
    # If this is set to True, aggregation will work even if some clients fail
    accept_failures: bool = False,
    # How to combine the metrics dictionary returned by all clients for fit/eval
    fit_metrics_aggregation_fn: Callable | None = aggregate_weighted_average,
    evaluate_metrics_aggregation_fn: Callable | None = aggregate_weighted_average,
    federated_client_generator: Callable[
        [str], flwr.client.NumPyClient
    ] = federated_client_generator,
    # Aggregation learning rate for FedAvg
    server_learning_rate: float = 1.0,
    server_momentum: float = 0.0,
) -> tuple[list[tuple[int, NDArrays]], History]:
    """Run a federated simulation using Flower."""
    log(INFO, "FL will execute for %s rounds", num_rounds)

    # Percentage of clients used for train/eval
    fraction_fit: float = float(num_clients_per_round) / num_total_clients
    fraction_evaluate: float = float(num_evaluate_clients) / num_total_clients

    strategy = FedAvg(
        fraction_fit=fraction_fit,
        fraction_evaluate=fraction_evaluate,
        min_fit_clients=min_fit_clients,
        min_evaluate_clients=min_evaluate_clients,
        min_available_clients=min_available_clients,
        on_fit_config_fn=on_fit_config_fn,
        on_evaluate_config_fn=on_evaluate_config_fn,
        evaluate_fn=evaluate_fn,
        initial_parameters=initial_parameters,
        accept_failures=accept_failures,
        fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
        evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
        server_learning_rate=server_learning_rate,
        server_momentum=server_momentum,
        # batch_size = int(config["batch_size"])
        # num_workers = int(config["num_workers"])
        # dataset = self._load_dataset(name)
        
    )
    # resetting the seed for the random selection of clients
    # this way the list of clients trained is guaranteed to be always the same

    cfg = ServerConfig(num_rounds)

    def simulator_client_generator(cid: str) -> Client:
        return federated_client_generator(cid).to_client()

    parameters_for_each_round, hist = start_seeded_simulation(
        client_fn=simulator_client_generator,
        num_clients=num_total_clients,
        config=cfg,
        strategy=strategy,
        name="c2m3",
        return_all_parameters=True,
        seed=Seeds.DEFAULT,
    )
    return parameters_for_each_round, hist

In [None]:
parameters_for_each_round, hist = run_simulation_frank_wolfe()

INFO flwr 2025-03-28 05:06:05,344 | 4257692453.py:56 | FL will execute for 10 rounds


INFO flwr 2025-03-28 05:06:05,368 | app.py:149 | Starting Flower simulation, config: ServerConfig(num_rounds=10, round_timeout=None)


INFO flwr 2025-03-28 05:06:05,375 | server_returns_parameters.py:81 | Initializing global parameters


INFO flwr 2025-03-28 05:06:05,379 | server_returns_parameters.py:273 | Using initial parameters provided by strategy


INFO flwr 2025-03-28 05:06:05,389 | server_returns_parameters.py:84 | Evaluating initial parameters


Testing [28510 samples]: 100%|██████████| 100/100 [00:01<00:00, 82.50it/s]
INFO flwr 2025-03-28 05:06:06,648 | server_returns_parameters.py:87 | initial parameters (loss, other metrics): 412.7646117210388, {'accuracy': 0.005625}


INFO flwr 2025-03-28 05:06:06,652 | server_returns_parameters.py:97 | FL starting


DEBUG flwr 2025-03-28 05:06:06,656 | server_returns_parameters.py:223 | fit_round 1: strategy sampled 5 clients (out of 20)


INFO flwr 2025-03-28 05:06:06,661 | 3963006861.py:21 | cid: 370
INFO flwr 2025-03-28 05:06:06,666 | 3963006861.py:21 | cid: 44


INFO flwr 2025-03-28 05:06:06,669 | 3963006861.py:21 | cid: 558
INFO flwr 2025-03-28 05:06:06,673 | 3963006861.py:21 | cid: 284


INFO flwr 2025-03-28 05:06:06,677 | 3963006861.py:21 | cid: 322


Training [Epoch 1/25]:   0%|          | 0/32 [00:00<?, ?it/s]
[A

[A[A


Training [Epoch 1/25]:   3%|▎         | 1/32 [00:00<00:06,  4.78it/s]

[A[A


[A[A[A
[A


[A[A[A

Training [Epoch 1/25]:   9%|▉         | 3/32 [00:00<00:03,  7.82it/s]
[A

Training [Epoch 1/25]:  12%|█▎        | 4/32 [00:00<00:03,  8.00it/s]
[A


[A[A[A

Training [Epoch 1/25]:  16%|█▌        | 5/32 [00:00<00:03,  7.92it/s]
[A


[A[A[A

Training [Epoch 1/25]:  19%|█▉        | 6/32 [00:00<00:03,  7.73it/s]
[A


[A[A[A

Training [Epoch 1/25]:  22%|██▏       | 7/32 [00:00<00:03,  6.74it/s]
[A


[A[A[A

[A[A
[A

[A[A


Training [Epoch 1/25]:  28%|██▊       | 9/32 [00:01<00:03,  7.58it/s]
[A


Training [Epoch 1/25]:  31%|███▏      | 10/32 [00:01<00:02,  7.71it/s]

[A[A
Training [Epoch 1/25]:  34%|███▍      | 11/32 [00:01<00:02,  7.63it/s]
[A


[A[A[A

Training [Epoch 1/25]:  38%|███▊      | 12/32 [00:01<00:02,  7.99it/s]
[A

[A[A


Training [Epoch 1/25]:  41%|████      | 13/32

In [None]:
parameters_for_each_round_fedavg, hist_fedavg = run_simulation_fedavg()

In [None]:
hist

In [None]:
hist_fedavg

In [None]:
import json
import os
from pathlib import Path

save_dir = Path.cwd() / "../../fed_results"
save_dir.mkdir(exist_ok=True)

# Convert the history objects to dictionaries
hist_dict = {
    "metrics_centralized": hist.metrics_centralized,
    "losses_centralized": hist.losses_centralized,
    "metrics_distributed_fit": hist.metrics_distributed_fit,
    "metrics_distributed": hist.metrics_distributed,
    "losses_distributed": hist.losses_distributed,
}

hist_fedavg_dict = {
    "metrics_centralized": hist_fedavg.metrics_centralized,
    "losses_centralized": hist_fedavg.losses_centralized,
    "metrics_distributed_fit": hist_fedavg.metrics_distributed_fit,
    "losses_distributed": hist_fedavg.losses_distributed,
    "metrics_distributed": hist_fedavg.metrics_distributed,
}

results_data = {
    "c2m3": {
        "history": hist_dict
    },
    "fedavg": {
        "history": hist_fedavg_dict
    }
}

# Save to a JSON file
with open(save_dir / f"flwr_LR_0.01_{Seeds.DEFAULT}.json", "w") as f:
    json.dump(results_data, f, indent=2)

print(f"Results saved to {save_dir / 'flower_simulation_results.json'}")

In [24]:
import matplotlib.pyplot as plt

def plot_metrics(hist1, hist2, legend_labels=['FrankWolfe', 'FedAvg'], save_path=None):
    
    acc1 = hist1.metrics_centralized['accuracy']
    rounds_acc1, acc_values1 = zip(*acc1)
    
    acc2 = hist2.metrics_centralized['accuracy']
    rounds_acc2, acc_values2 = zip(*acc2)
    
    loss1 = hist1.losses_centralized
    rounds_loss1, loss_values1 = zip(*loss1)
    
    loss2 = hist2.losses_centralized
    rounds_loss2, loss_values2 = zip(*loss2)
    
    fig, axs = plt.subplots(2, 1, figsize=(12, 10))
    
    axs[0].plot(rounds_acc1, acc_values1, 'o-', color='blue', linewidth=2, markersize=8, 
               label=f'{legend_labels[0]} Accuracy')
    axs[0].plot(rounds_acc2, acc_values2, 's-', color='cyan', linewidth=2, markersize=8, 
               label=f'{legend_labels[1]} Accuracy')
    
    axs[0].set_title('Accuracy Comparison', fontsize=14)
    axs[0].set_xlabel('Round Number', fontsize=12)
    axs[0].set_ylabel('Accuracy', fontsize=12)
    axs[0].grid(True, linestyle='--', alpha=0.7)
    axs[0].legend(loc='best')
    
    all_rounds_acc = sorted(list(set(rounds_acc1 + rounds_acc2)))
    axs[0].set_xticks(all_rounds_acc)
    
    axs[1].plot(rounds_loss1, loss_values1, 'o-', color='red', linewidth=2, markersize=8, 
               label=f'{legend_labels[0]} Loss')
    axs[1].plot(rounds_loss2, loss_values2, 's-', color='orange', linewidth=2, markersize=8, 
               label=f'{legend_labels[1]} Loss')
    
    axs[1].set_title('Loss Comparison', fontsize=14)
    axs[1].set_xlabel('Round Number', fontsize=12)
    axs[1].set_ylabel('Loss', fontsize=12)
    axs[1].grid(True, linestyle='--', alpha=0.7)
    axs[1].legend(loc='best')

In [None]:
plot_metrics(hist, hist_fedavg)

In [26]:
# log(
#     INFO,
#     "Size of the list with the model parameters: %s",
#     len(parameters_for_each_round),
# )