# Getting Started with NVFlare (Lightning)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/NVFlare/blob/main/examples/getting_started/pt/nvflare_lightning_getting_started.ipynb)

NVFlare is an open-source framework that allows researchers and
data scientists to seamlessly move their machine learning and deep
learning workflows into a federated paradigm.

## Basic Concepts
At the heart of NVFlare lies the concept of collaboration through "tasks." An FL controller assigns tasks (e.g., training on local data) to one or more FL clients and processes returned results (e.g., model weight updates). Based on these results and other factors (e.g., a pre-configured number of training rounds), the controller may assign additional tasks.

<img src="../../../docs/resources/controller_executor_no_filter.png" alt="NVIDIA FLARE Controller and Executor" width=75% height=75% />

## Setup Environment

Install NVFlare and dependencies:

In [None]:
! pip install nvflare~=2.7.0rc pytorch_lightning

If running in Google Colab, download the source code for this example.

## Federated Averaging with NVFlare
Given the flexible controller and executor concepts, it is easy to implement different computing and communication patterns with NVFlare, such as [FedAvg](https://proceedings.mlr.press/v54/mcmahan17a?ref=https://githubhelp.com) and [cyclic weight transfer](https://academic.oup.com/jamia/article/25/8/945/4956468).

The controller's `run()` routine is responsible for assigning tasks and processing task results from the executors.

### Server Code
First, we provide a simple implementation of the [FedAvg](https://proceedings.mlr.press/v54/mcmahan17a?ref=https://githubhelp.com) algorithm with NVFlare. The `run()` routine implements the main algorithmic logic. Subroutines like `sample_clients()` and `send_model_and_wait()` utilize the communicator object native to each controller to get the list of available clients, distribute the current global model to the clients, and collect their results.

The FedAvg controller implements these main steps:
1. The FL server initializes a model using `self.load_model()`.
2. For each round (global iteration):
    - The FL server samples available clients using `self.sample_clients()`.
    - The FL server sends the global model to clients and waits for their updates using `self.send_model_and_wait()`.
    - The FL server aggregates all the `results` and produces a new global model using `self.update_model()`.

```python
class FedAvg(BaseFedAvg):
    def run(self) -> None:
        self.info("Start FedAvg.")

        model = self.load_model()
        model.start_round = self.start_round
        model.total_rounds = self.num_rounds

        for self.current_round in range(self.start_round, self.start_round + self.num_rounds):
            self.info(f"Round {self.current_round} started.")
            model.current_round = self.current_round

            clients = self.sample_clients(self.num_clients)

            results = self.send_model_and_wait(targets=clients, data=model)

            aggregate_results = self.aggregate(results)

            model = self.update_model(model, aggregate_results)

            self.save_model(model)

        self.info("Finished FedAvg.")
```

### Client Code 
Given a CIFAR10 [PyTorch Lightning](https://lightning.ai/) code example with the network wrapped in a `LitNet` LightningModule class (see [model.py](./model.py)), we will adapt this centralized training code to run in a federated setting.

On the client side, the training workflow is as follows:
1. Receive the model from the FL server.
2. Perform local training on the received global model and/or evaluate the model for model selection.
3. Send the updated model back to the FL server.

NVFlare's Client Lightning API makes this adaptation simple with the `patch()` function:
- `flare.patch(trainer)`: Patches the Lightning trainer. After calling `flare.patch()`, methods like `trainer.fit()` and `trainer.validate()` will automatically retrieve the global model from the FL server and send back the updated model after training.


With this method, developers can adapt their centralized training code for federated learning with these simple changes:
```python
    # (1) import nvflare lightning client API
    import nvflare.client.lightning as flare

    # (2) patch the lightning trainer
    flare.patch(trainer)

    while flare.is_running():
        
        # Note that we can optionally receive the FLModel from NVFLARE.
        # We don't need to pass this input_model to the trainer because after flare.patch(), the trainer.fit()/validate() methods will get the global model internally
        input_model = flare.receive()

        trainer.validate(...)

        trainer.fit(...)

        trainer.test(...)

        trainer.predict(...)
```

The full client training script is saved in [client.py](./client.py), which performs CNN training on the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset.

## Run an NVFlare Job
Now that we have defined the FedAvg controller to run our federated compute workflow on the FL server and our client training script to receive global models, run local training, and send results back to the FL server, we can put everything together using NVFlare's Job Recipe API.

#### 1. Define the Initial Model
First, we define the global model used to initialize training on the FL server. See [model.py](./model.py).

```python
from typing import Any

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from pytorch_lightning import LightningModule
from torchmetrics import Accuracy

NUM_CLASSES = 10
criterion = nn.CrossEntropyLoss()

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class LitNet(LightningModule):
    def __init__(self):
        super().__init__()
        self.save_hyperparameters()
        self.model = Net()
        self.train_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES)
        self.valid_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES)
        # (optional) pass additional information via self.__fl_meta__
        self.__fl_meta__ = {}

    def forward(self, x):
        out = self.model(x)
        return out

    def training_step(self, batch, batch_idx):
        x, labels = batch
        outputs = self(x)
        loss = criterion(outputs, labels)
        self.train_acc(outputs, labels)
        self.log("train_loss", loss)
        self.log("train_acc", self.train_acc, on_step=True, on_epoch=False)
        return loss

    def evaluate(self, batch, stage=None):
        x, labels = batch
        outputs = self(x)
        loss = criterion(outputs, labels)
        self.valid_acc(outputs, labels)

        if stage:
            self.log(f"{stage}_loss", loss)
            self.log(f"{stage}_acc", self.valid_acc, on_step=True, on_epoch=True)
        return outputs

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")

    def test_step(self, batch, batch_idx):
        self.evaluate(batch, "test")

    def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
        return self.evaluate(batch)

    def configure_optimizers(self):
        optimizer = optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
        return {"optimizer": optimizer}
```

#### 2. Job Recipe


```python

from nvflare.app_opt.pt.recipes.fedavg import FedAvgRecipe
from nvflare.recipe.sim_env import SimEnv
import torchvision.datasets as datasets

DATASET_ROOT = "/tmp/nvflare/data"

# Download data
datasets.CIFAR10(root=DATASET_ROOT, train=True, download=True)
datasets.CIFAR10(root=DATASET_ROOT, train=False, download=True)
n_clients  = 2
num_rounds = 2
batch_size = 24

recipe = FedAvgRecipe(
        min_clients=n_clients,
        num_rounds=num_rounds,
        initial_model=LitNet(),
        train_script="client.py",
        train_args=f"--batch_size {batch_size}",
)

env = SimEnv(num_clients=n_clients, num_threads=n_clients)
recipe.execute(env=env)
```

#### 3. Execute Recipe
You can execute the same recipe in different environments: `SimEnv`, `PoCEnv`, or `ProdEnv`. Here we will run it in the simulation environment.

In [None]:
! python job.py
