# Federated Learning for Medical Image Analysis

This tutorial demonstrates how to use NVIDIA FLARE for medical image analysis applications, we will work with:
- **Prostate Segmentation Task**: Both 2D (axial slices) and 3D (volumes) segmentation of the prostate in T2-weighted MRIs. For tutorial purpose, we will only illustrate the process with a few images from [**MSD Dataset**](http://medicaldecathlon.com/)), without downloading the full multi-source datasets. Please refer to [advanced example](../../../../advanced/prostate/README.md) for full experiment.
- **[MONAI](https://github.com/Project-MONAI/MONAI)**: A PyTorch-based framework for deep learning in medical imaging applications.

## Setup
First, let's set up our environment with necessary packages

In [None]:
# Install required packages
!pip install nibabel
!pip install --upgrade --no-cache-dir gdown

## Data Preparation
### Download
Let's first set up our directory structure and download the MSD_Prostate dataset.

In [None]:
import os
# Create necessary directories
data_folder='/tmp/nvflare/datasets/MSD/Raw'
os.makedirs(data_folder, exist_ok=True)

In [None]:
!gdown -O '/tmp/nvflare/datasets/MSD/Raw/Task05_Prostate.tar' "1Ff7c21UksxyT4JfETjaarmuKEjdqe1-a&confirm=t" 

In [None]:
!tar xf /tmp/nvflare/datasets/MSD/Raw/Task05_Prostate.tar -C /tmp/nvflare/datasets/MSD/Raw/

### Preprocessing
Now let's first convert our data to the appropriate format. We'll use the provided conversion script to select the T2 channel and convert labels to binary:

In [None]:
# Run conversion scripts
!bash data_conversion.sh

### Datalist Generation
With the prepared data, let's then generate data splits. We'll use a 50 : 25 : 25 split for training : validation : testing.

In [None]:
# Generate data lists
!bash datalists_gen.sh

Let's take a look at the datalist json

In [None]:
# show the content of the datalist json
!cat /tmp/nvflare/datasets/MSD/datalist/client_0.json

## Federated Training

Now that we have prepared our data, we can proceed to the federated training:

The application shown in this example is volumetric (3D) segmentation of the prostate in T2-weighted MRIs based on three datasets that can be split into four clients with comparable sizes.

The [3D U-Net](https://arxiv.org/abs/1606.06650) model is trained to segment the whole prostate region (binary) in a T2-weighted MRI scan. 

![](./figs/Prostate3D.png)

### Use NVFlare simulator to run the experiments
We use NVFlare simulator to run the FL training experiments, following the pattern:
```
nvflare simulator job_configs/[job] -w ${PWD}/workspaces/[job] -c [clients] -gpu [gpu] -t [thread]
```
`[job]` is the experiment job that will be submitted for the FL training, in this example, it is `prostate_fedavg` - FedAvg over prostate segmentation task.  
The combination of `-c` and `-gpu`/`-t` controls the resource allocation. In this example, we run centralized training with single thread, and four clients, each in a separate thread. 

In [None]:
!nvflare simulator prostate_fedavg -w /tmp/nvflare/workspaces/prostate_fedavg -c client_0,client_1,client_2,client_3 -gpu 0,0,0,0 -t 4

For demostration purpose, we only run 3 rounds, let's visualize training curves, increase in validation accuracy can be observed.

In [None]:
%load_ext tensorboard
%tensorboard --logdir /tmp/nvflare/workspaces/prostate_fedavg