## Setup environment

In [None]:
from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    RandAffined,
    ScaleIntensityRanged,
    Spacingd,
    Invertd,
    Resized,
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet, BasicUNet
from monai.networks.layers import Norm
from torch.optim.lr_scheduler import CosineAnnealingLR

from monai.metrics import DiceMetric, ROCAUCMetric, MSEMetric
from monai.networks.utils import copy_model_state
from monai.optimizers import generate_param_groups
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
import numpy as np
import wandb
import copy
import nibabel as nib
%matplotlib inline

In [None]:
torch.backends.cudnn.benchmark = True

## Setup imports

## Setup data directory

You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  
This allows you to save results and reuse downloads.  
If not specified a temporary directory will be used.

In [None]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = '/content/'
print(root_dir)

## Download dataset

Downloads and extracts the dataset.  
The dataset comes from http://medicaldecathlon.com/.

In [None]:
resource_liver = "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task03_Liver.tar"
md5_liver = "a90ec6c4aa7f6a3d087205e23d4e6397"

compressed_file_liver = os.path.join(root_dir, "Task03_Liver.tar")
data_dir_liver = os.path.join(root_dir, "Task03_Liver")
if not os.path.exists(data_dir_liver):
    download_and_extract(resource_liver, compressed_file_liver, root_dir, md5_liver)

## Set MSD Liver dataset path

In [None]:
train_images = sorted(
    glob.glob(os.path.join(data_dir_liver, "imagesTr", "*.nii.gz")))
train_labels = sorted(
    glob.glob(os.path.join(data_dir_liver, "labelsTr", "*.nii.gz")))
data_dicts_liver = [
    {"image": image_name, "label": label_name}
    for image_name, label_name in zip(train_images, train_labels)
]
train_files_liver, val_files_liver = data_dicts_liver[:-25], data_dicts_liver[-25:]

## Convert Multilabel to single label for liver

In [None]:
import nibabel as nib
for image_path in sorted(glob.glob(os.path.join(data_dir_liver, "labelsTr", "*.nii.gz"))):
  image_file = nib.load(image_path)
  image_file_array = nib.load(image_path).get_fdata()
  image_file_array[image_file_array > 1 ] = 1
  image_file_final = nib.Nifti1Image(image_file_array, image_file.affine)
  nib.save(image_file_final , image_path)

## Set deterministic training for reproducibility

In [None]:
set_determinism(seed=0)

## Setup transforms for training and validation

Here we use several transforms to augment the dataset:
1. `LoadImaged` loads the spleen CT images and labels from NIfTI format files.
1. `EnsureChannelFirstd` ensures the original data to construct "channel first" shape.
1. `Orientationd` unifies the data orientation based on the affine matrix.
1. `Spacingd` adjusts the spacing by `pixdim=(1.5, 1.5, 2.)` based on the affine matrix.
1. `ScaleIntensityRanged` extracts intensity range [-57, 164] and scales to [0, 1].
1. `CropForegroundd` removes all zero borders to focus on the valid body area of the images and labels.
1. `RandCropByPosNegLabeld` randomly crop patch samples from big image based on pos / neg ratio.  
The image centers of negative samples must be in valid body area.
1. `RandAffined` efficiently performs `rotate`, `scale`, `shear`, `translate`, etc. together based on PyTorch affine transform.

In [None]:
train_transforms_liver = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityRanged(
            keys=["image"], a_min=-200, a_max=200,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        Resized(keys=["image"], spatial_size=(256,256,128)),   
        Resized(keys=["label"], spatial_size=(256,256,128), mode='nearest'), 
        # CropForegroundd(keys=["image", "label"], source_key="image"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"], pixdim=(
            1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=(128,128,32),
            pos=1,
            neg=1,
            num_samples=4,
            image_key="image",
            image_threshold=0,
        ),
    ]
)
val_transform_liver = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityRanged(
            keys=["image"], a_min=-200, a_max=200,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"], pixdim=(
            1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
    ]
)

## Check transforms in DataLoader

In [None]:
check_ds = Dataset(data=train_files_liver, transform=train_transforms_liver)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)
image, label = (check_data["image"][0][0], check_data["label"][0][0])
print(f"image shape: {image.shape}, label shape: {label.shape}")
label[label != 1]= 0
print(np.unique(label))
# plot the slice [:, :, 80]
plt.figure("check", (12, 6))
plt.subplot(1, 2, 1)
plt.title("image")
plt.imshow(image[:, :, 16], cmap="gray")
plt.subplot(1, 2, 2)
plt.title("label")
plt.imshow(label[:, :, 16])
plt.show()

#Logging spleen & liver slices to w&B


In [None]:
# utility function for generating interactive image mask from components
def wb_mask(bg_img, mask):
    return wandb.Image(bg_img, masks={
    "ground truth" : {"mask_data" : mask, "class_labels" : {0: "background", 1: "mask"} }})

def log_liver_slices(total_slices=100):
    
    wandb_mask_logs = []
    wandb_img_logs = []

    check_ds = Dataset(data=train_files_liver, transform=val_transform_liver)
    check_loader = DataLoader(check_ds, batch_size=1)
    check_data = first(check_loader) # get the first item of the dataloader

    image, label = (check_data["image"][0][0], check_data["label"][0][0])
    
    for img_slice_no in range(total_slices):
        img = image[:, :, img_slice_no]
        lbl = label[:, :, img_slice_no]
        
        # append the image to wandb_img_list to visualize 
        # the slices interactively in W&B dashboard
        wandb_img_logs.append(wandb.Image(img, caption=f"Slice: {img_slice_no}"))

        # append the image and masks to wandb_mask_logs
        # to see the masks overlayed on the original image
        wandb_mask_logs.append(wb_mask(img, lbl))

    wandb.log({"Liver Image": wandb_img_logs})
    wandb.log({"Segmentation Liver": wandb_mask_logs})

In [None]:
# 🐝 init wandb with appropiate project and run name
wandb.init(project="SegViz_3D", name="slice_image_exploration")
# 🐝 log images to W&B
# log_spleen_slices(total_slices=100)
log_liver_slices(total_slices=100)
# 🐝 finish the run
wandb.finish()

In [None]:
config = {
    # data
    "num_workers": 0,


    # train settings
    "train_batch_size": 30,
    "val_batch_size": 1,
    "learning_rate": 1e-4,
    "max_epochs": 500,
    "val_interval": 2, # check validation score after n epochs
    "lr_scheduler": "cosine_decay", # just to keep track

    # Unet model (you can even use nested dictionary and this will be handled by W&B automatically)
    "model_type_liver": "unet",

    "model_params_liver": dict(spatial_dims=3,
                      in_channels=1,
                      out_channels=2,
                      channels=(16, 32, 64, 128, 256),
                      strides=(2, 2, 2, 2),
                      num_res_units=2,
                      norm=Norm.BATCH,),
    # data
    "cache_rate_liver": 1.0,
}

## Define CacheDataset and DataLoader for training and validation

Here we use CacheDataset to accelerate training and validation process, it's 10x faster than the regular Dataset.  
To achieve best performance, set `cache_rate=1.0` to cache all the data, if memory is not enough, set lower value.  
Users can also set `cache_num` instead of `cache_rate`, will use the minimum value of the 2 settings.  
And set `num_workers` to enable multi-threads during caching.  
If want to to try the regular Dataset, just change to use the commented code below.

In [None]:
train_ds_liver = CacheDataset(
    data=train_files_liver, transform=train_transforms_liver,
    cache_rate=config['cache_rate_liver'], num_workers=config['num_workers'])
# train_ds = Dataset(data=train_files, transform=train_transforms)

# use batch_size=2 to load images and use RandCropByPosNegLabeld
# to generate 2 x 4 images for network training
train_loader_liver = DataLoader(train_ds_liver, batch_size=config['train_batch_size'], shuffle=True, num_workers=config['num_workers'])

val_ds_liver = CacheDataset(
    data=val_files_liver, transform=val_transform_liver, cache_rate=config['cache_rate_liver'], num_workers=config['num_workers'])
# val_ds = Dataset(data=val_files, transform=val_transforms)
val_loader_liver = DataLoader(val_ds_liver, batch_size=config['val_batch_size'], num_workers=config['num_workers'])

## Create Model, Loss, Optimizer

In [None]:
# standard PyTorch program style: create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model_spleen = UNet(**config['model_params_spleen']).to(device)
model_liver = UNet(**config['model_params_liver']).to(device)


# loss_function_spleen = DiceLoss(to_onehot_y=True, softmax=True)
loss_function_liver = DiceLoss(to_onehot_y=True, softmax=True)

# optimizer_spleen = torch.optim.Adam(model_spleen.parameters(), lr=config['learning_rate'])
optimizer_liver = torch.optim.Adam(model_liver.parameters(), lr=config['learning_rate'])

# dice_metric_spleen = DiceMetric(include_background=False, reduction="mean")
dice_metric_liver = DiceMetric(include_background=False, reduction="mean")


# scheduler_spleen = CosineAnnealingLR(optimizer_spleen, T_max=config['max_epochs'], eta_min=1e-9)
scheduler_liver = CosineAnnealingLR(optimizer_liver, T_max=config['max_epochs'], eta_min=1e-9)

In [None]:
for name, param in model_liver.named_parameters():
    print(name)

## Execute a typical PyTorch training process

In [None]:
# 🐝 initialize a wandb run
wandb.init(
    project="SegViz_Liveronly",
    config=config
)

# 🐝 log gradients of the model to wandb
wandb.watch(model_liver, log_freq=100)

max_epochs = 500
val_interval = 2

best_metric_spleen = -1
best_metric_liver = -1

best_metric_epoch_spleen = -1
best_metric_epoch_liver = -1

epoch_loss_values_spleen = []
metric_values_spleen = []

epoch_loss_values_liver = []
metric_values_liver = []



post_pred_liver = Compose([AsDiscrete(argmax=True, to_onehot=2)])
post_label_liver = Compose([AsDiscrete(to_onehot=2)])

for epoch in range(max_epochs):
  epoch_loss_spleen = 0
  epoch_loss_liver = 0

  step_0 = 0
  step_1 = 0
  
  # For one epoch
  
  print("-" * 10)
  print(f"epoch {epoch + 1}/{max_epochs}")
  
  model_liver.train()
  
  # One forward pass of the liver data through the liver UNet
  for batch_data_liver in train_loader_liver:
      step_1 += 1
      inputs_liver, labels_liver = (
          batch_data_liver["image"].to(device),
          batch_data_liver["label"].to(device),
      )
      labels_liver[labels_liver != 1] = 0
      optimizer_liver.zero_grad()
      outputs_liver = model_liver(inputs_liver)
      loss_liver = loss_function_liver(outputs_liver, labels_liver)
      loss_liver.backward()
      optimizer_liver.step()
      epoch_loss_liver += loss_liver.item()
      print(
          f"{step_1}/{len(train_ds_liver) // train_loader_liver.batch_size}, "
          f"train_loss: {loss_liver.item():.4f}")
      wandb.log({"train/loss liver": loss_liver.item()})
  epoch_loss_liver /= step_1
  epoch_loss_values_liver.append(epoch_loss_liver)
  print(f"epoch {epoch + 1} average loss liver: {epoch_loss_liver:.4f}")
  
  scheduler_liver.step()

  wandb.log({"train/loss_epoch liver": epoch_loss_liver})
      # 🐝 log learning rate after each epoch to wandb
  wandb.log({"learning_rate liver": scheduler_liver.get_lr()[0]})

  # Validation 
  if (epoch + 1) % val_interval == 0:
    # model_spleen.eval()
    model_liver.eval()
    with torch.no_grad():
      # Validation forward Liver 
      for val_data_liver in val_loader_liver:
          val_inputs_liver, val_labels_liver = (
              val_data_liver["image"].to(device),
              val_data_liver["label"].to(device),
          )
          val_labels_liver[val_labels_liver != 1] = 0
          roi_size = (160, 160, 160)
          sw_batch_size = 4
          val_outputs_liver = sliding_window_inference(
              val_inputs_liver, roi_size, sw_batch_size, model_liver)
          val_outputs_liver = [post_pred_liver(i) for i in decollate_batch(val_outputs_liver)]
          val_labels_liver = [post_label_liver(i) for i in decollate_batch(val_labels_liver)]
          # compute metric for current iteration
          dice_metric_liver(y_pred=val_outputs_liver, y=val_labels_liver)

      # aggregate the final mean dice result
      metric_liver = dice_metric_liver.aggregate().item()
      wandb.log({"val/dice_metric liver": metric_liver})

      scheduler_liver.step(metric_liver)
      # reset the status for next validation round
      dice_metric_liver.reset()

      metric_values_liver.append(metric_liver)
      if metric_liver > best_metric_liver:
          best_metric_liver = metric_liver
          best_metric_epoch_liver = epoch + 1
          torch.save(model_liver.state_dict(), os.path.join(
              root_dir, "best_metric_model_liveronly_128_rescrop.pth"))
          print("saved new best metric model for liver dataset")
      print(
          f"current epoch: {epoch + 1} current mean dice for liver: {metric_liver:.4f}"
          f"\nbest mean dice for liver: {best_metric_liver:.4f} "
          f"at epoch: {best_metric_epoch_liver}"
      )


# wandb.log({"best_dice_metric spleen": best_metric_spleen, "best_metric_epoch spleen": best_metric_epoch_spleen})
wandb.log({"best_dice_metric liver": best_metric_liver, "best_metric_epoch liver": best_metric_epoch_liver})

best_model_path_liver = os.path.join(root_dir, "best_metric_model_liveronly_128_rescrop.pth")
model_artifact = wandb.Artifact(
            "unet", type="model",
            description="Segviz branch liver",
            metadata=dict(config['model_params_liver']))
model_artifact.add_file(best_model_path_liver)
wandb.log_artifact(model_artifact)

## Plot the loss and metric

In [None]:
plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Epoch Average Loss")
x = [i + 1 for i in range(len(epoch_loss_values_liver))]
y = epoch_loss_values_liver
plt.xlabel("epoch")
plt.plot(x, y)
plt.subplot(1, 2, 2)
plt.title("Val Mean Dice")
x = [val_interval * (i + 1) for i in range(len(metric_values_liver))]
y = metric_values_liver
plt.xlabel("epoch")
plt.plot(x, y)
plt.show()

## Check best model output with the input image and label

In [None]:
model_liver.load_state_dict(torch.load(
    os.path.join(root_dir, "best_metric_model_liveronly_128_rescrop.pth")))
model_liver.eval()
with torch.no_grad():
    for i, val_data in enumerate(val_loader_liver):
        roi_size = (160, 160, 160)
        sw_batch_size = 4
        val_outputs = sliding_window_inference(
            val_data["image"].to(device), roi_size, sw_batch_size, model_liver
        )
        # plot the slice [:, :, 80]
        plt.figure("check", (18, 6))
        plt.subplot(1, 3, 1)
        plt.title(f"image {i}")
        plt.imshow(val_data["image"][0, 0, :, :, 200], cmap="gray")
        plt.subplot(1, 3, 2)
        plt.title(f"label {i}")
        plt.imshow(val_data["label"][0, 0, :, :, 200])
        plt.subplot(1, 3, 3)
        plt.title(f"output {i}")
        plt.imshow(torch.argmax(
            val_outputs, dim=1).detach().cpu()[0, :, :, 200])
        plt.show()
        if i == 2:
            break