# Notebook to calculate the Dice metrics for all the trained models

In [None]:
from monai.utils import first, set_determinism
from medpy.metric.binary import dc as dice_coef
from medpy.metric.binary import jc as jacc
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureType,
    Activationsd,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    RandAffined,
    ScaleIntensityRanged,
    Spacingd,
    Invertd,
    Resized,
    RandFlipd,
    RandShiftIntensityd,
    Spacingd,
    RandRotate90d,
    EnsureTyped,
    Transpose,
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet, BasicUNet
from monai.networks.layers import Norm
from torch.optim.lr_scheduler import CosineAnnealingLR

from monai.metrics import DiceMetric, ROCAUCMetric, MSEMetric
from monai.networks.utils import copy_model_state
from monai.optimizers import generate_param_groups
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import time
import glob
import numpy as np
import wandb
import nibabel as nib
import copy

%matplotlib inline

In [None]:
!nvidia-smi

In [None]:
torch.backends.cudnn.benchmark = True

## Setup imports

## Setup data directory

You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  
This allows you to save results and reuse downloads.  
If not specified a temporary directory will be used.

In [None]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = 'PATH/SegViz/'
print(root_dir)

Example path to one dataset - In this case the KITS19

In [None]:
train_images = sorted(
    glob.glob(os.path.join('PATH//SegViz/Task040_KiTS', "imagesTr", "*.nii.gz")))
train_labels = sorted(
    glob.glob(os.path.join('PATH/SegViz/Task040_KiTS', "labelsTr", "*.nii.gz")))
data_dicts_spleen = [
    {"image": image_name, "label": label_name}
    for image_name, label_name in zip(train_images, train_labels)
]
train_files_spleen = data_dicts_spleen[-42:]

In [None]:
train_transforms_spleen = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityRanged(
            keys=["image"], a_min=-79, a_max=304,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"], pixdim=(
            1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
        Resized(keys=["image"], spatial_size=(256,256,128)),   
        Resized(keys=["label"], spatial_size=(256,256,128), mode='nearest'), 
        #CropForegroundd(keys=["image", "label"], source_key="image"),
        
    ]
)

In [None]:
train_ds_spleen = CacheDataset(
    data=train_files_spleen, transform=train_transforms_spleen,
    cache_rate=1.0, num_workers=0)

train_loader_spleen = DataLoader(train_ds_spleen, batch_size=1, shuffle=False, num_workers=0)


In [None]:
config = {
    # data
    "cache_rate_spleen": 1.0,
    "num_workers": 5,


    # train settings
    "train_batch_size": 2,
    "val_batch_size": 1,
    "learning_rate": 1e-4,
    "max_epochs": 1000,
    "val_interval": 2, # check validation score after n epochs
    "lr_scheduler": "cosine_decay", # just to keep track


    # Unet model (you can even use nested dictionary and this will be handled by W&B automatically)
    "model_type_spleen": "unet", # dummy names just to keep track
    "model_type_liver": "unet",
    "model_params_spleen": dict(spatial_dims=3,
                  in_channels=1,
                  out_channels=2,
                  channels=(16, 32, 64, 128, 256),
                  strides=(2, 2, 2, 2),
                  num_res_units=2,
                  norm=Norm.BATCH,),
    "model_params_liver": dict(spatial_dims=3,
                      in_channels=1,
                      out_channels=5,
                      channels=(16, 32, 64, 128, 256),
                      strides=(2, 2, 2, 2),
                      num_res_units=2,
                      norm=Norm.BATCH,),
    # data
    "cache_rate_liver": 0.4,
}

In [None]:
# standard PyTorch program style: create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_spleen = UNet(**config['model_params_spleen']).to(device)
model_liver = UNet(**config['model_params_liver']).to(device)

post_pred_spleen = Compose([AsDiscrete(argmax=True, to_onehot=2)])
post_label_spleen = Compose([AsDiscrete(to_onehot=2)])

# Base

In [None]:
model_spleen.load_state_dict(torch.load(
    os.path.join(root_dir, "best_metric_model_kitsonly_128_rescrop.pth")))
model_spleen.eval()
dice_metric_all_base = []
jacc_all_only = []

slice_map = {
    "img0035.nii.gz": 60,
    "img0036.nii.gz": 100,
    "img0037.nii.gz": 95,
    "img0038.nii.gz": 95,
    "img0039.nii.gz": 95,
    "img0040.nii.gz": 60,
}
case_num = 1
start_time = time.time()

with torch.no_grad():
    for i, val_data in (enumerate(train_loader_spleen)):
        roi_size = (160, 160, 160)
        img_name = os.path.split(val_data['image'].meta["filename_or_obj"][0])[1]
        sw_batch_size = 4
        val_outputs = sliding_window_inference(
            val_data["image"].to(device), roi_size, sw_batch_size, model_spleen
        )
        val_outputs_spleen = torch.argmax(
            val_outputs, dim=1).detach().cpu()[0, :, :, :]
        val_labels_spleen = val_data["label"][0, 0, :, :, :]
        

        assert val_labels_spleen.shape == val_outputs_spleen.shape

        dice_metric_all_base.append(dice_coef(val_outputs_spleen, val_labels_spleen))

        plt.figure()
        plt.imshow(val_data["image"][0, 0, :, :,85], cmap="gray")
        plt.axis('off')
        plt.savefig(f"PATH//Images/kits/base/{img_name}_img.png")

        plt.figure()
        plt.imshow(val_labels_spleen[:, :, 85], cmap='gray')
        plt.axis('off')
        plt.savefig(f"PATH//Images/kits/base/{img_name}_label.png")

        plt.figure()

        plt.imshow(val_outputs_spleen[:, :, 85], cmap='gray')
        plt.axis('off')
        plt.savefig(f"PATH//Images/kits/base/{img_name}_pred.png")

        plt.show()

In [None]:
print(np.mean(dice_metric_all_base), np.std(dice_metric_all_base))

# Fedavg

In [None]:
model_spleen.load_state_dict(torch.load(
    os.path.join(root_dir,"best_metric_model_kits_128_segviz_LSPK.pth")))
model_spleen.eval()
dice_metric_all_fedavg = []
jacc_all_spleenonly = []
slice_map = {
    "img0035.nii.gz": 60,
    "img0036.nii.gz": 100,
    "img0037.nii.gz": 95,
    "img0038.nii.gz": 95,
    "img0039.nii.gz": 95,
    "img0040.nii.gz": 60,
}
case_num = 1

start_time = time.time()
with torch.no_grad():
    for i, val_data in (enumerate(train_loader_spleen)):
        roi_size = (160, 160, 160)
        img_name = os.path.split(val_data['image'].meta["filename_or_obj"][0])[1]
        sw_batch_size = 4
        val_outputs = sliding_window_inference(
            val_data["image"].to(device), roi_size, sw_batch_size, model_spleen
        )
        val_outputs_spleen = torch.argmax(
            val_outputs, dim=1).detach().cpu()[0, :, :, :]
        val_labels_spleen = val_data["label"][0, 0, :, :, :]
        
        assert val_labels_spleen.shape == val_outputs_spleen.shape

        dice_metric_all_fedavg.append(dice_coef(val_outputs_spleen, val_labels_spleen))


        plt.figure()

        plt.imshow(val_data["image"][0, 0, :, :,85], cmap="gray")
        plt.axis('off')
        plt.savefig(f"PATH//Images/kits/fedavg/{img_name}_img.png")

        plt.figure()
        plt.imshow(val_labels_spleen[:, :, 85], cmap='gray')
        plt.axis('off')
        plt.savefig(f"PATH//Images/kits/fedavg/{img_name}_label.png")

        plt.figure()
        plt.imshow(val_outputs_spleen[:, :, 85], cmap='gray')
        plt.axis('off')
        plt.savefig(f"PATH//Images/kits/fedavg/{img_name}_pred.png")

        plt.show()

In [None]:
print(np.mean(dice_metric_all_fedavg), np.std(dice_metric_all_fedavg))

# FedAvg + FT

In [None]:
model_spleen.load_state_dict(torch.load(
    os.path.join(root_dir,"best_metric_model_kits_128segviz_finetuned_LSPK.pth")))
model_spleen.eval()
dice_metric_all_fedavgft = []
jacc_all_spleenonly = []
slice_map = {
    "img0035.nii.gz": 60,
    "img0036.nii.gz": 100,
    "img0037.nii.gz": 95,
    "img0038.nii.gz": 95,
    "img0039.nii.gz": 95,
    "img0040.nii.gz": 60,
}
case_num = 1

start_time = time.time()
with torch.no_grad():
    for i, val_data in (enumerate(train_loader_spleen)):
        roi_size = (160, 160, 160)
        img_name = os.path.split(val_data['image'].meta["filename_or_obj"][0])[1]
        sw_batch_size = 4
        val_outputs = sliding_window_inference(
            val_data["image"].to(device), roi_size, sw_batch_size, model_spleen
        )
        val_outputs_spleen = torch.argmax(
            val_outputs, dim=1).detach().cpu()[0, :, :, :]
        val_labels_spleen = val_data["label"][0, 0, :, :, :]

        assert val_labels_spleen.shape == val_outputs_spleen.shape

        dice_metric_all_fedavgft.append(dice_coef(val_outputs_spleen, val_labels_spleen))

        plt.figure()
        plt.imshow(val_data["image"][0, 0, :, :,80], cmap="gray")
        plt.axis('off')
        plt.savefig(f"PATH//Images/Kidneys/{img_name}_img.png")

        plt.figure()
        plt.imshow(val_labels_spleen[:, :, 80], cmap='gray')
        plt.axis('off')
        plt.savefig(f"PATH//Images/Kidneys/{img_name}_label.png")

        plt.figure()
        plt.imshow(val_outputs_spleen[:, :, 80], cmap='gray')
        plt.axis('off')
        plt.savefig(f"PATH//Images/Kidneys/{img_name}_pred.png")

        plt.show()

In [None]:
print(np.mean(dice_metric_all_fedavgft), np.std(dice_metric_all_fedavgft))

# FedBN


In [None]:
model_spleen.load_state_dict(torch.load(
    os.path.join(root_dir,"best_metric_model_kits_128_segviz_LSPK_fedbn.pth")))
model_spleen.eval()
dice_metric_all_fedbn = []
jacc_all_spleenonly = []
slice_map = {
    "img0035.nii.gz": 60,
    "img0036.nii.gz": 100,
    "img0037.nii.gz": 95,
    "img0038.nii.gz": 95,
    "img0039.nii.gz": 95,
    "img0040.nii.gz": 60,
}
case_num = 1

start_time = time.time()
with torch.no_grad():
    for i, val_data in (enumerate(train_loader_spleen)):
        roi_size = (160, 160, 160)
        img_name = os.path.split(val_data['image'].meta["filename_or_obj"][0])[1]
        sw_batch_size = 4
        val_outputs = sliding_window_inference(
            val_data["image"].to(device), roi_size, sw_batch_size, model_spleen
        )
        val_outputs_spleen = torch.argmax(
            val_outputs, dim=1).detach().cpu()[0, :, :, :]
        val_labels_spleen = val_data["label"][0, 0, :, :, :]

        assert val_labels_spleen.shape == val_outputs_spleen.shape

        dice_metric_all_fedbn.append(dice_coef(val_outputs_spleen, val_labels_spleen))
        plt.figure()
        plt.imshow(val_data["image"][0, 0, :, :,85], cmap="gray")
        plt.axis('off')
        plt.savefig(f"PATH//Images/kits/fedbn/{img_name}_img.png")

        plt.figure()
        plt.imshow(val_labels_spleen[:, :, 85], cmap='gray')
        plt.axis('off')
        plt.savefig(f"PATH//Images/kits/fedbn/{img_name}_label.png")

        plt.figure()
        plt.imshow(val_outputs_spleen[:, :, 85], cmap='gray')
        plt.axis('off')
        plt.savefig(f"PATH//Images/kits/fedbn/{img_name}_pred.png")

        plt.show()

In [None]:
print(np.mean(dice_metric_all_fedbn), np.std(dice_metric_all_fedbn))

# Fedbn + FT


In [None]:
model_spleen.load_state_dict(torch.load(
    os.path.join(root_dir,"best_metric_model_kits_finetuned_LSPK_fedbn.pth")))
model_spleen.eval()
dice_metric_all_fedbnft = []
jacc_all_spleenonly = []
slice_map = {
    "img0035.nii.gz": 60,
    "img0036.nii.gz": 100,
    "img0037.nii.gz": 95,
    "img0038.nii.gz": 95,
    "img0039.nii.gz": 95,
    "img0040.nii.gz": 60,
}
case_num = 1

start_time = time.time()
with torch.no_grad():
    for i, val_data in (enumerate(train_loader_spleen)):
        roi_size = (160, 160, 160)
        img_name = os.path.split(val_data['image'].meta["filename_or_obj"][0])[1]
        sw_batch_size = 4
        val_outputs = sliding_window_inference(
            val_data["image"].to(device), roi_size, sw_batch_size, model_spleen
        )
        val_outputs_spleen = torch.argmax(
            val_outputs, dim=1).detach().cpu()[0, :, :, :]
        val_labels_spleen = val_data["label"][0, 0, :, :, :]

        assert val_labels_spleen.shape == val_outputs_spleen.shape

        dice_metric_all_fedbnft.append(dice_coef(val_outputs_spleen, val_labels_spleen))

        plt.figure()
        plt.imshow(val_data["image"][0, 0, :, :,80], cmap="gray")
        plt.axis('off')

        plt.figure()
        plt.imshow(val_labels_spleen[:, :, 80], cmap='gray')
        plt.axis('off')
        plt.savefig(f"PATH//Images/Kidneys/{img_name}_label.png")

        plt.figure()
        plt.imshow(val_outputs_spleen[:, :, 80], cmap='gray')
        plt.axis('off')
        plt.savefig(f"PATH//Images/Kidneys/{img_name}_pred.png")

        plt.show()

In [None]:
print(np.mean(dice_metric_all_fedbnft), np.std(dice_metric_all_fedbnft))

# Central

In [None]:
model_liver.load_state_dict(torch.load(
    os.path.join(root_dir,"best_metric_model_LSPK_combined.pth")))
model_liver.eval()
dice_metric_all_central = []
jacc_all_spleenonly = []
slice_map = {
    "img0035.nii.gz": 60,
    "img0036.nii.gz": 100,
    "img0037.nii.gz": 95,
    "img0038.nii.gz": 95,
    "img0039.nii.gz": 95,
    "img0040.nii.gz": 60,
}
case_num = 1

start_time = time.time()
with torch.no_grad():
    for i, val_data in (enumerate(train_loader_spleen)):
        roi_size = (160, 160, 160)
        img_name = os.path.split(val_data['image'].meta["filename_or_obj"][0])[1]
        sw_batch_size = 4
        val_outputs = sliding_window_inference(
            val_data["image"].to(device), roi_size, sw_batch_size, model_liver
        )
        val_outputs_spleen = torch.argmax(
            val_outputs, dim=1).detach().cpu()[0, :, :, :]
        val_labels_spleen = val_data["label"][0, 0, :, :, :]
        
        val_outputs_spleen[val_outputs_spleen != 4] = 0


        assert val_labels_spleen.shape == val_outputs_spleen.shape

        dice_metric_all_central.append(dice_coef(val_outputs_spleen, val_labels_spleen))

        plt.figure()
        plt.imshow(val_data["image"][0, 0, :, :,85], cmap="gray")
        plt.axis('off')
        plt.savefig(f"PATH//Images/kits/central/{img_name}_img.png")

        plt.figure()
        plt.imshow(val_labels_spleen[:, :, 85], cmap='gray')
        plt.axis('off')
        plt.savefig(f"PATH//Images/kits/central/{img_name}_label.png")

        plt.figure()
        plt.imshow(val_outputs_spleen[:, :, 85], cmap='gray')
        plt.axis('off')
        plt.savefig(f"PATH//Images/kits/central/{img_name}_pred.png")

        plt.show()

In [None]:
print(np.mean(dice_metric_all_central), np.std(dice_metric_all_central))