# FedAvg using Executor

In this example, we will demonstrate the FegAvg algorithm using the CIFAR10 dataset using an Executor. 

While the previous example [FedAvg with SAG workflow](../sag/sag.ipynb#title) utilized the Client API, here we will demonstrate how to convert the original training code into a Executor trainer, showcase its capabilities, and recommend the best use cases.

We build on top of the previous example [FedAvg algorithm](../sag/sag.ipynb#title)

Please follow these steps before proceeding to the next section:
  * [Understanding FedAvg and SAG](../sag/sag.ipynb#sag)

## Executor

An `Executor` in FLARE is an FLComponent for clients used for executing tasks, wherein the `execute` method receives and returns a `Shareable` object given a task name.

Key Concepts:
- Executor is a client-side FLComponent for executing tasks
- Produces `Shareable` from input `Shareable` and handles `DXO` object conversion for standardized data passing
- Directly uses FLARE-specific communication concepts, and as such serves as the basis of higher level learning APIs made to abstract these concepts away

See the [documentation](https://nvflare.readthedocs.io/en/main/programming_guide/executor.html#executor) for more information about Executors and other FLARE-specific constructs.

### When to use Executors

The Executor is best used when implementing tasks and logic that do not fit the standard learning methods of higher level APIs such as the ModelLearner or Client API. In this example, in addition to the `train`, `validate`, and `submit_model` tasks, we also introduce the `get_weights` task. This pretrain task allows us to perform the `InitializeGlobalWeights` workflow, which would otherwise not be supported.

## Converting DL training code to FL Executor training code
We will use the original [Training a Classifer](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html) example
in PyTorch as our base [DL code](../code/dl/train.py).

In order to transform the existing PyTorch classifier training code into Federated Classifer training code, we must restructure our code to implement tasks to execute, as well as handle the data exchange formats. The converted code can be found at [FL Executor code](../code/fl/executor.py).

Key changes:
- Encapsulate the original DL train and validate code inside `local_train()` and `local_validate()` and the dataset and PyTorch training utilities in `initialize()`
- Implement `execute` function to handle `get_weights`, `train`, `validate`, and `submit_model` tasks
- Process incoming and outgoing `Shareable` objects, and converting to and from `DXO` objects
- Implement `_save_local_model()` and `_load_local_model()` using the `PTPersistenceManager` to handle `ModelLearnable` object and manage the format for PyTorch model persistence.

```
def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
    try:
        if task_name == self.pre_train_task_name:
            # Get the new state dict and send as weights
            return self._get_model_weights()
        if task_name == self.train_task_name:
            # Get model weights
            try:
                dxo = from_shareable(shareable)
            except:
                self.log_error(fl_ctx, "Unable to extract dxo from shareable.")
                return make_reply(ReturnCode.BAD_TASK_DATA)

            # Ensure data kind is weights.
            if not dxo.data_kind == DataKind.WEIGHTS:
                self.log_error(fl_ctx, f"data_kind expected WEIGHTS but got {dxo.data_kind} instead.")
                return make_reply(ReturnCode.BAD_TASK_DATA)

            # Convert weights to tensor. Run training
            torch_weights = {k: torch.as_tensor(v) for k, v in dxo.data.items()}
            self._local_train(fl_ctx, torch_weights)

            # Check the abort_signal after training.
            if abort_signal.triggered:
                return make_reply(ReturnCode.TASK_ABORTED)

            # Save the local model after training.
            self._save_local_model(fl_ctx)

            # Get the new state dict and send as weights
            return self._get_model_weights()
        if task_name == self.validate_task_name:
            model_owner = "?"
            try:
                try:
                    dxo = from_shareable(shareable)
                except:
                    self.log_error(fl_ctx, "Error in extracting dxo from shareable.")
                    return make_reply(ReturnCode.BAD_TASK_DATA)

                # Ensure data_kind is weights.
                if not dxo.data_kind == DataKind.WEIGHTS:
                    self.log_exception(fl_ctx, f"DXO is of type {dxo.data_kind} but expected type WEIGHTS.")
                    return make_reply(ReturnCode.BAD_TASK_DATA)

                # Extract weights and ensure they are tensor.
                model_owner = shareable.get_header(AppConstants.MODEL_OWNER, "?")
                weights = {k: torch.as_tensor(v, device=self.device) for k, v in dxo.data.items()}

                # Get validation accuracy
                val_accuracy = self._local_validate(fl_ctx, weights)
                if abort_signal.triggered:
                    return make_reply(ReturnCode.TASK_ABORTED)

                self.log_info(
                    fl_ctx,
                    f"Accuracy when validating {model_owner}'s model on"
                    f" {fl_ctx.get_identity_name()}"
                    f"s data: {val_accuracy}",
                )

                dxo = DXO(data_kind=DataKind.METRICS, data={"val_acc": val_accuracy})
                return dxo.to_shareable()
            except:
                self.log_exception(fl_ctx, f"Exception in validating model from {model_owner}")
                return make_reply(ReturnCode.EXECUTION_EXCEPTION)
        elif task_name == self.submit_model_task_name:
            # Load local model
            ml = self._load_local_model(fl_ctx)

            # Get the model parameters and create dxo from it
            dxo = model_learnable_to_dxo(ml)
            return dxo.to_shareable()
        else:
            return make_reply(ReturnCode.TASK_UNKNOWN)
    except Exception as e:
        self.log_exception(fl_ctx, f"Exception in simple trainer: {e}.")
        return make_reply(ReturnCode.EXECUTION_EXCEPTION)
...
```

## Prepare Data

Make sure the CIFAR10 dataset is downloaded with the following script:

In [None]:
! python ../data/download.py --dataset_path /tmp/nvflare/data/cifar10

## Job Configuration

Now we must install the Executor to the training client. We define our CIFAR10Executor in the client configuration, and list the implemented tasks.

Let's first copy the required files over:

In [None]:
! cp ../code/fl/net.py net.py
! cp ../code/fl/executor.py executor.py

We can use Job API to easily create a job and run in simulator:

In [None]:
from net import Net
from executor import CIFAR10Executor

from nvflare.app_opt.pt.job_config.fed_avg import FedAvgJob


if __name__ == "__main__":
    n_clients = 2
    num_rounds = 5

    job = FedAvgJob(
        name="fedavg_executor",
        n_clients=n_clients,
        num_rounds=num_rounds,
        initial_model=Net()
    )

    # Add clients
    for i in range(n_clients):
        executor = CIFAR10Executor()
        job.to(executor, f"site-{i+1}")


    job.export_job("/tmp/nvflare/jobs")
    job.simulator_run("/tmp/nvflare/jobs/workdir", gpu="0")


## Run Job

The previous cell exports the job config and executes the job in NVFlare simulator.

If you want to run in production system, you will need to submit this exported job folder to nvflare system.


For additional resources, take a look at the various other executors with different use cases in the app_common, app_opt, and examples folder.

In the previous examples we have finished covering each of Execution API types: the Client API, Model Learner, and Executor.
Now we will be using the Client API in future examples to highlight other features and workflows.

Next we have the [sag_mlflow](../sag_mlflow/sag_mlflow.ipynb) example, which shows how to enable MLflow experiment tracking logs.