## Setup environment

In [None]:
from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    RandAffined,
    ScaleIntensityRanged,
    Spacingd,
    Invertd,
    Resized,
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet, BasicUNet
from monai.networks.layers import Norm
from torch.optim.lr_scheduler import CosineAnnealingLR

from monai.metrics import DiceMetric, ROCAUCMetric, MSEMetric
from monai.networks.utils import copy_model_state
from monai.optimizers import generate_param_groups
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import random
import glob
import numpy as np
import wandb
import copy
import nibabel as nib
%matplotlib inline

In [None]:
torch.backends.cudnn.benchmark = True

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

## Setup imports

## Setup data directory

You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  
This allows you to save results and reuse downloads.  
If not specified a temporary directory will be used.

In [None]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = '/content'
print(root_dir)

## Download dataset

Downloads and extracts the dataset.  
The dataset comes from http://medicaldecathlon.com/.

## Set MSD Spleen dataset path

In [None]:
data_dir_spleen = '/content/LSPK_combined'

In [None]:
train_images = sorted(
    glob.glob(os.path.join(data_dir_spleen, "imagesTr", "*.nii.gz")))
train_labels = sorted(
    glob.glob(os.path.join(data_dir_spleen, "labelsTr", "*.nii.gz")))
data_dicts_spleen = [
    {"image": image_name, "label": label_name}
    for image_name, label_name in zip(train_images, train_labels)
]

random.seed(108)
random.shuffle(data_dicts_spleen)

train_files_spleen, val_files_spleen = data_dicts_spleen[:-132], data_dicts_spleen[-132:]

## Set deterministic training for reproducibility

In [None]:
set_determinism(seed=0)

## Setup transforms for training and validation

Here we use several transforms to augment the dataset:
1. `LoadImaged` loads the spleen CT images and labels from NIfTI format files.
1. `EnsureChannelFirstd` ensures the original data to construct "channel first" shape.
1. `Orientationd` unifies the data orientation based on the affine matrix.
1. `Spacingd` adjusts the spacing by `pixdim=(1.5, 1.5, 2.)` based on the affine matrix.
1. `ScaleIntensityRanged` extracts intensity range [-57, 164] and scales to [0, 1].
1. `CropForegroundd` removes all zero borders to focus on the valid body area of the images and labels.
1. `RandCropByPosNegLabeld` randomly crop patch samples from big image based on pos / neg ratio.  
The image centers of negative samples must be in valid body area.
1. `RandAffined` efficiently performs `rotate`, `scale`, `shear`, `translate`, etc. together based on PyTorch affine transform.

In [None]:
train_transforms_spleen = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityRanged(
            keys=["image"], a_min=-200, a_max=304,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        Resized(keys=["image"], spatial_size=(256,256,128)),   
        Resized(keys=["label"], spatial_size=(256,256,128), mode='nearest'),   
        # CropForegroundd(keys=["image", "label"], source_key="image"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"], pixdim=(
            1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=(128,128,32),
            pos=1,
            neg=1,
            num_samples=4,
            image_key="image",
            image_threshold=0,
        ),
    ]
)
val_transform_spleen = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityRanged(
            keys=["image"], a_min=-200, a_max=304,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        Resized(keys=["image"], spatial_size=(256,256,128)),   
        Resized(keys=["label"], spatial_size=(256,256,128), mode='nearest'), 
        #CropForegroundd(keys=["image", "label"], source_key="image"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"], pixdim=(
            1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
    ]
)

## Check transforms in DataLoader

In [None]:
check_ds = Dataset(data=val_files_spleen, transform=val_transform_spleen)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)
image, label = (check_data["image"][0][0], check_data["label"][0][0])
print(f"image shape: {image.shape}, label shape: {label.shape}")
print(np.unique(label))

# plot the slice [:, :, 80]
plt.figure("check", (12, 6))
plt.subplot(1, 2, 1)
plt.title("image")
plt.imshow(image[:, :, 16], cmap="gray")
plt.subplot(1, 2, 2)
plt.title("label")
plt.imshow(label[:, :, 16])
plt.show()

In [None]:
config = {
    # data
    "cache_rate_spleen": 1.0,
    "num_workers": 14,


    # train settings
    "train_batch_size": 2,
    "val_batch_size": 1,
    "learning_rate": 1e-4,
    "max_epochs": 500,
    "val_interval": 2, # check validation score after n epochs
    "lr_scheduler": "cosine_decay", # just to keep track




    # Unet model (you can even use nested dictionary and this will be handled by W&B automatically)
    "model_type_spleen": "unet", # just to keep track
    "model_params_spleen": dict(spatial_dims=3,
                  in_channels=1,
                  out_channels=5,
                  channels=(16, 32, 64, 128, 256),
                  strides=(2, 2, 2, 2),
                  num_res_units=2,
                  norm=Norm.BATCH,),
}

## Define CacheDataset and DataLoader for training and validation

Here we use CacheDataset to accelerate training and validation process, it's 10x faster than the regular Dataset.  
To achieve best performance, set `cache_rate=1.0` to cache all the data, if memory is not enough, set lower value.  
Users can also set `cache_num` instead of `cache_rate`, will use the minimum value of the 2 settings.  
And set `num_workers` to enable multi-threads during caching.  
If want to to try the regular Dataset, just change to use the commented code below.

In [None]:
train_ds_spleen = CacheDataset(
    data=train_files_spleen, transform=train_transforms_spleen,
    cache_rate=config['cache_rate_spleen'], num_workers=config['num_workers'])
# train_ds_spleen = Dataset(data=train_files_spleen, transform=train_transforms_spleen)

# use batch_size=2 to load images and use RandCropByPosNegLabeld
# to generate 2 x 4 images for network training
train_loader_spleen = DataLoader(train_ds_spleen, batch_size=config['train_batch_size'], shuffle=True, num_workers=config['num_workers'])

val_ds_spleen = CacheDataset(
    data=val_files_spleen, transform=val_transform_spleen, cache_rate=config['cache_rate_spleen'], num_workers=config['num_workers'])
# val_ds = Dataset(data=val_files, transform=val_transforms)
val_loader_spleen = DataLoader(val_ds_spleen, batch_size=config['val_batch_size'], num_workers=config['num_workers'])

## Create Model, Loss, Optimizer

In [None]:
# standard PyTorch program style: create UNet, DiceLoss and Adam optimizer
model_spleen = UNet(**config['model_params_spleen']).to(device)

loss_function_spleen = DiceLoss(to_onehot_y=True, softmax=True)


optimizer_spleen = torch.optim.Adam(model_spleen.parameters(), lr=config['learning_rate'])


dice_metric_spleen = DiceMetric(include_background=False, reduction="mean")


scheduler_spleen = CosineAnnealingLR(optimizer_spleen, T_max=config['max_epochs'], eta_min=1e-9)


## Execute a typical PyTorch training process

In [None]:
# 🐝 initialize a wandb run
wandb.init(
    project="SegViz_LSPK_combined",
    config=config
)

# 🐝 log gradients of the model to wandb
wandb.watch(model_spleen, log_freq=100)

max_epochs = 500
val_interval = 2
best_metric = -1
best_metric_spleen = -1
best_metric_liver = -1
best_metric_pan = -1

best_metric_epoch = -1
best_metric_epoch_spleen = -1
best_metric_epoch_liver = -1
best_metric_epoch_pan = -1

epoch_loss_values = []
metric_values = []

epoch_loss_values_spleen = []
metric_values_spleen = []

epoch_loss_values_liver = []
metric_values_liver = []

epoch_loss_values_pan = []
metric_values_pan = []

post_pred_spleen = Compose([AsDiscrete(argmax=True, to_onehot=5)])
post_label_spleen = Compose([AsDiscrete(to_onehot=5)])

for epoch in range(max_epochs):
  epoch_loss_spleen = 0

  step_0 = 0
  step_1 = 0
  
  # For one epoch
  
  print("-" * 10)
  print(f"epoch {epoch + 1}/{max_epochs}")
  
  model_spleen.train()
  
  # One forward pass of the spleen data through the spleen UNet
  for batch_data_spleen in train_loader_spleen:
      step_0 += 1
      inputs_spleen, labels_spleen = (
          batch_data_spleen["image"].to(device),
          batch_data_spleen["label"].to(device),
      )
      optimizer_spleen.zero_grad()
      outputs_spleen = model_spleen(inputs_spleen)
      
      if "liver" in batch_data_spleen['image_meta_dict']['filename_or_obj'][0].split('/')[-1]:
        labels_spleen[labels_spleen != 2] = 0
      elif "pancreas" in batch_data_spleen['image_meta_dict']['filename_or_obj'][0].split('/')[-1]:
        labels_spleen[labels_spleen != 3] = 0
      elif "spleen" in batch_data_spleen['image_meta_dict']['filename_or_obj'][0].split('/')[-1]:
        labels_spleen[labels_spleen != 1] = 0
      elif "case" in batch_data_spleen['image_meta_dict']['filename_or_obj'][0].split('/')[-1]:
        labels_spleen[labels_spleen != 4] = 0

      loss_spleen = loss_function_spleen(outputs_spleen, labels_spleen)
      loss_spleen.backward()
      
      optimizer_spleen.step()
      epoch_loss_spleen += loss_spleen.item()
      print(
          f"{step_0}/{len(train_ds_spleen) // train_loader_spleen.batch_size}, "
          f"train_loss: {loss_spleen.item():.4f}")
      wandb.log({"train/loss combined": loss_spleen.item()})
  epoch_loss_spleen /= step_0
  epoch_loss_values_spleen.append(epoch_loss_spleen)
  print(f"epoch {epoch + 1} average loss combined: {epoch_loss_spleen:.4f}")
  
  scheduler_spleen.step()

  wandb.log({"train/loss_epoch comb": epoch_loss_spleen})
    # 🐝 log learning rate after each epoch to wandb
  wandb.log({"learning_rate comb": scheduler_spleen.get_last_lr()[0]})

  # Validation 
  if (epoch + 1) % val_interval == 0:
    model_spleen.eval()
    with torch.no_grad():

        # Validation forward spleen
        for val_data_spleen in val_loader_spleen:
            val_inputs_spleen, val_labels_spleen = (
                val_data_spleen["image"].to(device),
                val_data_spleen["label"].to(device),
            )
            roi_size = (160, 160, 160)
            sw_batch_size = 4
            
            if "liver" in val_data_spleen['image_meta_dict']['filename_or_obj'][0].split('/')[-1]:
              val_labels_spleen[val_labels_spleen != 2] = 0
            elif "pancreas" in val_data_spleen['image_meta_dict']['filename_or_obj'][0].split('/')[-1]:
              val_labels_spleen[val_labels_spleen != 3] = 0
            elif "spleen" in val_data_spleen['image_meta_dict']['filename_or_obj'][0].split('/')[-1]:
              val_labels_spleen[val_labels_spleen != 1] = 0
            elif "case" in val_data_spleen['image_meta_dict']['filename_or_obj'][0].split('/')[-1]:
              val_labels_spleen[val_labels_spleen != 4] = 0

            
            val_outputs_spleen = sliding_window_inference(
                val_inputs_spleen, roi_size, sw_batch_size, model_spleen)
            val_outputs_spleen = [post_pred_spleen(i) for i in decollate_batch(val_outputs_spleen)]
            val_labels_spleen = [post_label_spleen(i) for i in decollate_batch(val_labels_spleen)]
            # compute metric for current iteration
            dice_metric_spleen(y_pred=val_outputs_spleen, y=val_labels_spleen)

        # aggregate the final mean dice result
        metric_spleen = dice_metric_spleen.aggregate().item()
        wandb.log({"val/dice_metric combined": metric_spleen})
        scheduler_spleen.step(metric_spleen)
        # reset the status for next validation round
        dice_metric_spleen.reset()

        metric_values_spleen.append(metric_spleen)
        if metric_spleen > best_metric_spleen:
            best_metric_spleen = metric_spleen
            best_metric_epoch_spleen = epoch + 1
            torch.save(model_spleen.state_dict(), os.path.join(
                root_dir, "best_metric_model_LSPK_combined_trial2.pth"))
            print("saved new best metric model for combined dataset")
        print(
            f"current epoch: {epoch + 1} current mean dice for combined: {metric_spleen:.4f}"
            f"\nbest mean dice for combined: {best_metric_spleen:.4f} "
            f"at epoch: {best_metric_epoch_spleen}"
        )

## Plot the loss and metric

In [None]:
plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Epoch Average Loss")
x = [i + 1 for i in range(len(epoch_loss_values_spleen))]
y = epoch_loss_values_spleen
plt.xlabel("epoch")
plt.plot(x, y)
plt.subplot(1, 2, 2)
plt.title("Val Mean Dice")
x = [val_interval * (i + 1) for i in range(len(metric_values_spleen))]
y = metric_values_spleen
plt.xlabel("epoch")
plt.plot(x, y)
plt.show()