# FedAvg with SAG workflow using Model Learner

In this example, we will demonstrate the FegAvg SAG workflow using the CIFAR10 dataset using the ModelLearner API. 

While the previous example [FedAvg with SAG workflow](../sag/sag.ipynb#title) utilized the Client API, here we will demonstrate how to convert the original training code into a ModelLearner trainer, showcase its capabilities, and recommend the best use cases.

For an overview on Federated Averaging and SAG, see the section from the previous example: [Understanding FedAvg and SAG](../sag/sag.ipynb#sag)

## ModelLearner

The main goal of the ModelLearner is to make it easier to write learning logic by minimizing FLARE specific concepts that the user is exposed to. The ModelLearner defines familiar learning functions for training and validation, and uses the FLModel object for transferring learning information.

Key Concepts:
- Learning
    - `FLModel` object defines structure to contain essential information about the learning task, such as `params`, `metrics`, `meta`, etc.
    - learning logic implemented in `train()` and `validate` methods, which both receive and send an `FLModel` object
    - return requested model via `get_model()`
- Lifecycle
    - `initialize` for logic before learning job start and `finalize` for once learning job is finished
    - abort gracefully with `abort()` or `is_aborted()`
- Convenience 
    - various logging methods such as `info`, `debug`, `error`, etc.
    - contextual information availabled in learner


Here are the full definitions of the APIs for the [ModelLearner](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_common/abstract/model_learner.py) and [FLModel](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_common/abstract/fl_model.py).

### When to use ModelLearner

The ModelLearner is best used when working with standard machine learning code that can fit well into the train and validate methods and can be easily adapated to the ModelLearner structure. This allows for the separation of FLARE specific communication constructs from the machine learning specific tasks, and provides the FLModel object for data transfer. 

On the otherhand, if the user would rather not adapt the code structure, we recommend using the [Client API](https://github.com/NVIDIA/NVFlare/blob/main/examples/hello-world/ml-to-fl/README.md) for even simpler conversion to FL code at the cost of losing some convenience functionalities.

Finally, if the user wishes to implement something more specific that is not supported by the ModelLearner, we recommend writing an Executor which gives greater freedom for defining logic and tasks. The main tradeoff is this requires the use of more FLARE concepts such as FLContext, Shareable, DXO, etc.


## Converting DL training code to FL ModelLearner training code
We will use the original [Training a Classifer](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html) example
in PyTorch as our base [DL code](../code/dl/train.py).

With the FLARE ModelLearner API, we need to transform the existing PyTorch classifer into a Federated classifer by restructuring our code to subclass ModelLearner, and implementing the required methods. The converted code can be found at [FL ModelLearner code](../code/fl/model_learner.py).

Key Changes:
- Subclass ModelLearner with appropriate init args
- Encapsulate the original DL train and validate code inside `local_train()` and `local_validate()` and the dataset and PyTorch training utilities in `initialize()`
- Implement the `train()` and `validate()` methods by wrapping the local learning methods and processing and returning `FLModel`
- Implement `get_model()` method to load and return best local model, so it can then be sent to other sites for validation (via the cross-site evaluation workflow)

```
def get_model(self, model_name: str) -> Union[str, FLModel]:
    # Retrieve the best local model saved during training.
    if model_name == ModelName.BEST_MODEL:
        try:
            model_data = torch.load(self.model_path, map_location="cpu")
            np_model_data = {k: v.cpu().numpy() for k, v in model_data.items()}

            return FLModel(params_type=ParamsType.FULL, params=np_model_data)
        except Exception as e:
            raise ValueError("Unable to load best model") from e
    else:
        raise ValueError(f"Unknown model_type: {model_name}")  # Raised errors are caught in LearnerExecutor class.

def train(self, model: FLModel) -> Union[str, FLModel]:
    self.info(f"Current/Total Round: {self.current_round + 1}/{self.total_rounds}")
    self.info(f"Client identity: {self.site_name}")

    pt_input_params = {k: torch.as_tensor(v) for k, v in model.params.items()}
    self._local_train(pt_input_params)

    pt_output_params = {k: torch.as_tensor(v) for k, v in self.net.cpu().state_dict().items()}
    accuracy = self._local_validate(pt_output_params)

    if accuracy > self.best_acc:
        self.best_acc = accuracy
        torch.save(self.net.state_dict(), self.model_path)

    np_output_params = {k: v.cpu().numpy() for k, v in self.net.cpu().state_dict().items()}
    return FLModel(
        params=np_output_params,
        metrics={"accuracy": accuracy},
        meta={"NUM_STEPS_CURRENT_ROUND": 2 * len(self.trainloader)},
    )

def validate(self, model: FLModel) -> Union[str, FLModel]:
    pt_params = {k: torch.as_tensor(v) for k, v in model.params.items()}
    val_accuracy = self._local_validate(pt_params)

    return FLModel(metrics={"val_accuracy": val_accuracy})

...
    
```

## Job Configuration

Now we must install the ModelLearner to the training client. We use the predefined `ModelLearnerExecutor`, which handles setting up the Learner and executing the tasks using the ModelLearner methods. In the client configuration, the `learner_id` of the `ModelLearnerExecutor` is mapped to the `id` of the ModelLearner trainer component that we implemented.

Let's use the Job CLI to create the job from a ModelLearner template:

In [None]:
! nvflare job create -j /tmp/nvflare/jobs/sag_pt_model_learner -w sag_pt_model_learner -sd ../code/fl -force

We can take a look at the server and client configurations and make any changes as desired:

In [None]:
! cat /tmp/nvflare/jobs/sag_pt_model_learner/app/config/config_fed_server.conf

In [None]:
! cat /tmp/nvflare/jobs/sag_pt_model_learner/app/config/config_fed_client.conf

Ensure that our ModelLearner trainer code is correctly installed with the ModelLearnerExecutor. Also since the ModelLearnerExecutor supports the train, validate, and submit_model tasks, we can use the CrossSiteModelEval workflow (see [CSE example](../cse/cse.ipynb)) in the server configuration in addition to the ScatterAndGather workflow.

## Prepare Data

Make sure the CIFAR10 dataset is downloaded with the following script:

In [None]:
! python ../data/download.py

## Run the Job

Now we can run the job with the simulator:

In [None]:
! nvflare simulator /tmp/nvflare/jobs/sag_pt_model_learner -w /tmp/nvflare/sag_pt_model_learner -t 2 -n 2 

As an additional resource, also see the [CIFAR10 examples](../../../../advanced/cifar10/README.md) for a comprehensive implementation of a PyTorch ModelLearner.