# Spleen 3D segmentation with MONAI

This tutorial shows how to integrate MONAI into an existing PyTorch medical DL program.

And easily use below features:
1. Transforms for dictionary format data.
1. Load Nifti image with metadata.
1. Add channel dim to the data if no channel dimension.
1. Scale medical image intensity with expected range.
1. Crop out a batch of balanced images based on positive / negative label ratio.
1. Cache IO and transforms to accelerate training and validation.
1. 3D UNet model, Dice loss function, Mean Dice metric for 3D segmentation task.
1. Sliding window inference method.
1. Deterministic training for reproducibility.

The Spleen dataset can be downloaded from http://medicaldecathlon.com/.

![spleen](http://medicaldecathlon.com/img/spleen0.png)

Target: Spleen  
Modality: CT  
Size: 61 3D volumes (41 Training + 20 Testing)  
Source: Memorial Sloan Kettering Cancer Center  
Challenge: Large ranging foreground size

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/master/3d_segmentation/spleen_segmentation_3d.ipynb)

## Setup environment

In [None]:
!pip install 'monai[all]'

In [None]:
!python -c "import monai" || pip install -q "monai-weekly[gdown, nibabel, tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

import tensorflow as tf
tf.__version__

In [None]:
from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    RandAffined,
    ScaleIntensityRanged,
    Spacingd,
    EnsureTyped,
    EnsureType,
    Invertd,
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet,UNETR
#from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
import numpy as np

## Setup imports

In [None]:
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


print_config()

## Setup data directory

You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  
This allows you to save results and reuse downloads.  
If not specified a temporary directory will be used.

In [None]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

## Download dataset

Downloads and extracts the dataset.  
The dataset comes from http://medicaldecathlon.com/.

In [None]:
#resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar"
#md5 = "410d4a301da4e5b2f6f86ec3ddba524e"

#compressed_file = os.path.join(root_dir, "Task09_Spleen.tar")
#data_dir = os.path.join(root_dir, "Task09_Spleen")
#if not os.path.exists(data_dir):
#    download_and_extract(resource, compressed_file, root_dir, md5)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Set MSD Spleen dataset path

In [None]:
#del train_files
#del val_files

train_images = sorted(
    glob.glob(os.path.join("/content/drive/MyDrive/Head _Neck Hector/NewHector_final/New Fusion/method8_10000_100_1", "*.nii.gz")))

#train_images = sorted(
#    glob.glob(os.path.join("/content/drive/MyDrive/Head _Neck Hector/NewHector_final/New Fusion/method9_1_1_1", "*.nii.gz")))

#train_images = sorted(
#    glob.glob(os.path.join("/content/drive/MyDrive/HN_Cancer/Head _Neck Hector/NewHector_final/Fusions/method13", "*.nii.gz")))

#train_images = sorted(
#    glob.glob(os.path.join("/content/drive/MyDrive/Head _Neck Hector/NewHector_final/New Fusion/method2_500_0.1_1", "*.nii.gz")))


#train_images = sorted(
#    glob.glob(os.path.join("/content/drive/MyDrive/Head _Neck Hector/NewHector_final/New Fusion/method7_100_50_10", "*.nii.gz")))

train_images = sorted(
    glob.glob(os.path.join("/content/drive/MyDrive/HN_Cancer/Head _Neck Hector/NewHector_final/New Fusion/method8_1000_100_6", "*.nii.gz")))


#train_images = sorted(
#    glob.glob(os.path.join("/content/drive/MyDrive/Head _Neck Hector/NewHector_final/Fusions/method1", "*.nii.gz")))


#train_images = sorted(
#    glob.glob(os.path.join("/content/drive/MyDrive/HN_Cancer/Head _Neck Hector/NewHector_final/New Fusion/method9_1_1_1", "*.nii.gz")))


train_labels = sorted(
    glob.glob(os.path.join("/content/drive/MyDrive/HN_Cancer/Head _Neck Hector/NewHector_final/gtv", "*.nii.gz")))
data_dicts = [
    {"image": image_name, "label": label_name}
    for image_name, label_name in zip(train_images, train_labels)
]
#train_files, val_files = data_dicts[0:210], data_dicts[211:223]
#train_files, test_files, val_files = data_dicts[0:50],data_dicts[101:110], data_dicts[120:125]
train_files, test_files, val_files = data_dicts[0:135],data_dicts[136:200], data_dicts[201:223]
#train_files, test_files, val_files = data_dicts[0:160],data_dicts[161:223]


## Set deterministic training for reproducibility

In [None]:
set_determinism(seed=0)

In [None]:
train_images[0:4]

#train_labels[0:4]

## Setup transforms for training and validation

Here we use several transforms to augment the dataset:
1. `LoadImaged` loads the spleen CT images and labels from NIfTI format files.
1. `AddChanneld` as the original data doesn't have channel dim, add 1 dim to construct "channel first" shape.
1. `Spacingd` adjusts the spacing by `pixdim=(1.5, 1.5, 2.)` based on the affine matrix.
1. `Orientationd` unifies the data orientation based on the affine matrix.
1. `ScaleIntensityRanged` extracts intensity range [-57, 164] and scales to [0, 1].
1. `CropForegroundd` removes all zero borders to focus on the valid body area of the images and labels.
1. `RandCropByPosNegLabeld` randomly crop patch samples from big image based on pos / neg ratio.  
The image centers of negative samples must be in valid body area.
1. `RandAffined` efficiently performs `rotate`, `scale`, `shear`, `translate`, etc. together based on PyTorch affine transform.
1. `EnsureTyped` converts the numpy array to PyTorch Tensor for further steps.

In [None]:
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Spacingd(keys=["image", "label"], pixdim=(
            1.5, 1.5, 1), mode=("bilinear", "nearest")),
         #1.5, 1.5, 2), mode=("bilinear", "nearest")),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        ScaleIntensityRanged(
            keys=["image"], a_min=0, a_max=1.0,
            #keys=["image"], a_min=-3100, a_max=3100,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=(96, 96, 96),
            pos=1,
            neg=1,
            num_samples=4,
            image_key="image",
            image_threshold=0,
        ),
        # user can also add other random transforms
         #RandAffined(
         #    keys=['image', 'label'],
         #    mode=('bilinear', 'nearest'),
         #    prob=1.0, spatial_size=(96, 96, 96),
         #    rotate_range=(0, 0, np.pi/15),
         #    scale_range=(0.1, 0.1, 0.1)),
        #EnsureTyped(keys=["image", "label"]),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Spacingd(keys=["image", "label"], pixdim=(
            1.5, 1.5, 1), mode=("bilinear", "nearest")),
        # 1.5, 1.5, 2), mode=("bilinear", "nearest")),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        ScaleIntensityRanged(
            keys=["image"], a_min=0, a_max=1.0,
            #keys=["image"], a_min=-3100, a_max=3100,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        EnsureTyped(keys=["image", "label"]),
    ]
)

## Check transforms in DataLoader

In [None]:
check_ds = Dataset(data=val_files, transform=val_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)

image, label = (check_data["image"][0][0], check_data["label"][0][0])
print(f"image shape: {image.shape}, label shape: {label.shape}")
# plot the slice [:, :, 80]
plt.figure("check", (12, 6))
plt.subplot(1, 2, 1)
plt.title("image")
plt.imshow(image[:, :, 1], cmap="gray")
plt.subplot(1, 2, 2)
plt.title("label")
plt.imshow(label[:, :, 1])
plt.show()

In [None]:
check_loader

## Define CacheDataset and DataLoader for training and validation

Here we use CacheDataset to accelerate training and validation process, it's 10x faster than the regular Dataset.  
To achieve best performance, set `cache_rate=1.0` to cache all the data, if memory is not enough, set lower value.  
Users can also set `cache_num` instead of `cache_rate`, will use the minimum value of the 2 settings.  
And set `num_workers` to enable multi-threads during caching.  
If want to to try the regular Dataset, just change to use the commented code below.

In [14]:
train_ds = CacheDataset(
    data=train_files, transform=train_transforms,cache_rate=1, num_workers=2)
# train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)

# use batch_size=2 to load images and use RandCropByPosNegLabeld
# to generate 2 x 4 images for network training
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=2)



test_ds = CacheDataset(
    data=test_files, transform=val_transforms, cache_rate=1, num_workers=2)
# val_ds = Dataset(data=val_files, transform=val_transforms)
test_loader = DataLoader(test_ds, batch_size=1, num_workers=2)



val_ds = CacheDataset(
    data=val_files, transform=val_transforms, cache_rate=1, num_workers=2)
# val_ds = Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=2)

Loading dataset: 100%|██████████| 135/135 [02:05<00:00,  1.07it/s]
Loading dataset: 100%|██████████| 64/64 [00:57<00:00,  1.11it/s]
Loading dataset: 100%|██████████| 22/22 [00:18<00:00,  1.19it/s]


## Create Model, Loss, Optimizer

In [16]:
# standard PyTorch program style: create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda:0")
#device = torch.device("cpu:0")

model = UNet(
    dimensions=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=20,
    norm=Norm.BATCH,
).to(device)

#model = UNETR(
#    in_channels=1,
#    out_channels=2,
#    img_size=(96, 96, 96),
#    feature_size=16,
#    hidden_size=768,
#    mlp_dim=3072,
#    num_heads=12,
#    pos_embed="perceptron",
#    norm_name="instance",
#    res_block=True,
#    dropout_rate=0.0,
#).to(device)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.99)
dice_metric = DiceMetric(include_background=False, reduction="mean")
#dice_metric = DiceMetric(include_background=False,reduction="mean_channel")
#model.load_state_dict(torch.load("/content/drive/MyDrive/Fatan_Hecktor/best_metric_model_F8_10000_100_1_croped_final_20layer22.pth"))
model.load_state_dict(torch.load("/content/drive/MyDrive/Fatan_Hecktor/best_metric_model_method8_10000_100_1_croped_start_again_20layer.pth"))

#model.load_state_dict(torch.load("/content/drive/MyDrive/Fatan_Hecktor/best_metric_model_F8_10000_100_1_croped_final_20layer26.pth"))
#model.load_state_dict(torch.load("/content/drive/MyDrive/Fatan_Hecktor/best_metric_model_F8_10000_100_1_croped_final_20layer28.pth"))
#model.load_state_dict(torch.load("/content/drive/MyDrive/Fatan_Hecktor/best_metric_model_F8_10000_100_1_croped_final_20layer33.pth"))
#model.load_state_dict(torch.load("/content/drive/MyDrive/Fatan_Hecktor/best_metric_model_F9_100_50_1_croped_final_20layer26.pth"))
#model.load_state_dict(torch.load("/content/drive/MyDrive/Fatan_Hecktor/best_metric_model_F13_croped_final_20layer2.pth"))
#model.load_state_dict(torch.load("/content/drive/MyDrive/Fatan_Hecktor/best_metric_model_F13_croped_final_20layer.pth"))
#model.load_state_dict(torch.load("/content/drive/MyDrive/Fatan_Hecktor/best_metric_model_F9_1_1_1_croped_final_20layer.pth"))
#model.load_state_dict(torch.load("/content/drive/MyDrive/Fatan_Hecktor/best_metric_model_F8_10000_100_1_croped_final_20layer.pth"))
#model.load_state_dict(torch.load("/content/drive/MyDrive/Fatan_Hecktor/best_metric_model_F8new_croped_final.pth"))

<All keys matched successfully>

## Execute a typical PyTorch training process

In [18]:
#import cv2
#import numpy as np
import random

max_epochs = 2000
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=True, n_classes=2)])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=True, n_classes=2)])

for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        #lr=random.random()/100.0
        #optimizer = torch.optim.Adam(model.parameters(), lr)
        step += 1
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device),
        )

        #print(inputs.min())
        #print(inputs.max())
        #print(np.shape(inputs))
        #d1,d2,d3,d4,d5=np.shape(inputs)


        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        #print(
        #    f"{step}/{len(train_ds) // train_loader.batch_size}, "
        #    f"train_loss: {loss.item():.4f}")
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_labels = (
                    val_data["image"].to(device),
                    val_data["label"].to(device),
                )
                #roi_size = (160, 160, 160)
                #roi_size = (96, 96, 96)
                roi_size = (96, 96, 96)
                #roi_size = (96, 96, 96)
                sw_batch_size = 4
                val_outputs = sliding_window_inference(
                    val_inputs, roi_size, sw_batch_size, model)
                val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                # compute metric for current iteration
                dice_metric(y_pred=val_outputs, y=val_labels)

            # aggregate the final mean dice result
            metric = dice_metric.aggregate().item()
            # reset the status for next validation round
            dice_metric.reset()

            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                #torch.save(model.state_dict(),"/content/drive/MyDrive/Fatan_Hecktor/best_metric_model_F8_10000_100_1_croped_final_20layer22.pth")
                torch.save(model.state_dict(),"/content/drive/MyDrive/Fatan_Hecktor/best_metric_model_method8_1000_100_6_croped_start_again_20layer.pth")
            #    torch.save(model.state_dict(), os.path.join(
            #        root_dir, "best_metric_model.pth"))
                print("saved new best metric model")
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f"\nbest mean dice: {best_metric:.4f} "
                f"at epoch: {best_metric_epoch}"
            )

----------
epoch 1/2000
epoch 1 average loss: 0.0814
----------
epoch 2/2000
epoch 2 average loss: 0.0730


NameError: ignored

In [None]:
print(
    f"train completed, best_metric: {best_metric:.4f} "
    f"at epoch: {best_metric_epoch}")

## Plot the loss and metric

In [None]:
plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Epoch Average Loss")
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
plt.xlabel("epoch")
plt.plot(x, y)
plt.subplot(1, 2, 2)
plt.title("Val Mean Dice")
x = [val_interval * (i + 1) for i in range(len(metric_values))]
y = metric_values
plt.xlabel("epoch")
plt.plot(x, y)
plt.show()

## Check best model output with the input image and label

In [None]:
#model.load_state_dict(torch.load(
#    os.path.join(root_dir, "/content/drive/MyDrive/best_metric_model.pth")))

model.load_state_dict(torch.load("/content/drive/MyDrive/Fatan_Hecktor/best_metric_method7_100_50_10_croped_start_again_20layer.pth"))

model.eval()
with torch.no_grad():
    for i, test_data in enumerate(test_loader):
        roi_size = (96, 96, 96)
        sw_batch_size = 4
        test_outputs = sliding_window_inference(
            test_data["image"].to(device), roi_size, sw_batch_size, model
        )
        # plot the slice [:, :, 80]
        plt.figure("check", (18, 6))
        plt.subplot(1, 3, 1)
        plt.title(f"image {i}")
        plt.imshow(test_data["image"][0, 0, :, :, 100], cmap="gray")
        plt.subplot(1, 3, 2)
        plt.title(f"label {i}")
        plt.imshow(test_data["label"][0, 0, :, :, 100])
        plt.subplot(1, 3, 3)
        plt.title(f"output {i}")
        plt.imshow(torch.argmax(
            test_outputs, dim=1).detach().cpu()[0, :, :, 100])
        plt.show()
        if i == 2:
            break

## Evaluation on original image spacings

In [20]:
test_org_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Spacingd(keys=["image"], pixdim=(
            1.5, 1.5, 1.0), mode="bilinear"),
        Orientationd(keys=["image"], axcodes="RAS"),
        ScaleIntensityRanged(
            keys=["image"], a_min=0.0, a_max=1.0,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["image"], source_key="image"),
        EnsureTyped(keys=["image", "label"]),
    ]
)

test_org_ds = Dataset(
    data=test_files, transform=test_org_transforms)
test_org_loader = DataLoader(test_org_ds, batch_size=1, num_workers=4)

post_transforms = Compose([
    EnsureTyped(keys="pred"),
    Invertd(
        keys="pred",
        transform=test_org_transforms,
        orig_keys="image",
        meta_keys="pred_meta_dict",
        orig_meta_keys="image_meta_dict",
        meta_key_postfix="meta_dict",
        nearest_interp=False,
        to_tensor=True,
    ),
    AsDiscreted(keys="pred", argmax=True, to_onehot=True, n_classes=2),
    AsDiscreted(keys="label", to_onehot=True, n_classes=2),
])

In [34]:
dice_metric = DiceMetric(include_background=False,reduction="mean")
model.load_state_dict(torch.load(os.path.join("/content/drive/MyDrive/Fatan_Hecktor/best_metric_model_method8_1000_100_6_croped_start_again_20layer.pth")))
model.eval()

with torch.no_grad():
    for test_data in test_org_loader:
        test_inputs = test_data["image"].to(device)
        roi_size = (96, 96, 96)
        sw_batch_size = 1
        test_data["pred"] = sliding_window_inference(
            test_inputs, roi_size, sw_batch_size, model)
        test_data = [post_transforms(i) for i in decollate_batch(test_data)]
        test_outputs, test_labels = from_engine(["pred", "label"])(test_data)
        # compute metric for current iteration
        dice_metric(y_pred=test_outputs, y=test_labels)

    # aggregate the final mean dice result
    metric_org = dice_metric.aggregate().item()
    # reset the status for next validation round
    dice_metric.reset()

print("Metric on original image spacing: ", metric_org)

Metric on original image spacing:  0.5795199871063232


## Cleanup data directory

Remove directory if a temporary was used.

In [None]:
if directory is None:
    shutil.rmtree(root_dir)