
# Customizing Federated Learning Server Logics

In the [previous section](../01.2_convert_deep_learning_to_federated_learning/convert_dl_to_fl.ipynb), we were able to run federated PyTorch image classification code with NVIDIA FLARE's built-in FedAvg algorithm.

What if we want to build our own algorithm or modify the existing one?

In the following, using FedAvg as a starting point, we would like to make a few changes to FedAvg to fit our needs:

* Add early stopping mechanism so that the training could stop instead of waiting for the total number of rounds if the criteria is satisfied
* Instead of relying on the internal best model selection approach, we want to provide our own best model selection
* Instead of using built-in persist component PTFileModelPersistor, we would like to have our own save and loading functions

In this section, we will go over these changes step-by-step. 

> Reference:
> _[FedAvg with early stopping](https://github.com/NVIDIA/NVFlare/blob/main/examples/hello-world/hello-fedavg/hello-fedavg.ipynb) example_

## Getting started

We will start with the [`BaseFedAvg`](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_common/workflows/base_fedavg.py#L29) class. This class provides a core basic class to customize FedAvg. It is derived from the [`ModelController`](https://nvflare.readthedocs.io/en/main/programming_guide/controllers/model_controller.html) class which exposes a `run()` method that orchestrates the overall federated workflow. The `ModelController` class also contains the communication component, that is capable to send the model to clients, and wait for result. 

In the following sections, we will show various incremental implementations of `FedAvg` that will essentially override the `run()` method in `BaseFedAvg` / `ModelController` class. 

## Version 0: writting a basic FedAvg algorithm

In this version, we will look at how the basic `FedAvg` algorithm is implemented.

Using the `BaseFedAvg` class, `FedAvg` can be written as very simple for-loop inside the [`run()`](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_common/workflows/fedavg.py#L34) function. 

There are several other factors to consider 
* How to send the model to clients?
* How to receive the responses?
* For the model and responses, what's the format?
* The model and responses and corresponding objects must be serialized, how to serialize them?

Let's dive into these questions.

### Transfer format: `FLModel`

FLARE defined a high-level data structure [`FLModel`](https://nvflare.readthedocs.io/en/main/programming_guide/fl_model.html) that holds the model parameters, metrics and metadata

```python

class ParamsType(str, Enum):
    FULL = "FULL"
    DIFF = "DIFF"


class FLModel:
    def __init__(
        self,
        params_type: Union[None, str, ParamsType] = None,
        params: Any = None,
        optimizer_params: Any = None,
        metrics: Optional[Dict] = None,
        start_round: Optional[int] = 0,
        current_round: Optional[int] = None,
        total_rounds: Optional[int] = None,
        meta: Optional[Dict] = None,
    ):

```
Using `FLModel`, model & responses data can be packaged for transfer between clients and server, as well as among clients. 

### Serialization 

Many deep learning / machine learning frameworks use python `pickle` as default serialization mechanism. However, there are security concerns, because of which FLARE does not use `pickle` for object serialization. 

NVIDIA FLARE introduces FLARE Object Serializer (FOBS), which uses a [messagePack](https://msgpack.org/index.html)-based serialization approach. User needs to register a component ("Decomposer") to serialize/de-serialize an object to FOBS. 

For example, for a PyTorch tensor, we need to register the [TensorDecomposer](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_opt/pt/decomposers.py) component to FOBS. 

```python
# Use FOBS for serializing/deserializing PyTorch tensors
fobs.register(TensorDecomposer)
```

### Send and recieve object

FLARE's `ModelController` class provides a high-level API allowing you to easily send data to a specific target.

```python

   results = self.send_model_and_wait(targets=clients, data=model)
```

The `send_model_and_wait` function send the FLModel to targeted clients and recieve result. This is a synchronized function, executing a "[Scatter-and-Gather](https://nvflare.readthedocs.io/en/main/programming_guide/controllers/scatter_and_gather_workflow.html)" workflow. We broadcast the model to all targeted clients and receive results when required clients send back the results. 

Let's look at an initial version of the code in [code/src/fedavg_v0.py](code/src/fedavg_v0.py):

In [None]:
! cat code/src/fedavg_v0.py

## Version 1: express and apply the early stop condition

Now, we have our initial implementation of FedAvg, let's look into how to add the early stop condition.

### Create a stop condition

FLARE's `math_utils` provides the `parse_compare_criteria` function that we can leverage to implement custom stop conditions:
```python

math_utils.parse_compare_criteria(compare_expr: Optional[str] = None) -> Tuple[str, float, Callable]

```

Here, ```compare_expr``` is a literal string representing the stop condition, in the format of `"<key> <op> <value>"`. For example: `"accuracy >= 80"`.

The returned tuple will contain:
* A key, for instance `accuracy`
* The target_value, for instance `80`
* The callable function `op_fn`, for instance `gt`

Execute the following example to generate a stop condition that is based on if the accuracy value is greater than 80 percent.

In [None]:
from nvflare.app_common.utils.math_utils import parse_compare_criteria
key, target_value, fn= parse_compare_criteria("accuracy > 80")
print (key, target_value, fn)
accuracy = 90
fn (accuracy, target_value)

### Integrate the early stop condition

Let's implement a simple `should_stop` function that returns a boolean value indicating if the stop condition is met:

```python

def should_stop(self, metrics: Optional[Dict] = None, stop_condition: Optional[str] = None):
        key, target, op_fn = stop_condition
        value = metrics.get(key, None)
        return op_fn(value, target)
```

Then, we can simply break out the execution loop if the condition is met:

```python
if self.should_stop(model.metrics, self.stop_condition):
    break
```

The complete code with the stop condition integrated can be found in [fedavg_v1.py](code/src/fedavg_v1.py):

In [None]:
! cat code/src/fedavg_v1.py

## Version 2: further customization

We have successfully modified the FedAvg logic and allowed user to specify early stop condition. 

Now, let's make some additional changes:
* Implements our own best model selection
* Implement our own model saving and loading functions 

### Select the best model 

We simply write the following two functions and put into previous code

```python

    def select_best_model(self, curr_model: FLModel):
        if self.best_model is None:
            self.best_model = curr_model
            return

        if self.stop_condition:
            metric, _, op_fn = self.stop_condition
            if self.is_curr_model_better(self.best_model, curr_model, metric, op_fn):
                self.info("Current model is new best model.")
                self.best_model = curr_model
        else:
            self.best_model = curr_model

    def is_curr_model_better(
        self, best_model: FLModel, curr_model: FLModel, target_metric: str, op_fn: Callable
    ) -> bool:
        curr_metrics = curr_model.metrics
        if curr_metrics is None:
            return False
        if target_metric not in curr_metrics:
            return False

        best_metrics = best_model.metrics
        return op_fn(curr_metrics.get(target_metric), best_metrics.get(target_metric))
```

### Customize model saving and loading functions
     
The ```BaseFedAvg``` class defined ```save_model()``` and ```load_model()``` functions for user to override. 
We use torch save and load functions, and save the FLModel metadata separately with the fobs.dumpf and fobs.loadf serialization utilities.

```python
    def save_model(self, model, filepath=""):
        params = model.params
        # PyTorch save
        torch.save(params, filepath)

        # save FLModel metadata
        model.params = {}
        fobs.dumpf(model, filepath + ".metadata")
        model.params = params

    def load_model(self, filepath=""):
        # PyTorch load
        params = torch.load(filepath)

        # load FLModel metadata
        model = fobs.loadf(filepath + ".metadata")
        model.params = params
        return model
```

That's it, put everything together in [fedavg_v2](code/src/fedavg_v2.py), let's take a look at the server code:

In [None]:
! cat code/src/fedavg_v2.py

## Running Customized FedAvg

Let's run our customized `FedAvgV2` with the simulator. Notice that we initialize the `FedAvgV2` class with a stop condition:

```python
stop_cond="accuracy > 25"
```

In [None]:
! pip install nvflare

In [None]:
! pip install -r code/requirements.txt

In [None]:
! python3 code/data/download.py

In [None]:
! cd code && python3 fl_job.py

Next, we are going to see how to [customize the cilent-side training](../01.4_client_side_customization/customize_client_training.ipynb).