# New federated Aggregation with `fluke`

This tutorial will guide you through the steps required to implement KRUM and Multi-KRUM with ``fluke``.

```{attention}
This tutorial does not go into the details of the implementation, but it provides a quick overview of the steps required to implement a new federated learning algorithm.
```

Try this notebook: [![Open in Colab](https://img.shields.io/badge/Open_in_Colab-blue?style=flat-square&logo=google-colab&logoColor=yellow&labelColor=gray)
](https://colab.research.google.com/github/CasellaJr/Fluke-tutorial-ECAI25/blob/main/2_fluke_krum.ipynb)

## Install `fluke` (if not already done)

In [None]:
!pip install fluke-fl

# KRUM

Let's do a further operation, let's implement KRUM aggregation, that should be more robust to byzantine attacks than the median aggregation.

In [None]:
from typing import Sequence, Iterable
from torch.nn import Module
import torch
from copy import deepcopy
import warnings

from fluke.client import Client
from fluke.server import Server
from fluke.data import FastDataLoader

warnings.simplefilter("ignore")

class KRUMServer(Server):
    def __init__(
        self,
        model: torch.nn.Module,
        test_set: FastDataLoader | None,
        clients: Sequence[Client],
        weighted: bool = False,
        lr: float = 1.0,
        f: int = 0,
        **kwargs,
    ):
        super().__init__(model, test_set, clients, weighted, lr, **kwargs)
        self.hyper_params.update(f=f)


    def aggregate(self, eligible: Sequence[Client], client_models: Iterable[Module]) -> None:

        client_models = list(client_models)
        n = len(client_models)
        f = self.hyper_params.f

        server_param_items = list(self.model.named_parameters())
        param_keys = [k for k, _ in server_param_items]

        # get one vector for each client
        client_vecs = []
        for i, cm in enumerate(client_models):
            cm_state = dict(cm.named_parameters())
            parts = [torch.ravel(cm_state[k].data) for k in param_keys]
            vec = torch.cat(parts, dim=0)
            client_vecs.append(vec)

        # put all these vector in a tensor
        mat = torch.stack(client_vecs, dim=0)

        # calculate pairwise euclidean distance matrix to extract the scores
        dists = torch.cdist(mat, mat, p=2.0)  # shape (n, n)
        sq_dists = dists.pow(2)

        scores = torch.empty(n)
        # For each model i, sort distances to others (exclude self-distance=0)
        # and sum the smallest (n - f - 2) distances
        nb_small = n - f - 2
        for i in range(n):
            # distances to others (includes zero at i)
            row = sq_dists[i]
            # Get sorted distances (ascending)
            sorted_row, _ = torch.sort(row)
            selected = sorted_row[1: nb_small + 1]
            scores[i] = selected.sum()

        _, selected_indices = torch.topk(scores, k=1, largest=False)

        selected_indices = selected_indices.tolist()  # indices of chosen clients

        chosen_idx = selected_indices[0]
        chosen_state = dict(client_models[chosen_idx].named_parameters())
        # Copy parameter tensors into server model
        for k in param_keys:
            self.model.state_dict()[k].copy_(chosen_state[k].data)

        return deepcopy(self.model)

## Implementing the new federated algorithm

Now, we only need to put everything together in a new class that inherits from `fluke.algorithms.CentralizedFL` specifying the server class we just implemented.

In [None]:
from fluke.algorithms import CentralizedFL

class KRUMFLAlgorithm(CentralizedFL):

    def get_server_class(self) -> type[Server]:
        return KRUMServer

Everything is ready! Now we can test our new federated algorithm with `fluke`!

## Ready to test KRUM

In [None]:
from fluke.data import DataSplitter
from fluke.data.datasets import Datasets
from fluke import DDict
from fluke.utils.log import Log
from fluke.evaluation import ClassificationEval
from fluke import FlukeENV

env = FlukeENV()
env.set_seed(42) # we set a seed for reproducibility
env.set_device("cpu") # we use the CPU for this example

dataset = Datasets.get("mnist", path="./data")

# we set the evaluator to be used by both the server and the clients
env.set_evaluator(ClassificationEval(eval_every=1, n_classes=dataset.num_classes))

splitter = DataSplitter(dataset=dataset,
                        distribution="iid")

client_hp = DDict(
    batch_size=10,
    local_epochs=5,
    loss="CrossEntropyLoss",
    optimizer=DDict(
      lr=0.01,
      momentum=0.9,
      weight_decay=0.0001),
    scheduler=DDict(
      gamma=1,
      step_size=1)
)

# we put together the hyperparameters for the algorithm
hyperparams = DDict(client=client_hp,
                    server=DDict(weighted=True, f=7),
                    model="MNIST_2NN")

Here is where the new federated algorithm comes into play.

In [None]:
algorithm = KRUMFLAlgorithm(n_clients=10, # 10 clients in the federation
                            data_splitter=splitter,
                            hyper_params=hyperparams)

logger = Log()
algorithm.set_callbacks(logger)
logger.init()

In [None]:
algorithm.run(n_rounds=3, eligible_perc=1)

# Multi-KRUM

In [None]:
class MultiKRUMServer(Server):
    def __init__(
        self,
        model: torch.nn.Module,
        test_set: FastDataLoader | None,
        clients: Sequence[Client],
        weighted: bool = False,
        lr: float = 1.0,
        f: int = 0,
        m: int = 1,
        **kwargs,
    ):
        super().__init__(model, test_set, clients, weighted, lr, **kwargs)
        self.hyper_params.update(f=f, m=m)


    def aggregate(self, eligible: Sequence[Client], client_models: Iterable[Module]) -> None:

            client_models = list(client_models)
            n = len(client_models)
            f = self.hyper_params.f

            server_param_items = list(self.model.named_parameters())
            param_keys = [k for k, _ in server_param_items]

            # get one vector for each client
            client_vecs = []
            for i, cm in enumerate(client_models):
                cm_state = dict(cm.named_parameters())
                parts = [torch.ravel(cm_state[k].data) for k in param_keys]
                vec = torch.cat(parts, dim=0)
                client_vecs.append(vec)

            # put all these vector in a tensor
            mat = torch.stack(client_vecs, dim=0)

            # calculate pairwise euclidean distance matrix to extract the scores
            dists = torch.cdist(mat, mat, p=2.0)  # shape (n, n)
            sq_dists = dists.pow(2)

            scores = torch.empty(n)
            # For each model i, sort distances to others (exclude self-distance=0)
            # and sum the smallest (n - f - 2) distances
            nb_small = n - f - 2
            for i in range(n):
                # distances to others (includes zero at i)
                row = sq_dists[i]
                # Get sorted distances (ascending)
                sorted_row, _ = torch.sort(row)
                selected = sorted_row[1: nb_small + 1]
                scores[i] = selected.sum()

            _, selected_indices = torch.topk(scores, k=self.hyper_params.m, largest=False)

            selected_indices = selected_indices.tolist()  # indices of chosen clients

            # If m == 1: KRUM
            if self.hyper_params.m == 1:
                chosen_idx = selected_indices[0]
                chosen_state = dict(client_models[chosen_idx].named_parameters())
                # Copy parameter tensors into server model
                for k in param_keys:
                    self.model.state_dict()[k].copy_(chosen_state[k].data)
            else:
                # Multi-KRUM: average the selected m models elementwise (simple mean)
                # We'll build averaged parameters by stacking the selected params and mean them
                # Using named_parameters order to assemble results
                # Initialize accumulator dict of tensors (float) on server device
                accum = {}
                for k in param_keys:
                    first_t = client_models[selected_indices[0]].state_dict()[k].data
                    accum[k] = first_t.clone().detach().float()  # use float accumulation

                for idx in selected_indices[1:]:
                    cm_state = client_models[idx].state_dict()
                    for k in param_keys:
                        accum[k].add_(cm_state[k].data)

                # divide by m and copy back
                for k in param_keys:
                    mean_t = (accum[k] / float(self.hyper_params.m)).to(self.model.state_dict()[k].dtype)
                    self.model.state_dict()[k].copy_(mean_t)

            return deepcopy(self.model)

In [None]:
from fluke.algorithms import CentralizedFL

class MultiKRUMFLAlgorithm(CentralizedFL):

    def get_server_class(self) -> type[Server]:
        return MultiKRUMServer

In [None]:
algorithm = MultiKRUMFLAlgorithm(n_clients=10, # 10 clients in the federation
                                data_splitter=splitter,
                                hyper_params=hyperparams)

logger = Log()
algorithm.set_callbacks(logger)
logger.init()
algorithm.run(n_rounds=3, eligible_perc=1)

# Implmenting malicious clients

We implement a simple malicious client that sends random model updates to the server.

In [None]:
import torch
import warnings
from fluke.client import Client

warnings.simplefilter("ignore")


class RandomClient(Client):
    def __init__(self, malicious_percentage: float = 0, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.malicious = torch.rand(1).item() < malicious_percentage

    def fit(self, override_local_epochs = 0):
        if self.malicious:
            for param in self.model.parameters():
                param.data = torch.randn_like(param)
        else:
            super().fit(override_local_epochs)

Let's try to create a FedAVG federation where some clients are malicious and send random updates to the server.

In [None]:
class AttackFLAlgorithm(CentralizedFL):

    def get_client_class(self) -> type[Client]:
        return RandomClient

In [None]:
client_hp = DDict(
    batch_size=10,
    local_epochs=5,
    loss="CrossEntropyLoss",
    optimizer=DDict(
        lr=0.01,
        momentum=0.9,
        weight_decay=0.0001),
    scheduler=DDict(
        gamma=1,
        step_size=1),
    malicious_percentage=0.4 ## 40% of clients are malicious
)

# we put together the hyperparameters for the algorithm
hyperparams = DDict(client=client_hp,
                    server=DDict(weighted=True),
                    model="MNIST_2NN")

In [None]:
algorithm = AttackFLAlgorithm(
    n_clients=10, # 10 clients in the federation
    data_splitter=splitter,
    hyper_params=hyperparams
)

logger = Log()
algorithm.set_callbacks(logger)
logger.init()
algorithm.run(n_rounds=3, eligible_perc=1)

## Set up the defence

Let's now implement the defence mechanism, that is, KRUM aggregation rule.

In [None]:
class AttackDefenceFLAlgorithm(CentralizedFL):

    def get_server_class(self) -> type[Server]:
        return KRUMServer

    def get_client_class(self) -> type[Client]:
        return RandomClient

In [None]:
algorithm = AttackDefenceFLAlgorithm(
    n_clients=10, # 10 clients in the federation
    data_splitter=splitter,
    hyper_params=hyperparams
)

logger = Log()
algorithm.set_callbacks(logger)
logger.init()
algorithm.run(n_rounds=3, eligible_perc=1)

Let's try the Multi-KRUM aggregation rule in the presence of malicious clients.

In [None]:
class AttackDefenceFLAlgorithm(CentralizedFL):

    def get_server_class(self) -> type[Server]:
        return MultiKRUMServer

    def get_client_class(self) -> type[Client]:
        return RandomClient

In [None]:
algorithm = AttackDefenceFLAlgorithm(
    n_clients=10, # 10 clients in the federation
    data_splitter=splitter,
    hyper_params=hyperparams
)

logger = Log()
algorithm.set_callbacks(logger)
logger.init()
algorithm.run(n_rounds=3, eligible_perc=1)