
# Customizing Federated Learning Server logics


In previous sections, we are able to run federated pytorch image classification code with NVIDIA FLARE builtin FedAvg algorithm. 
What if we want to build my own algorithms or modify the existing algorithm ? 

In the following, using FedAvg as starting point, we like to make a few changes to FedAvg to fit our needs: 

* Add early stopping mechanism so that the training could stop instead of waiting to the total numbers of rounds if the criteria is statisfied
* Instead of rely on the internal best model selection approach, we want to provide our own best model selection
* Instead of using building persiste component PTFileModelPersistor, we like to have our own save and loading functions


In this section, we will go over these changes step-by-step. 

> Reference:
> _[FedAvg with early stopping](https://github.com/NVIDIA/NVFlare/blob/main/examples/hello-world/hello-fedavg/hello-fedavg.ipynb) example_


## Customized FedAvg v1

There are several factors to consider:

* **How to write a Federated Avg Algorithms** 

* **How to express and apply the early stop condition** 


### Write a FedAvg Algorithm

FedAvg can be written as very simple for-loop. There are several other factors to consider 

* How to send the model to clients?
* How to receive the response 
* for the model and response, what's the format ? 
* The model and responses and corresponding objects must be serialized, how to series them ? 

Let's dive into these questions.


#### Transfer Structure: FLModel

FLARE defined a high-level data structure "FLModel" that holds the model parameters, metrics and metadata

```

class ParamsType(str, Enum):
    FULL = "FULL"
    DIFF = "DIFF"


class FLModel:
    def __init__(
        self,
        params_type: Union[None, str, ParamsType] = None,
        params: Any = None,
        optimizer_params: Any = None,
        metrics: Optional[Dict] = None,
        start_round: Optional[int] = 0,
        current_round: Optional[int] = None,
        total_rounds: Optional[int] = None,
        meta: Optional[Dict] = None,
    ):

```
the data can be packaged into FLModel transfer between clients and server as well as among clients. 


#### Serialization 

Many of the deep learning machine frameworks using python pickle as default serrialization mechanism. There are enough security concerns that FLARE is not using Pickle. NVIDIA FLARE Object Serializer (FOBS) used a [messagePack](https://msgpack.org/index.html)-based serialization approach. 
User needs to register a component ( "Decomposer") to serialize/de-serialize certain project to fobs. 

To PyTorch Tensor, we need to register [TensorDecompressor](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_opt/pt/decomposers.py) component at FOBS. 

```
            # Use FOBS for serializing/deserializing PyTorch tensors
            fobs.register(TensorDecomposer)
```

#### Send and Recieve Object

For high-level API, we can use the followings

```
   results = self.send_model_and_wait(targets=clients, data=model)
```
the function send the FLModel to targeted clients and recieve result. This is synchornized methood like scatter and gather. We broadcast the model to all targeted clients and receive results when required clients send back the results. 

The BasedFedAvg is derived from ModelController which has the communication component, which allows the component to send the model and wait for result. 


Now we are covered these few factors, lets write a class and see how it will look

We will start with BaseFedAvg class. ```class BaseFedAvg``` provided a core based class for the customize FedAvg, it define a run() methods that capture all the running logs
as well as some utiliies. We can look at the initial version of the code
 

In [None]:
! cat code/src/fedavg_v0.py


Now, we have our own FedAvg version, we now look into how to stop the training

### Express and apply the early stop condition

#### Stop Condition

```stop_cond``` is a string to represent the stop condition, its string literal in the format of "<key> <op> <value>" (e.g. "accuracy >= 80")

we need to parse this condition so we can compare. To parse this, we leverage FLARE's math_utils
```

math_utils.parse_compare_criteria(compare_expr: Optional[str] = None) -> Tuple[str, float, Callable]

```
the return will be
* key,
* target_value,
* callable op_fn

For example

In [None]:
from nvflare.app_common.utils.math_utils import parse_compare_criteria
key, target_value, fn= parse_compare_criteria("accuracy > 80")
print (key, target_value, fn)
accuracy = 90
fn (accuracy, target_value)

#### Integrate the early stop condition

This should simple, if the condition is satified and simply break out the for-loop

```
            if self.should_stop(model.metrics, self.stop_condition):
                break
```

and the ```should_stop``` function is defined as followings

```
def should_stop(self, metrics: Optional[Dict] = None, stop_condition: Optional[str] = None):
        key, target, op_fn = stop_condition
        value = metrics.get(key, None)
        return op_fn(value, target)
```

the code can be found in [fedavg_v1.py](code/src/fedavg_v1.py)


## Customized FedAvg v2

We have successfully modify the FedAvg logics and allow user to specify early stop conditions. 
Now, we want to make additional changes

* We like to implements our own best model selection
* we like to have our own model save and loading instead of using the FLARE's persistor. 




### Select best model 

we simply write the following two functions and put into previus code

```
    def select_best_model(self, curr_model: FLModel):
        if self.best_model is None:
            self.best_model = curr_model
            return

        if self.stop_condition:
            metric, _, op_fn = self.stop_condition
            if self.is_curr_model_better(self.best_model, curr_model, metric, op_fn):
                self.info("Current model is new best model.")
                self.best_model = curr_model
        else:
            self.best_model = curr_model

    def is_curr_model_better(
        self, best_model: FLModel, curr_model: FLModel, target_metric: str, op_fn: Callable
    ) -> bool:
        curr_metrics = curr_model.metrics
        if curr_metrics is None:
            return False
        if target_metric not in curr_metrics:
            return False

        best_metrics = best_model.metrics
        return op_fn(curr_metrics.get(target_metric), best_metrics.get(target_metric))

```





### Customized save and load model functions
     
The ```BaseFedAvg``` class defined ```save_model()``` and ```load_model()``` functions for user to overwrite. 
We use torch save and load functions, and save the FLModel metadata separately with the fobs.dumpf and fobs.loadf serialization utilities.



    def save_model(self, model, filepath=""):
        params = model.params
        # PyTorch save
        torch.save(params, filepath)

        # save FLModel metadata
        model.params = {}
        fobs.dumpf(model, filepath + ".metadata")
        model.params = params

    def load_model(self, filepath=""):
        # PyTorch load
        params = torch.load(filepath)

        # load FLModel metadata
        model = fobs.loadf(filepath + ".metadata")
        model.params = params
        return model


## Add Evaluation to the training code

We need to add the evaluation code to the training code to compute accuracy. We then use accuracy to select best model





## Running Customized FedAvg

Now, put everything together in [fedavg_v2](code/src/fedavg_v2.py), we can take a look at the server code


In [None]:
!cat code/src/fedavg_v2.py

Lets create Job with our newly modified FedAvgV2. 

### Create Fed Job

```
    n_clients = 5
    num_rounds = 2

    train_script = "src/client.py"

    job = FedJob(name="fedavg_v2", n_clients=n_clients)

    controller = FedAvgV2(
            num_clients=n_clients,
            num_rounds=num_rounds,
            stop_cond = None,
            save_filename = "global_model.pt",
            initial_model=SimpleNetwork())
        
    job.to_server(controller)

    # Add clients
    for i in range(n_clients):
        executor = ScriptRunner(
            script=train_script, script_args=""
        )
        job.to(executor, f"site-{i + 1}")

    job.simulator_run("/tmp/nvflare/jobs/workdir")


```




### Run job with simulator

In [None]:
! pip install nvflare

In [None]:
! pip install -r code/requirements.txt

In [None]:
!python3 code/data/download.py

In [None]:
%cd code

In [None]:
! python3 fl_job.py