# Managing client-server communication with `fluke`

This tutorial will guide you through the steps required to handle the communication between clients and server when implementing a new federated learning algorithm using the `fluke` library.

```{attention}
This tutorial does not go into the details of the implementation, but it provides a quick overview of the steps required to implement a new federated learning algorithm.
```

Try this notebook: [![Open in Colab](https://img.shields.io/badge/Open_in_Colab-blue?style=flat-square&logo=google-colab&logoColor=yellow&labelColor=gray)
](https://colab.research.google.com/github/CasellaJr/Fluke-tutorial-ECAI25/blob/main/3_fluke_communication.ipynb)

## LG-FedAVG example

We have seen in the theoretical part of the tutorial that in LG-FedAVG, clients send only the global model updates to the server, while keeping their local model private. To implement this behavior in `fluke`, we need to customize the `Client` and `Server` classes to handle the communication accordingly.

But first, we need to define a model that separates the global and local parts.
`fluke` provides a convenient `EncoderHeadNet` class that allows us to easily create such models.

In [None]:
from torch.nn import Module, Linear
from torch import Tensor
import torch.nn.functional as F

from fluke.nets import EncoderHeadNet

# The encoder/backbone
class MLP_E(Module):

    def __init__(self, hidden_size: tuple[int, int] = (200, 100)):
        super(MLP_E, self).__init__()
        self.fc1 = Linear(28 * 28, hidden_size[0])
        self.fc2 = Linear(hidden_size[0], hidden_size[1])

    def forward(self, x: Tensor) -> Tensor:
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return x

# The head/classifier
class MLP_D(Module):

    def __init__(self, hidden_size: int = 100):
        super(MLP_D, self).__init__()
        self.fc3 = Linear(hidden_size, 10)

    def forward(self, x: Tensor) -> Tensor:
        return self.fc3(x)

# The complete model
class MLP(EncoderHeadNet):
    def __init__(self, hidden_size: tuple[int, int] = (200, 100)):
        super(MLP, self).__init__(
            MLP_E(hidden_size), MLP_D(hidden_size[1])
        )

Then we define the client that only needs to share the head instead of the full model.
To do this we override the `send_model` and receive_model methods to only send/receive the global part of the model.

In [None]:
from typing import Sequence
from torch.nn.modules import Module

from fluke.algorithms import PersonalizedFL  # NOQA
from fluke.client import Client  # NOQA
from fluke.comm import Message  # NOQA
from fluke.config import OptimizerConfigurator  # NOQA
from fluke.data import FastDataLoader  # NOQA
from fluke.nets import EncoderHeadNet, HeadGlobalEncoderLocalNet  # NOQA
from fluke.server import Server  # NOQA
from fluke.utils.model import safe_load_state_dict  # NOQA


class MyLGFedAVGClient(Client):

    def __init__(
        self,
        index: int,
        model: EncoderHeadNet,
        train_set: FastDataLoader,
        test_set: FastDataLoader,
        optimizer_cfg: OptimizerConfigurator,
        loss_fn: Module,
        local_epochs: int = 3,
        fine_tuning_epochs: int = 0,
        **kwargs,
    ):
        super().__init__(
            index=index,
            train_set=train_set,
            test_set=test_set,
            optimizer_cfg=optimizer_cfg,
            loss_fn=loss_fn,
            local_epochs=local_epochs,
            fine_tuning_epochs=fine_tuning_epochs,
            **kwargs,
        )
        self.model = HeadGlobalEncoderLocalNet(model)
        self._save_to_cache()

    # The client communicates with the server through the channel
    # The channel is a field of the Client class and it is a shared (with the server) reference
    def send_model(self) -> None:
        # The communication happens through Messages
        # Each message has a type (here "model") and a payload (here the model)
        # + the sender (the client ID in this case) and the receiver (the server)
        # `inmemory=True` means that the message is kept in memory and not serialized to disk
        message_model = Message(self.model.get_global(), "model", self.index, inmemory=True)
        self.channel.send(
            message=message_model,
            mbox="server"
        )

    def receive_model(self) -> None:
        # The client receives only the global model from the server
        msg = self.channel.receive(self.index, "server", msg_type="model")
        # Update only the global part of the model
        safe_load_state_dict(self.model.get_global(), msg.payload.state_dict())


The server customization is not strictly necessary in case we assume that the configuration will be done 
coherently (i.e., the server does not have a test set). Otherwise, we can force it by changing the server constructor which will ignore any test set passed to it.

In [None]:
class MyLGFedAVGServer(Server):

    def __init__(
        self,
        model: Module,
        test_set: FastDataLoader,  # not used
        clients: Sequence[Client],
        weighted: bool = False,
    ):
        super().__init__(model=model, test_set=None, clients=clients, weighted=weighted)


The algorithm now just need to define the `get_client_class` and `get_server_class` methods to use the customized classes.
Note that it inherits from `PersonalizedFL` since LG-FedAVG is a personalized federated learning algorithm.
The main difference with a standard `CentralizedFL` is that the client must define a model type rather than receiving it from the server.

In [None]:
class MyLGFedAVG(PersonalizedFL):

    def get_client_class(self) -> type[Client]:
        return MyLGFedAVGClient

    def get_server_class(self) -> type[Server]:
        return MyLGFedAVGServer

## Running the algorithm

In [None]:
from fluke.data import DataSplitter
from fluke.data.datasets import Datasets
from fluke import DDict
from fluke.utils.log import Log
from fluke.evaluation import ClassificationEval
from fluke import FlukeENV

env = FlukeENV()
env.set_seed(42) # we set a seed for reproducibility
env.set_device("cpu")

dataset = Datasets.get("mnist", path="./data")
env.set_evaluator(ClassificationEval(eval_every=1, n_classes=dataset.num_classes))

# We split the data client-side in train and test (80%-20%) - no server test set
splitter = DataSplitter(dataset=dataset, distribution="iid", client_split=0.2, server_test=False)

# Withouth a test on the server, we need to set post_fit evaluation client-side
env.set_eval_cfg(post_fit=True)

client_hp = DDict(
    model=MLP(),
    batch_size=10,
    local_epochs=5,
    loss="CrossEntropyLoss",
    optimizer=DDict(lr=0.1),
    scheduler=DDict(
      gamma=1,
      step_size=1)
)

hyperparams = DDict(client=client_hp,
                    server=DDict(weighted=True),
                    model=MLP_D()) # Here the shared model is only the head!!

In [None]:
from fluke.algorithms.lg_fedavg import LGFedAVG
algorithm = LGFedAVG(n_clients=100,
                       data_splitter=splitter,
                       hyper_params=hyperparams)

logger = Log()
algorithm.set_callbacks(logger)
logger.init()

In [None]:
algorithm.run(n_rounds=5, eligible_perc=0.2)