# Simple Federation

## Configure Environment

In [None]:
# Run this to change the working directory.
# TODO: pip install will resolve this issue...
import os
os.chdir("..")

In [None]:
import imagiq.federated as iqf
import numpy as np
import torch
from imagiq.models import Model
from imagiq.datasets import NIHDataset
from monai.transforms import (
    Compose,
    LoadImaged,
    ScaleIntensityd,
    SqueezeDimd,
    AddChanneld,
    AsChannelFirstd,
    Lambdad,
    ToTensord,
    Resized, 
    RandRotated, 
    RandFlipd, 
    RandZoomd
)
from monai.networks.nets import densenet121, densenet201, densenet264, se_resnet50, se_resnet101, se_resnet152
from monai.data import CacheDataset
import sys

## Create and start local nodes

In [None]:
# Create local nodes at ports 8000 and 8001
node1_port = 8000
node2_port = 8001
node1 = iqf.nodes.Node("localhost", node1_port)  # a virtual computer
node2 = iqf.nodes.Node("localhost", node2_port)  # another virtual computer

node1.start()
node2.start()

## Establish Connections

In [None]:
node1.connect_to("localhost", node2_port)
node2.connect_to("localhost", node1_port)

In [None]:
# TODO: Read all, not just the test section
# TODO: Train test split
train_dataset = NIHDataset(section="training", download=[0])
val_dataset = NIHDataset(section="validation", download=[0])
test_dataset = NIHDataset(section="test", download=[0])
print(train_dataset)
print(val_dataset)
print(test_dataset)

In [None]:
# Node 1 is biased towards Atelectasis
# Node 2 is biased towards Infiltration
# TODO: bias towards AP/Lateral views
# TODO: Bias towards male/female
N_normal = train_dataset.class_count[0]

train_node1 = list()
train_node2 = list()
for i, data in enumerate(train_dataset):
    if data["label"][1]:  # if atelectasis, more likely go in node1
        data["label"] = 1-data["label"][0]
        if np.random.rand() < 0.8:
            train_node1.append(data)
        else:
            train_node2.append(data)
    elif data["label"][4]:  # infiltration, more likely go in node2
        data["label"] = 1-data["label"][0]
        if np.random.rand() < 0.8:
            train_node2.append(data)
        else:
            train_node1.append(data)
    else:  # for other findings, split it half and half
        data["label"] = 1-data["label"][0]
        if np.random.rand() < 0.5:
            train_node1.append(data)
        else:
            train_node2.append(data)

In [None]:
val_normal = val_dataset.class_count[0]

val_node1 = list()
val_node2 = list()
for i, data in enumerate(val_dataset):
    
    if data["label"][1]:  # if atelectasis, more likely go in node1
        data["label"] = 1-data["label"][0]
        if np.random.rand() < 0.8:
            val_node1.append(data)
        else:
            val_node2.append(data)
    elif data["label"][4]:  # infiltration, more likely go in node2
        data["label"] = 1-data["label"][0]
        if np.random.rand() < 0.8:
            val_node2.append(data)
        else:
            val_node1.append(data)
    else:  # for other findings, split it half and half
        data["label"] = 1-data["label"][0]
        if np.random.rand() < 0.5:
            val_node1.append(data)
        else:
            val_node2.append(data)

In [None]:
test_normal = test_dataset.class_count[0]
test_node1 = list()
test_node2 = list()
for i, data in enumerate(test_dataset):
    if data["label"][1]:  # if atelectasis, more likely go in node1
        data["label"] = 1-data["label"][0]
        if np.random.rand() < 0.8:
            test_node1.append(data)
        else:
            test_node2.append(data)
    elif data["label"][4]:  # infiltration, more likely go in node2
        data["label"] = 1-data["label"][0]
        if np.random.rand() < 0.8:
            test_node2.append(data)
        else:
            val_node1.append(data)
    else:  # for other findings, split it half and half
        data["label"] = 1-data["label"][0]
        if np.random.rand() < 0.5:
            test_node1.append(data)
        else:
            test_node2.append(data)

In [None]:
train_transforms = Compose([
    LoadImaged("image"), 
    Lambdad("image", func=lambda x: np.mean(x, axis=2) if len(x.shape) == 3 else x),
    AsChannelFirstd("image"),
    AddChanneld("image"),
    ScaleIntensityd("image"),
    Resized("image", spatial_size=(224,224), mode="nearest"),
    RandRotated("image", range_x=15, prob=0.5, keep_size=True),
    RandFlipd("image", spatial_axis=0, prob=0.5),
    RandZoomd("image", min_zoom=0.9, max_zoom=1.1, prob=0.5, keep_size=True)
])

val_transforms = Compose([
    LoadImaged("image"),
    Lambdad("image", func=lambda x: np.mean(x, axis=2) if len(x.shape) == 3 else x),
    AsChannelFirstd("image"),
    AddChanneld("image"),
    ScaleIntensityd("image"),
    Resized("image", spatial_size=(224,224), mode="nearest")
])

train_node1 = CacheDataset(train_node1, train_transforms)
train_node2 = CacheDataset(train_node2, train_transforms)
val_node1 = CacheDataset(val_node1, val_transforms)
val_node2 = CacheDataset(val_node2, val_transforms)
test_node1 = CacheDataset(test_node1, val_transforms)
test_node2 = CacheDataset(test_node2, val_transforms)

## Add models and datasets to the hospitals

In [None]:
node1.add_model([
    Model(se_resnet50(spatial_dims=2, in_channels=1, num_classes=2), 'hospitalA_se_resnet50'),
    Model(se_resnet101(spatial_dims=2, in_channels=1, num_classes=2), 'hospitalA_se_resnet101'),
])

node2.add_model([
    Model(densenet201(spatial_dims=2, in_channels=1, out_channels=2), 'hospitalB_densenet201'),
    Model(se_resnet152(spatial_dims=2, in_channels=1, num_classes=2), 'hospitalB_se_resnet152'),
])

In [None]:
node1.add_dataset(train_node1)
node2.add_dataset(train_node2)

In [None]:
import gc
import time

gc.collect()
torch.cuda.empty_cache()

## Training @ Hospital 1

In [None]:
# Models in hospital 1
# TODO: this should happen in the node class (e.g. node1.train_all())
for model in node1.model_bench:
    optimizer = torch.optim.Adam( model.net.parameters(), 5e-3)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=0.1, patience=5)
    model.train(
        train_node1,
        torch.nn.CrossEntropyLoss(),
        optimizer,
        epochs=10,
        metrics=["AUC"],
        batch_size=16,
        device="cpu",
        validation_dataset=val_node1,
        dirpath='path/to/save/model/',
        scheduler=scheduler
)

In [None]:
for model in node1.model_bench:
    print(model.name)
    model.predict(
        test_node1,
        batch_size=16
    )

## Training @ Hospital 2

In [None]:
###### Models in hospital 2
for model in node2.model_bench:
    optimizer = torch.optim.Adam( model.net.parameters(), 5e-3)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=0.1, patience=5)
    model.train(
        train_node2,
        torch.nn.CrossEntropyLoss(),
        optimizer,
        epochs=10,
        metrics=["AUC"],
        batch_size=16,
        device="cpu",
        validation_dataset=val_node2,
        dirpath='path/to/save/model/',
        scheduler=scheduler
)

In [None]:
for model in node2.model_bench:
    print(model.name)
    model.predict(
        test_node2,
        batch_size=5
    )

## Terminate the Connections and Destroy the Virtual Hospitals

In [None]:
# TODO: merge stop and join
node1.stop()
node2.stop()
node1.join()
node2.join()

In [None]:
node1.destroy()
node2.destroy()