# Federated Learning for Medical Image Analysis

This tutorial demonstrates how to use NVIDIA FLARE for medical image analysis applications. For local training on medical images, we will use **[MONAI](https://github.com/Project-MONAI/MONAI)**, a PyTorch-based framework for deep learning in medical imaging applications. We will work with two tasks:

- **MedNIST Classification Task**: a 2D classification task on medical images, this dataset was gathered from several sets from [TCIA](https://wiki.cancerimagingarchive.net/display/Public/Data+Usage+Policies+and+Restrictions),
[the RSNA Bone Age Challenge](http://rsnachallenges.cloudapp.net/competitions/4),
and [the NIH Chest X-ray dataset](https://cloud.google.com/healthcare/docs/resources/public-datasets/nih-chest). The dataset is kindly made available by [Dr. Bradley J. Erickson M.D., Ph.D.](https://www.mayo.edu/research/labs/radiology-informatics/overview) (Department of Radiology, Mayo Clinic)
under the Creative Commons [CC BY-SA 4.0 license](https://creativecommons.org/licenses/by-sa/4.0/).

A 2D Densenet will be trained to classify each image into its corresponding classes, some example images are as below:
![](./figs/MedNIST.png)

- **Prostate Segmentation Task**: a 3D segmentation of the prostate in T2-weighted MRIs. For tutorial purpose, we will only illustrate the process with a few images from [**MSD Dataset**](http://medicaldecathlon.com/)), without downloading the full multi-source datasets. Please refer to [advanced example](../../../../advanced/prostate/README.md) for the full experiment.

The [3D U-Net](https://arxiv.org/abs/1606.06650) model is trained to segment the whole prostate region (binary) in a T2-weighted MRI scan. 

![](./figs/Prostate3D.png)




## General Federated Learning Flow with MONAI 

In this example, the **server** uses the [`FedAvg`](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_common/workflows/fedavg.py) controller, which performs the following steps:
1. Initialize the global model. This is achieved through the method `load_model()`
  from the base class
  [`ModelController`](https://github.com/NVIDIA/NVFlare/blob/fa4d00f76848fe4eb356dcde417c136047eeab36/nvflare/app_common/workflows/model_controller.py#L292),
  which relies on the
  [`ModelPersistor`](https://nvflare.readthedocs.io/en/main/glossary.html#persistor). 
2. During each training round, the global model will be sent to the
  list of participating clients to perform a training task. This is
  done using the
  [`send_model()`](https://github.com/NVIDIA/NVFlare/blob/d6827bca96d332adb3402ceceb4b67e876146067/nvflare/app_common/workflows/model_controller.py#L99)
  method under the hood from the `ModelController` base class. Once
  the clients finish their local training, results will be collected
  and sent back to the server as an [`FLModel`](https://nvflare.readthedocs.io/en/main/programming_guide/fl_model.html#flmodel)s.
3. Results sent by clients will be aggregated based on the
  [`WeightedAggregationHelper`](https://github.com/NVIDIA/NVFlare/blob/fa4d00f76848fe4eb356dcde417c136047eeab36/nvflare/app_common/aggregators/weighted_aggregation_helper.py#L20),
  which weighs the contribution from each client based on the number
  of local training samples. The aggregated updates are
  returned as a new `FLModel`.
5. After getting the aggregated results, the global model is [updated](https://github.com/NVIDIA/NVFlare/blob/724140e7dc9081eca7a912a818817f89aadfef5d/nvflare/app_common/workflows/fedavg.py#L63).
6. The last step is to save the updated global model, again through
  the [`ModelPersistor`](https://nvflare.readthedocs.io/en/main/glossary.html#persistor).

The **clients** implement the local training logic using NVFlare's [Client
API](https://nvflare.readthedocs.io/en/main/programming_guide/execution_api_type.html#client-api)
[here](./code/monai_mednist_train.py). The Client API
allows the user to add minimum `nvflare`-specific codes to turn a typical
centralized training script to a federated client-side local training
script.
1. During local training, each client receives a copy of the global
  model sent by the server using `flare.receive()` API. The received
  global model is an instance of `FLModel`. Integration with MONAI Trainer will handle the local training and validation.
2. A local validation is first performed, where validation metrics
  (accuracy and precision) are streamed to server using the
  [`SummaryWriter`](https://nvflare.readthedocs.io/en/main/apidocs/nvflare.client.tracking.html#nvflare.client.tracking.SummaryWriter). The
  streamed metrics can be loaded and visualized using [TensorBoard](https://www.tensorflow.org/tensorboard) or [MLflow](https://mlflow.org/).
3. Then, each client performs local training as in the non-federated training [notebook](./monai_101.ipynb). At the end of each FL round, each client then sends the computed results (always in
  `FLModel` format) to the server for aggregation, using the `flare.send()`
  API.


## Setup
First, let's set up our environment with necessary packages

In [None]:
# Install required packages
!pip install nibabel mlflow
!pip install "monai-weekly[ignite, tqdm]"
!pip install --upgrade --no-cache-dir gdown

## Part 1: MedNIST
For MedNIST experiment, everything will be handled by MONAI, including data download, we let MONAI create temp folder and files. 

### MedNIST Training Script
We can notice the part of MONAI code handling data, preprocessing, network definition, and training within the [local training code](./mednist_fedavg/app/custom/monai_mednist_train.py):

- Data download:
```
root_dir = tempfile.mkdtemp()
print(root_dir)
dataset = MedNISTDataset(root_dir=root_dir, transform=transform, section="training", download=True)
```
- Preprocessing:
```
transform = Compose(
    [
        LoadImageD(keys="image", image_only=True),
        EnsureChannelFirstD(keys="image"),
        ScaleIntensityD(keys="image"),
    ]
)
```
- Network definition:
```
model = densenet121(spatial_dims=2, in_channels=1, out_channels=6).to(DEVICE)
```
- Trainer definition:
```
trainer = SupervisedTrainer(
        device=torch.device(DEVICE),
        max_epochs=max_epochs,
        train_data_loader=train_loader,
        network=model,
        optimizer=torch.optim.Adam(model.parameters(), lr=1e-5),
        loss_function=torch.nn.CrossEntropyLoss(),
        inferer=SimpleInferer(),
        train_handlers=StatsHandler(),
    )
```

### Use NVFlare JobAPI to run the federated experiments
We use NVFlare [JobAPI](https://github.com/NVIDIA/NVFlare/blob/main/examples/advanced/job_api/pt/README.md) to run the FL training experiments.

In [None]:
from src.densenet import DenseNet121

from nvflare.app_opt.pt.job_config.fed_avg import FedAvgJob
from nvflare.job_config.script_runner import ScriptRunner

if __name__ == "__main__":
    n_clients = 2
    num_rounds = 5
    train_script = "src/monai_mednist_train.py"

    job = FedAvgJob(
        name="mednist_fedavg",
        n_clients=n_clients,
        num_rounds=num_rounds,
        initial_model=DenseNet121(spatial_dims=2, in_channels=1, out_channels=6),
    )

    # Add clients
    executor = ScriptRunner(script=train_script, script_args="")
    job.to_clients(executor)

    job.export_job("/tmp/nvflare/jobs/")
    job.simulator_run("/tmp/nvflare/workspaces/mednist_fedavg", n_clients=n_clients, gpu="0")

### Training result visualization:
Let's visualize training curves.

In [None]:
%load_ext tensorboard
%tensorboard --logdir /tmp/nvflare/workspaces/mednist_fedavg

## Part 2 Prostate
Second task is more practical - 3D segmentation.

### Data download
Let's first set up our directory structure and download the MSD_Prostate dataset.

In [None]:
import os
# Create necessary directories
data_folder='/tmp/nvflare/datasets/MSD/Raw'
os.makedirs(data_folder, exist_ok=True)

In [None]:
!gdown -O '/tmp/nvflare/datasets/MSD/Raw/Task05_Prostate.tar' "1Ff7c21UksxyT4JfETjaarmuKEjdqe1-a&confirm=t" 

In [None]:
!tar xf /tmp/nvflare/datasets/MSD/Raw/Task05_Prostate.tar -C /tmp/nvflare/datasets/MSD/Raw/

### Preprocessing
Now let's first convert our data to the appropriate format. We'll use the provided conversion script to select the T2 channel and convert labels to binary:

In [None]:
# Run conversion scripts
!bash data_conversion.sh

### Datalist Generation
With the prepared data, let's then generate data splits. We'll use a 50 : 25 : 25 split for training : validation : testing.

In [None]:
# Generate data lists
!bash datalists_gen.sh

Let's take a look at the datalist json

In [None]:
# show the content of the datalist json
!cat /tmp/nvflare/datasets/MSD/datalist/site-1.json

### Federated Training

Now that we have prepared our data, we can proceed to the federated training:

In [None]:
from src.learners.supervised_monai_prostate_learner import SupervisedMonaiProstateLearner
from src.unet import UNet

from nvflare.app_common.executors.learner_executor import LearnerExecutor
from nvflare.app_opt.pt.job_config.fed_avg import FedAvgJob

if __name__ == "__main__":
    n_clients = 4
    num_rounds = 3
    train_script = "src/monai_mednist_train.py"

    model = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=[16, 32, 64, 128, 256],
        strides=[2, 2, 2, 2],
        num_res_units=2,
    )

    job = FedAvgJob(name="prostate_fedavg", n_clients=n_clients, num_rounds=num_rounds, initial_model=model)

    # Add clients
    learner = SupervisedMonaiProstateLearner(
        train_config_filename="../custom/src/config/config_train.json", aggregation_epochs=10
    )
    job.to_clients(learner, id="prostate-learner")
    executor = LearnerExecutor(learner_id="prostate-learner")
    job.to_clients(executor)
    job.to_clients("src/config/config_train.json")

    job.export_job("/tmp/nvflare/jobs/")
    job.simulator_run("/tmp/nvflare/workspaces/prostate_fedavg", n_clients=n_clients, gpu="0")

For demostration purpose, we only run 3 rounds, let's visualize training curves, increase in validation accuracy can be observed.

In [None]:
%load_ext tensorboard
%tensorboard --logdir /tmp/nvflare/workspaces/prostate_fedavg