# Getting Started with NVFlare (Lightning)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/NVFlare/blob/main/examples/getting_started/pt/nvflare_lightning_getting_started.ipynb)

NVFlare is an open-source framework that allows researchers and
data scientists to seamlessly move their machine learning and deep
learning workflows into a federated paradigm.

## Basic Concepts
At the heart of NVFlare lies the concept of collaboration through
"tasks." An FL controller assigns tasks (e.g., training on local data) to one or more FL clients, processes returned
results (e.g., model weight updates), and may assign additional
tasks based on these results and other factors (e.g., a pre-configured
number of training rounds). The clients run executors which can listen for tasks and perform the necessary computations locally, such as model training. This task-based interaction repeats
until the experiment’s objectives are met. 

<img src="../../../docs/resources/controller_executor_no_filter.png" alt="NVIDIA FLARE Controller and Executor" width=75% height=75% />

## Setup environment

Install nvflare and dependencies:

In [None]:
! pip install --ignore-installed blinker
! pip install nvflare~=2.5.0rc pytorch_lightning

If running in Google Colab, download the source code for this example:

In [None]:
! npx degit NVIDIA/NVFlare/examples/getting_started/pt/src src

## Federated Averaging with NVFlare
Given the flexible controller and executor concepts, it is easy to implement different computing & communication patterns with NVFlare, such as [FedAvg](https://proceedings.mlr.press/v54/mcmahan17a?ref=https://githubhelp.com) and [cyclic weight transfer](https://academic.oup.com/jamia/article/25/8/945/4956468). 

The controller's `run()` routine is responsible for assigning tasks and processing task results from the Executors. 

### Server Code
First, we provide a simple implementation of the [FedAvg](https://proceedings.mlr.press/v54/mcmahan17a?ref=https://githubhelp.com) algorithm with NVFlare. 
The `run()` routine implements the main algorithmic logic. 
Subroutines, like `sample_clients()` and `scatter_and_gather_model()` utilize the communicator object, native to each Controller to get the list of available clients,
distribute the current global model to the clients, and collect their results.

The FedAvg controller implements these main steps:
1. FL server initializes an initial model using `self.load_model()`.
2. For each round (global iteration):
    - FL server samples available clients using `self.sample_clients()`.
    - FL server sends the global model to clients and waits for their updates using `self.send_model_and_wait()`.
    - FL server aggregates all the `results` and produces a new global model using `self.update_model()`.

```python
class FedAvg(BaseFedAvg):
    def run(self) -> None:
        self.info("Start FedAvg.")

        model = self.load_model()
        model.start_round = self.start_round
        model.total_rounds = self.num_rounds

        for self.current_round in range(self.start_round, self.start_round + self.num_rounds):
            self.info(f"Round {self.current_round} started.")
            model.current_round = self.current_round

            clients = self.sample_clients(self.num_clients)

            results = self.send_model_and_wait(targets=clients, data=model)

            aggregate_results = self.aggregate(results)

            model = self.update_model(model, aggregate_results)

            self.save_model(model)

        self.info("Finished FedAvg.")
```

### Client Code 
Given a CIFAR10 [PyTorch Lightning](https://lightning.ai/) code example with the network wrapped into a LightningModule [LitNet](src/lit_net.py) class, we wish to adapt this centralized training code to something that can run in a federated setting.

On the client side, the training workflow is as follows:
1. Receive the model from the FL server.
2. Perform local training on the received global model
and/or evaluate the received global model for model
selection.
3. Send the new model back to the FL server.

Using NVFlare's Client Lightning API, we can easily adapt machine learning code that was written for centralized training and apply it in a federated scenario.
For general use cases, we can use the Client Lightning API patch function:
- `flare.patch(trainer)`: Patch the lightning trainer. After flare.patch, functions such as `trainer.fit()` and `trainer.validate()` will get the global model internally and automatically send the result model to the FL server.

With this method, the developers can use the Client API
to change their centralized training code to an FL scenario with
these simple code changes shown below.
```python
    # (1) import nvflare lightning client API
    import nvflare.client.lightning as flare

    # (2) patch the lightning trainer
    flare.patch(trainer)

    while flare.is_running():
        
        # Note that we can optionally receive the FLModel from NVFLARE.
        # We don't need to pass this input_model to trainer because after flare.patch 
        # the trainer.fit/validate will get the global model internally
        input_model = flare.receive()

        trainer.validate(...)

        trainer.fit(...)

        trainer.test(...)

        trainer.predict(...)
```

The full client training script is saved in a separate file, e.g. [./src/cifar10_lightning_fl.py](./src/cifar10_lightning_fl.py) doing CNN training on the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset.

## Run an NVFlare Job
Now that we have defined the FedAvg controller to run our federated compute workflow on the FL server, and our client training script to receive the global models, run local training, and send the results back to the FL server, we can put everything together using NVFlare's Job API.

#### 1. Define the initial model
First, we define the global model used to initialize the model on the FL server. See [src/lit_net.py](src/lit_net.py).

```python
from typing import Any

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from pytorch_lightning import LightningModule
from torchmetrics import Accuracy

NUM_CLASSES = 10
criterion = nn.CrossEntropyLoss()

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class LitNet(LightningModule):
    def __init__(self):
        super().__init__()
        self.save_hyperparameters()
        self.model = Net()
        self.train_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES)
        self.valid_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES)
        # (optional) pass additional information via self.__fl_meta__
        self.__fl_meta__ = {}

    def forward(self, x):
        out = self.model(x)
        return out

    def training_step(self, batch, batch_idx):
        x, labels = batch
        outputs = self(x)
        loss = criterion(outputs, labels)
        self.train_acc(outputs, labels)
        self.log("train_loss", loss)
        self.log("train_acc", self.train_acc, on_step=True, on_epoch=False)
        return loss

    def evaluate(self, batch, stage=None):
        x, labels = batch
        outputs = self(x)
        loss = criterion(outputs, labels)
        self.valid_acc(outputs, labels)

        if stage:
            self.log(f"{stage}_loss", loss)
            self.log(f"{stage}_acc", self.valid_acc, on_step=True, on_epoch=True)
        return outputs

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")

    def test_step(self, batch, batch_idx):
        self.evaluate(batch, "test")

    def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
        return self.evaluate(batch)

    def configure_optimizers(self):
        optimizer = optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
        return {"optimizer": optimizer}
```

#### 2. Define a FedJob
The `FedJob` is used to define how controllers and executors are placed within a federated job using the `to(object, target)` routine.

In [None]:
from nvflare import FedJob
from nvflare.app_common.executors.script_executor import ScriptExecutor
from nvflare.app_common.workflows.fedavg import FedAvg

job = FedJob(name="cifar10_fedavg_lightning")

#### 3. Define the Controller Workflow
Define the controller workflow and send to server.

In [None]:
n_clients = 2

controller = FedAvg(
    num_clients=n_clients,
    num_rounds=2,
)
job.to(controller, "server")

#### 4. Create Global Model
Now, we create the initial global model and send to server.

In [None]:
from src.lit_net import LitNet
from nvflare.app_opt.pt.job_config.model import PTModel

job.to(PTModel(LitNet()), "server")

That completes the components that need to be defined on the server.

#### 5. Add clients
Next, we can use the `ScriptExecutor` and send it to each of the clients to run our training script.

Note that our script could have additional input arguments, such as batch size or data path, but we don't use them here for simplicity.
We can also specify, which GPU should be used to run this client, which is helpful for simulated environments.

In [None]:
for i in range(n_clients):
    executor = ScriptExecutor(
        task_script_path="src/cifar10_lightning_fl.py", task_script_args=""  # f"--batch_size 32 --data_path /tmp/data/site-{i}"
    )
    job.to(executor, f"site-{i+1}")

That's it!

#### 6. Optionally export the job
Now, we could export the job and submit it to a real NVFlare deployment using the [Admin client](https://nvflare.readthedocs.io/en/main/real_world_fl/operation.html) or [FLARE API](https://nvflare.readthedocs.io/en/main/real_world_fl/flare_api.html). 

In [None]:
job.export_job("/tmp/nvflare/jobs/job_config")

#### 7. Run FL Simulation
Finally, we can run our FedJob in simulation using NVFlare's [simulator](https://nvflare.readthedocs.io/en/main/user_guide/nvflare_cli/fl_simulator.html) under the hood. The results will be saved in the specified `workdir`.

In [None]:
job.simulator_run("/tmp/nvflare/jobs/workdir", gpu="0")

If using Google Colab and the output is not showing correctly, export the job and run it with the simulator command instead:

In [None]:
! nvflare simulator -w /tmp/nvflare/jobs/workdir -n 2 -t 2 -gpu 0 /tmp/nvflare/jobs/job_config/cifar10_fedavg_lightning