# Federated Monai MedNIST Example 

This demo uses to demonstrate federated learning training and validation in the case of 2D medical image registration.

Based on MONAI [registration_mednist.ipynb](https://github.com/Project-MONAI/tutorials/blob/master/2d_registration/registration_mednist.ipynb) notebook and [OpenFL](https://github.com/intel/openfl) - federated learning framework.

In [None]:
# install workspace requirements
! pip install -r workspace_requirements.txt

In [None]:
import numpy as np
import torch
import tqdm
from monai.config import USE_COMPILED
from openfl.interface.interactive_api.experiment import (
    DataInterface,
    FLExperiment,
    ModelInterface,
    TaskInterface,
)
from openfl.interface.interactive_api.federation import Federation

## Connect to the Federation

In [None]:
# Create a federation

# please use the same identificator that was used in signed certificate
client_id = "api"
cert_dir = "cert"
director_node_fqdn = "localhost"
director_port = 50051
# 1) Run with API layer - Director mTLS
# If the user wants to enable mTLS their must provide CA root chain, and signed key pair to the federation interface
# cert_chain = f'{cert_dir}/root_ca.crt'
# api_certificate = f'{cert_dir}/{client_id}.crt'
# api_private_key = f'{cert_dir}/{client_id}.key'

# federation = Federation(client_id=client_id, director_node_fqdn=director_node_fqdn, director_port=director_port,
#                        cert_chain=cert_chain, api_cert=api_certificate, api_private_key=api_private_key)

# --------------------------------------------------------------------------------------------------------------------

# 2) Run with TLS disabled (trusted environment)
# Federation can also determine local fqdn automatically
federation = Federation(
    client_id=client_id,
    director_node_fqdn=director_node_fqdn,
    director_port=director_port,
    tls=False,
)

In [None]:
federation.target_shape

In [None]:
shard_registry = federation.get_shard_registry()
shard_registry

In [None]:
# First, request a dummy_shard_desc that holds information about the federated dataset
dummy_shard_desc = federation.get_dummy_shard_descriptor(size=10)
dummy_shard_dataset = dummy_shard_desc.get_dataset("train")
sample, target = dummy_shard_dataset[0]
print(sample.shape)
print(target.shape)

## Creating a FL experiment using Interactive API

### Register dataset

In [None]:
from monai.data import CacheDataset, DataLoader, Dataset
from monai.transforms import (
    Compose,
    EnsureChannelFirstD,
    EnsureTypeD,
    LoadImageD,
    RandRotateD,
    RandZoomD,
    ScaleIntensityRanged,
)

In [None]:
image_transforms = Compose(
    [
        LoadImageD(keys=["fixed_hand", "moving_hand"]),
        EnsureChannelFirstD(keys=["fixed_hand", "moving_hand"]),
        ScaleIntensityRanged(
            keys=["fixed_hand", "moving_hand"],
            a_min=0.0,
            a_max=255.0,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        RandRotateD(
            keys=["moving_hand"],
            range_x=np.pi / 4,
            prob=1.0,
            keep_size=True,
            mode="bicubic",
        ),
        RandZoomD(
            keys=["moving_hand"],
            min_zoom=0.9,
            max_zoom=1.1,
            prob=1.0,
            mode="bicubic",
            align_corners=False,
        ),
        EnsureTypeD(keys=["fixed_hand", "moving_hand"]),
    ]
)

In [None]:
class MedNISTDataset(DataInterface):
    def __init__(self, **kwargs):
        self.kwargs = kwargs

    @property
    def shard_descriptor(self):
        return self._shard_descriptor

    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.

        This method will be called during a collaborator initialization.
        Local shard_descriptor  will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor

        self.train_set = Dataset(
            data=self._shard_descriptor.get_dataset("train").data_items,
            transform=image_transforms,
        )
        self.valid_set = Dataset(
            data=self._shard_descriptor.get_dataset("validation").data_items,
            transform=image_transforms,
        )

    def get_train_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks with optimizer in contract
        """
        return DataLoader(
            self.train_set, batch_size=self.kwargs["train_bs"], shuffle=True
        )

    def get_valid_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks without optimizer in contract
        """
        return DataLoader(self.valid_set, batch_size=self.kwargs["valid_bs"])

    def get_train_data_size(self):
        """
        Information for aggregation
        """
        return len(self.train_set)

    def get_valid_data_size(self):
        """
        Information for aggregation
        """
        return len(self.valid_set)

In [None]:
fed_dataset = MedNISTDataset(train_bs=16, valid_bs=16)

### Describe the model and optimizer

In [None]:
from monai.networks.blocks import Warp
from monai.networks.nets import GlobalNet
from torch.nn import MSELoss

In [None]:
model_net = GlobalNet(
    image_size=(64, 64),
    spatial_dims=2,
    in_channels=2,  # moving and fixed
    num_channel_initial=16,
    depth=3,
)

image_loss = MSELoss()
if USE_COMPILED:
    warp_layer = Warp(3, "border")
else:
    warp_layer = Warp("bilinear", "border")
optimizer_adam = torch.optim.Adam(model_net.parameters(), 1e-5)

### Register model

In [None]:
framework_adapter = (
    "openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin"
)
model_interface = ModelInterface(
    model=model_net, optimizer=optimizer_adam, framework_plugin=framework_adapter
)

## Define and register FL tasks

In [None]:
task_interface = TaskInterface()


@task_interface.register_fl_task(
    model="net_model",
    data_loader="train_loader",
    device="device",
    optimizer="optimizer",
)
def train(
    net_model,
    train_loader,
    optimizer,
    device,
    loss_fn=image_loss,
    affine_transform=warp_layer,
):

    train_loader = tqdm.tqdm(train_loader, desc="train")
    net_model.train()
    net_model.to(device)
    warp_layer.to(device)

    epoch_loss = 0.0
    step = 0

    for batch_data in train_loader:
        step += 1
        optimizer.zero_grad()

        moving = batch_data["moving_hand"].to(device)
        fixed = batch_data["fixed_hand"].to(device)
        ddf = net_model(torch.cat((moving, fixed), dim=1))
        pred_image = affine_transform(moving, ddf)

        loss = loss_fn(pred_image, fixed)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    epoch_loss /= step
    return {
        "train_loss": epoch_loss,
    }


@task_interface.register_fl_task(
    model="net_model", data_loader="val_loader", device="device"
)
def validate(
    net_model, val_loader, device, loss_fn=image_loss, affine_transform=warp_layer
):
    net_model.eval()
    net_model.to(device)
    warp_layer.to(device)

    epoch_loss = 0.0 
    step = 0

    val_loader = tqdm.tqdm(val_loader, desc="validate")

    with torch.no_grad():
        for batch_data in val_loader:
            step += 1

            moving = batch_data["moving_hand"].to(device)
            fixed = batch_data["fixed_hand"].to(device)
            ddf = net_model(torch.cat((moving, fixed), dim=1))
            pred_image = affine_transform(moving, ddf)
            loss = loss_fn(pred_image, fixed)
            epoch_loss += loss.item()

    epoch_loss /= step
    return {
        "validation_loss": epoch_loss,
    }

## Time to start a federated learning experiment

In [None]:
# Create an experimnet in federation
experiment_name = "mednist_experiment"
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [None]:
# The following command zips the workspace and python requirements to be transfered to collaborator nodes
fl_experiment.start(
    model_provider=model_interface,
    task_keeper=task_interface,
    data_loader=fed_dataset,
    rounds_to_train=10,
    opt_treatment="CONTINUE_GLOBAL",
    device_assignment_policy="CUDA_PREFERRED",
)

In [None]:
# To check how experiment is going
fl_experiment.stream_metrics(tensorboard_logs=False)