# Getting Started with NVFlare (PyTorch)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/NVFlare/blob/main/examples/hello-world/hello-pt/hello-pt.ipynb)

NVFlare is an open-source framework that allows researchers and
data scientists to seamlessly move their machine learning and deep
learning workflows into a federated paradigm.

## Basic Concepts
At the heart of NVFlare lies the concept of collaboration through
"tasks." An FL controller assigns tasks (e.g., training on local data) to one or more FL clients, processes returned
results (e.g., model weight updates), and may assign additional
tasks based on these results and other factors (e.g., a pre-configured
number of training rounds). The clients run executors which can listen for tasks and perform the necessary computations locally, such as model training. This task-based interaction repeats
until the experimentâ€™s objectives are met. 

<img src="../../../docs/resources/controller_executor_no_filter.png" alt="NVIDIA FLARE Controller and Executor" width=75% height=75% />

## Setup environment

If running in Google Colab, download the source code for this example:

In [None]:
%pip install --ignore-installed blinker

In [None]:
! npx --yes degit -f NVIDIA/NVFlare/examples/hello-world/hello-pt .

Install nvflare and dependencies:

In [None]:
%pip install -r requirements.txt

> **Note:** Depending on the number of clients, you might run into errors if several clients try to download the CIFAR-10 dataset at the same time. If this happens, try pre-downloading the dataset first or reducing the number of concurrent clients.


## Federated Averaging with NVFlare
Given the flexible controller and executor concepts, it is easy to implement different computing & communication patterns with NVFlare, such as [FedAvg](https://proceedings.mlr.press/v54/mcmahan17a?ref=https://githubhelp.com) and [cyclic weight transfer](https://academic.oup.com/jamia/article/25/8/945/4956468). 

The controller's `run()` routine is responsible for assigning tasks and processing task results from the Executors. 

### Server-Side Workflow

This example uses the [FedAvgRecipe](https://nvflare.readthedocs.io/en/main/apidocs/nvflare.app_opt.pt.recipes.fedavg.html), which implements the [FedAvg](https://proceedings.mlr.press/v54/mcmahan17a?ref=https://githubhelp.com) algorithm. The Recipe API handles all server-side logic automatically:

1. Initialize the global model
2. For each training round:
   - Sample available clients
   - Send the global model to selected clients
   - Wait for client updates
   - Aggregate client models into a new global model

With the Recipe API, **there is no need to write custom server code**. The federated averaging workflow is provided by NVFlare using the `ScatterAndGather` controller.

### Client Code 
We take a CIFAR-10 example directly from [PyTorch website](https://github.com/pytorch/tutorials/blob/main/beginner_source/blitz/cifar10_tutorial.py) with some minor modifications, such as removing comments, moving the network to [model.py](model.py), and adding a main method and GPU support.

Now, we need to adapt this centralized training code to something that can run in a federated setting.

On the client side, the training workflow is as follows:
1. Receive the model from the FL server.
2. Perform local training on the received global model
and/or evaluate the received global model for model
selection.
3. Send the new model back to the FL server.

Using NVFlare's client API, we can easily adapt machine learning code that was written for centralized training and apply it in a federated scenario.
For a general use case, there are three essential methods to achieve this using the Client API :
- `init()`: Initializes NVFlare Client API environment.
- `receive()`: Receives model from the FL server.
- `send()`: Sends the model to the FL server.

With these simple methods, developers can use the Client API to change their centralized training code to an FL scenario with just a few lines of code changes as shown below.

```python
import nvflare.client as flare

flare.init()  # 1. Initialize NVFlare Client API environment
input_model = flare.receive()  # 2. Receive model from the FL server
params = input_model.params  # 3. Extract parameters from the received model

# (optional) Handle cross-site evaluation tasks
if flare.is_evaluate():
    accuracy = evaluate(params)  # Evaluate the model
    flare.send(flare.FLModel(metrics={"accuracy": accuracy}))
else:
    # Original local training code
    new_params = local_train(params)
    
    output_model = flare.FLModel(params=new_params)  # 4. Package results in FLModel
    flare.send(output_model)  # 5. Send the model to the FL server
```

The full client training script is saved in [client.py](client.py), which performs CNN training on the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset.

## Run an NVFlare Job
Now that we have defined the FedAvg controller to run our federated compute workflow on the FL server, and our client training script to receive the global models, run local training, and send the results back to the FL server, we can put everything together using NVFlare's Job API.

#### 1. Define the initial model
First, we define the global model used to initialize the model on the FL server. See [model.py](model.py).

This `SimpleNetwork` is a convolutional neural network (CNN) with:
- Two convolutional layers (`conv1`, `conv2`) for feature extraction
- Max pooling for dimensionality reduction
- Three fully connected layers (`fc1`, `fc2`, `fc3`) for classification into 10 CIFAR-10 classes


```python
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleNetwork(nn.Module):
    def __init__(self):
        super(SimpleNetwork, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
```

#### 2. Define a FedJob Recipe
 
 

In [None]:
from model import SimpleNetwork

from nvflare.app_opt.pt.recipes.fedavg import FedAvgRecipe
from nvflare.recipe import SimEnv, add_experiment_tracking
from nvflare.recipe.utils import add_cross_site_evaluation

n_clients = 2
num_rounds = 2
batch_size = 16

recipe = FedAvgRecipe(
    name="hello-pt",
    min_clients=n_clients,
    num_rounds=num_rounds,
    initial_model=SimpleNetwork(),
    train_script="client.py",
    train_args=f"--batch_size {batch_size}",
)

#### 3. Add experiment tracking

In [None]:
add_experiment_tracking(recipe, tracking_type="tensorboard")

#### 4. (Optional) Add Cross-Site Evaluation

To evaluate trained models across all client sites after training, you can add cross-site evaluation:

```python
# Uncomment to enable cross-site evaluation
# add_cross_site_evaluation(recipe)
```

This will run an additional evaluation phase after training completes, where each client evaluates models from all sites. The framework is auto-detected from the recipe.


#### 5. Run Job
Here, we run the job in a simulation environment.

In [None]:
env = SimEnv(num_clients=n_clients)
run = recipe.execute(env)
print()
print("Job Status is:", run.get_status())
print("Result can be found in :", run.get_result())
print()

#### 6. Visualize the Training

TensorBoard will show training metrics collected from each client, including:
- Training loss curves over time
- Per-client and aggregated metrics  
- Comparison across different rounds

You can launch TensorBoard by running:

```bash
tensorboard --bind_all --logdir /tmp/nvflare/simulation/hello-pt
```
in another terminal, or directly show the training curves in the next notebook cell.

If you enabled cross-site evaluation, you can view the validation results with:
```python
import json
with open('/tmp/nvflare/simulation/hello-pt/server/simulate_job/cross_site_val/cross_val_results.json') as f:
    print(json.dumps(json.load(f), indent=2))
```

In [None]:
%load_ext tensorboard
%tensorboard --bind_all --logdir /tmp/nvflare/simulation/hello-pt