# Hello FedAvg

Example of using [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html) to train an image classifier
using federated averaging ([FedAvg](https://arxiv.org/abs/1602.05629))
and [PyTorch](https://pytorch.org/) as the deep learning training framework.
In this example we highlight the flexibility of the ModelController API, and show how to write a Federated Averaging workflow with model selection, early stopping, and saving and loading. We use the train script [cifar10_fl.py](src/cifar10_fl.py) and network [net.py](src/net.py) from the src directory.

## 1. Setup
Install nvflare and dependencies.

In [None]:
! pip install nvflare~=2.5.0rc torch torchvision tensorboard

Download the source code for this example if running in Colab.


In [None]:
! npx degit NVIDIA/NVFlare/examples/hello-world/hello-fedavg/src src

## 2. PTFedAvgEarlyStopping using ModelController API

The ModelController API enables the option to easily customize a workflow. 
We implement additional functionalities on top of the BaseFedAvg class in [PTFedAvgEarlyStopping](https://github.com/NVIDIA/NVFlare/tree/main/nvflare/app_opt/pt/fedavg_early_stopping.py).

### 2.1 FedAvg
We subclass the BaseFedAvg class to leverage the predefined aggregation functions, and add our additional functionalities at the end of each round.

```python
self.select_best_model(model)

self.save_model(self.best_model, os.path.join(os.getcwd(), self.save_filename))

if self.should_stop(model.metrics, self.stop_condition):
    self.info(
        f"Stopping at round={self.current_round} out of total_rounds={self.num_rounds}. Early stop condition satisfied: {self.stop_condition}"
    )
    break
```

### 2.2 Model Selection
As an alternative to using an [IntimeModelSelector](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_common/widgets/intime_model_selector.py) component for model selection, we instead compare the metrics of the models in the workflow to select the best model each round.

```python
def select_best_model(self, curr_model: FLModel):
    if self.best_model is None:
        self.best_model = curr_model
        return

    if self.stop_condition:
        metric, _, op_fn = self.stop_condition
        if self.is_curr_model_better(self.best_model, curr_model, metric, op_fn):
            self.info("Current model is new best model.")
            self.best_model = curr_model
    else:
        self.best_model = curr_model

def is_curr_model_better(
    self, best_model: FLModel, curr_model: FLModel, target_metric: str, op_fn: Callable
) -> bool:
    curr_metrics = curr_model.metrics
    if curr_metrics is None:
        return False
    if target_metric not in curr_metrics:
        return False

    best_metrics = best_model.metrics
    return op_fn(curr_metrics.get(target_metric), best_metrics.get(target_metric))
```

### 2.3 Early Stopping
We add a `stop_condition` argument (eg. `"accuracy >= 80"`) and end the workflow early if the corresponding global model metric meets the condition.

```python
def should_stop(self, metrics: Optional[Dict] = None, stop_condition: Optional[str] = None):
    if stop_condition is None or metrics is None:
        return False

    key, target, op_fn = stop_condition
    value = metrics.get(key, None)

    if value is None:
        raise RuntimeError(f"stop criteria key '{key}' doesn't exists in metrics")

    return op_fn(value, target)
```

### 2.4 PyTorch Saving and Loading
Rather than configuring a persistor such as the [PTFileModelPersistor](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_opt/pt/file_model_persistor.py) component, we choose to utilize PyTorch's save and load functions and save the metadata of the FLModel separately. We load the `initial_model` into a class variable, which requires us to register the [TensorDecomposer](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_opt/pt/decomposers.py) for serialization of PyTorch tensors.

```python
if self.initial_model:
    # Use FOBS for serializing/deserializing PyTorch tensors (self.initial_model)
    fobs.register(TensorDecomposer)
    # PyTorch weights
    initial_weights = self.initial_model.state_dict()
else:
    initial_weights = {}

model = FLModel(params=initial_weights)
```

We use torch `save` and `load`, and save the FLModel metadata separately with the `fobs.dumpf` and `fobs.loadf` serialization utilities.

```python
def save_model(self, model, filepath=""):
    params = model.params
    # PyTorch save
    torch.save(params, filepath)

    # save FLModel metadata
    model.params = {}
    fobs.dumpf(model, filepath + ".metadata")
    model.params = params

def load_model(self, filepath=""):
    # PyTorch load
    params = torch.load(filepath)

    # load FLModel metadata
    model = fobs.loadf(filepath + ".metadata")
    model.params = params
    return model
```

## 3. Run the script

Use the Job API to define and run the example with the simulator.
(Note: We use `key_metric=None` to use our own model selection logic instead of the `IntimeModelSelector`, which will be configured if `key_metric` is used.)

In [None]:
from nvflare import FedJob, ScriptExecutor
from nvflare.app_opt.pt.fedavg_early_stopping import PTFedAvgEarlyStopping

job = FedJob(name="cifar10_fedavg_early_stopping", key_metric=None)

Define the `PTFedAvgEarlyStopping` controller workflow with the `stop_cond` and `initial_model` args and send to server.

In [None]:
from src.net import Net

n_clients = 2

# Define the controller workflow and send to server
controller = PTFedAvgEarlyStopping(
    num_clients=n_clients,
    num_rounds=5,
    stop_cond="accuracy >= 40",
    initial_model=Net(),
)
job.to(controller, "server")

Use the `ScriptExecutor` and send to each of the clients to run the train script.

In [None]:
train_script = "src/cifar10_fl.py"

# Add clients
for i in range(n_clients):
    executor = ScriptExecutor(task_script_path=train_script, task_script_args="")
    job.to(executor, f"site-{i}", gpu=0)

Optionally export the job to run in other modes.

In [None]:
job.export_job("/tmp/nvflare/jobs/job_config")

Run the FedJob using the simulator. View the results in the job workspace: `/tmp/nvflare/jobs/workdir`.

In [None]:
job.simulator_run("/tmp/nvflare/jobs/workdir")

### Visualize the Training Results
By default, we enable TensorBoard metric [streaming](https://nvflare.readthedocs.io/en/main/examples/tensorboard_streaming.html) using NVFlare's `SummaryWriter` in [src/cifar10_fl.py](src/cifar10_fl.py). 

The TensorBoard metrics will be received at the server, and you can visualize the training progress by running 
```commandline
tensorboard --logdir=/tmp/nvflare/jobs/workdir/server/simulate_job/tb_events
```
in a new terminal.