# Spleen 3D segmentation with MONAI

This tutorial shows how to run SageMaker managed training using MONAI for 3D Segmentation.
This tutorial shows how to run SageMaker managed inference after model training. 



This notebook and train.py script in source folder were derived from [spleen_segmentation_3d notebook](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/spleen_segmentation_3d.ipynb)

Key features demonstrated here:
1. SageMaker managed training with S3 integration
2. SageMaker hosted inference 

The Spleen dataset can be downloaded from https://registry.opendata.aws/msd/.

![spleen](http://medicaldecathlon.com/img/spleen0.png)

Target: Spleen  
Modality: CT  
Size: 61 3D volumes (41 Training + 20 Testing)  
Source: Memorial Sloan Kettering Cancer Center  
Challenge: Large ranging foreground size
    

### Install and import monai libraries 

In [26]:
!pip install  "monai[all]==0.8.0"
!python -c "import monai" || pip install -q "monai-weekly[gdown, nibabel, tqdm, ignite]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline



Error: mkl-service + Intel(R) MKL: MKL_THREADING_LAYER=INTEL is incompatible with libgomp.so.1 library.
	Try to import numpy first or set the threading layer accordingly. Set MKL_SERVICE_FORCE_INTEL to force it.


In [27]:
from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    ScaleIntensityRanged,
    Spacingd,
    EnsureTyped,
    EnsureType,
    Invertd,
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob




In [28]:
#import sagemaker libraries and get environment variables
import sagemaker 
from sagemaker import get_execution_role
role = get_execution_role()
sess = sagemaker.Session()
region = sess.boto_session.region_name
bucket = sess.default_bucket()

## Prepare the dataset: Spleen dataset
+ Download the Spleen dataset if it is not available locally
+ Transform the images using Compose from MONAI
+ Visualize the image 

In [29]:
# Download the images
resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar"
md5 = "410d4a301da4e5b2f6f86ec3ddba524e"
compressed_file = "./Task09_Spleen.tar"

MONAILabelServerIP = "../Spleen3D" ## you can change it to IP address of the MONAI Label Server if deployed
data_dir = MONAILabelServerIP 

if not os.path.exists(data_dir):
    download_and_extract(resource, compressed_file, data_dir+'/datasets', md5)

In [None]:
## transform the images through Compose
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),  ## keys include image and label with image first
        EnsureChannelFirstd(keys=["image", "label"]),
        Spacingd(keys=["image", "label"], pixdim=(
            1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        ScaleIntensityRanged(
            keys=["image"], a_min=-57, a_max=164,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        EnsureTyped(keys=["image", "label"]),
    ]
)

In [None]:
## divide the images into training and testing dataset

from monai.apps import download_and_extract
import os
import glob

train_images = sorted(
    glob.glob(os.path.join(data_dir, "datasets/Task09_Spleen/imagesTr", "*.nii.gz")))
train_labels = sorted(
    glob.glob(os.path.join(data_dir, "datasets/Task09_Spleen/labelsTr", "*.nii.gz")))
data_dicts = [
    {"image": image_name, "label": label_name}
    for image_name, label_name in zip(train_images, train_labels)
]
train_files, val_files = data_dicts[:-1], data_dicts[-1:]

In [None]:
# Visualization

check_ds = Dataset(data=val_files, transform=val_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)
image, label = (check_data["image"][0][0], check_data["label"][0][0])
print(f"image shape: {image.shape}, label shape: {label.shape}")
# plot only the slice [:, :, 80]
plt.figure("check", (12, 6))
plt.subplot(1, 2, 1)
plt.title("image")
plt.imshow(image[:, :, 80], cmap="gray")
plt.subplot(1, 2, 2)
plt.title("label")
plt.imshow(label[:, :, 80])
plt.show()

## Model training 

+ Divide the dataset into training and testing
+ Upload the dataset into S3 
+ SageMaker training job

In [None]:
## copy dataset for training 
!mkdir -p ../Spleen3D/train/imagesTr
!mkdir -p ../Spleen3D/train/labelsTr

## folder for testing dataset
!mkdir -p ../Spleen3D/test/imagesTr
!mkdir -p ../Spleen3D/test/labelsTr

In [None]:
## copy dataset for training 
for file in train_files:
    image = file['image']
    image_dest = "../Spleen3D/train/imagesTr"
    label = file['label']
    label_dest = "../Spleen3D/train/labelsTr"
    shutil.copy(image,image_dest)
    shutil.copy(label,label_dest)

In [None]:
## copy dataset for testing  
for file in val_files:
    image = file['image']
    image_dest = "../Spleen3D/test/imagesTr"
    label = file['label']
    label_dest = "../Spleen3D/test/labelsTr"
    shutil.copy(image,image_dest)
    shutil.copy(label,label_dest)

In [None]:
## upload the dataset to S3
prefix="MONAI-Segmentation"
bucket = sess.default_bucket()
## upload training dataset
S3_inputs = sess.upload_data(
    path="../Spleen3D/train",
    key_prefix=prefix+"/train",
    bucket=bucket 
)

## upload testing dataset
S3_test = sess.upload_data(
    path="../Spleen3D/test",
    key_prefix=prefix+"/test",
    bucket=bucket 
)


### SageMaker training job

In [None]:
%%time 
import sagemaker
from sagemaker.inputs import FileSystemInput
from sagemaker.pytorch import PyTorch

metrics=[
   {'Name': 'train:average epoch loss', 'Regex': 'average loss: ([0-9\\.]*)'},
   {'Name': 'train:current mean dice', 'Regex': 'current mean dice: ([0-9\\.]*)'},
   {'Name': 'train:best mean dice', 'Regex': 'best mean dice: ([0-9\\.]*)'}
]

estimator = PyTorch(source_dir='code',
                    entry_point='train.py',
                    role=role,
                    framework_version='1.6.0',
                    py_version='py3',
                    instance_count=1,
                    instance_type='ml.p2.xlarge',
                    hyperparameters={
                       "seed": 123,
                       "lr": 0.001,
                       "epochs": 10
                    },
                    metric_definitions=metrics,
#                     ### spot instance training ###
#                    use_spot_instances=True,
#                     max_run=2400,
#                     max_wait=2400
                )


estimator.fit(S3_inputs)


## Inference 

+ deploy the model with customized inference script
+ inference with testing image in S3
+ visualization the results
+ deployment with trained estimator or the model artifact in S3.


### Type of inference
SageMaker provides the flexibility of deploying endpoints with the following 3 options: 
+ realtime inference
+ asychronous inference
+ batch transform

Customers can select the corresponding option based on the budget and requirements. 

I am going to demonstrate the realtime and asynchronous deployment options in the following scripts 
Due to size of the model output, we chose to use asychronous inference to save output in S3 bucket

In [None]:
## you can get the model artifact here
model_data=estimator.__dict__['output_path']+estimator.__dict__['_current_job_name']+'/output/model.tar.gz'

## a neater version to fetch model artifact
model_data=estimator.latest_training_job.describe()["ModelArtifacts"]["S3ModelArtifacts"]

In [None]:
model_data

In [None]:
from sagemaker.pytorch.model import PyTorchModel

model = PyTorchModel(
    entry_point="inference.py", ## inference code with customization
    role=role,
    model_data=model_data,
    framework_version="1.5.0",
    py_version="py3",
)

predictor = model.deploy(
    initial_instance_count = 1, 
    instance_type = 'ml.m5.2xlarge',
    serializer=sagemaker.serializers.JSONSerializer(),
    deserializer=sagemaker.deserializers.JSONDeserializer()
)

In [30]:
%%time
payload={"bucket": 'sagemaker-us-east-1-741261399688',
    "key":"MONAI-Segmentation/test/imagesTr",
    "file":'spleen_9.nii.gz',
    "nslice":80}


response=predictor.predict(payload)

CPU times: user 24.1 ms, sys: 0 ns, total: 24.1 ms
Wall time: 9.25 s


In [31]:
response

{'s3_path': 'S3://sagemaker-us-east-1-741261399688/inference_output/results_slides1651611874.4674125.json',
 'pred': [[0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
  

In [32]:
%%time
payload={"bucket": 'sagemaker-us-east-1-741261399688',
    "key":"MONAI-Segmentation/test/imagesTr",
    "file":'spleen_9.nii.gz',
    "nslice":100}


response=predictor.predict(payload)

CPU times: user 4.1 ms, sys: 0 ns, total: 4.1 ms
Wall time: 9.18 s


In [33]:
response

{'s3': 'S3://sagemaker-us-east-1-741261399688/inference_output/results_slides1651611884.3146865.json'}

In [36]:
%%time
payload={"bucket": 'sagemaker-us-east-1-741261399688',
    "key":"MONAI-Segmentation/test/imagesTr",
    "file":'spleen_9.nii.gz',
    "nslice":200}


response=predictor.predict(payload)

CPU times: user 17.7 ms, sys: 0 ns, total: 17.7 ms
Wall time: 10 s


In [37]:
response

{'s3': 'S3://sagemaker-us-east-1-741261399688/inference_output/results_slides1651617715.1333244.json'}

## Inference 

+ deploy the model with customized inference script
+ inference with testing image in S3
+ visualization the results
+ deployment with trained estimator or the model artifact in S3.


### Type of inference
SageMaker provides the flexibility of deploying endpoints with the following 3 options: 
+ realtime inference
+ asychronous inference
+ batch transform

Customers can select the corresponding option based on the budget and requirements. 

I am going to demonstrate the realtime and asynchronous deployment options in the following scripts 

In [None]:
## realtime endpoint

predictor = estimator.deploy(initial_instance_count=1,entry_point='inference.py', instance_type='ml.m5.2xlarge',serializer=sagemaker.serializers.JSONSerializer(),deserializer=sagemaker.deserializers.JSONDeserializer())

In [None]:
%%time
payload={"bucket": 'sagemaker-us-east-1-741261399688',
    "key":"MONAI-Segmentation/test/imagesTr",
    "file":'spleen_9.nii.gz',
    "nslice":80}
response=predictor.predict(payload)

### Asynchronous inference endpoint



In [None]:
from sagemaker.async_inference import AsyncInferenceConfig

s3_bucket=bucket

bucket_prefix = 'Inference_output' ## output for the inference
output_path=f"s3://{s3_bucket}/{bucket_prefix}/segmentation/output"

async_config = AsyncInferenceConfig(
    output_path=output_path,
    max_concurrent_invocations_per_instance=5
)

In [None]:
predictor_async = estimator.deploy(initial_instance_count=1,
                             entry_point='inference_async.py', 
                             instance_type='ml.m5.2xlarge',
                             serializer=sagemaker.serializers.JSONSerializer(),
                             deserializer=sagemaker.deserializers.JSONDeserializer(),
                             async_inference_config=async_config)


In [None]:
%%time
payload={"bucket": 'sagemaker-us-east-1-741261399688',
    "key":"MONAI-Segmentation/test/imagesTr",
    "file":'spleen_9.nii.gz',
    "nslice":200}


response=predictor_async.predict(payload)

In [None]:
%%time
payload={"bucket": 'sagemaker-us-east-1-741261399688',
    "key":"MONAI-Segmentation/test/imagesTr",
    "file":'spleen_9.nii.gz',
    "nslice":70}


response=predictor_async.predict(payload)

In [None]:
import json
with open('input.json', 'w') as f:
    json.dump(payload, f)
    
## upload testing dataset
input_payload = sess.upload_data(
    path="input.json",
    key_prefix=prefix+"/test_json",
    bucket=bucket 
)

In [None]:
input_payload

In [None]:
%%time
response=predictor_async.predict(input_path=input_payload)

In [None]:
# Tears down the SageMaker endpoint and endpoint configuration if no longer needed
predictor_async.delete_endpoint()

## Visulaize the result

In [None]:
torch.Tensor(response["pred"])

In [None]:
import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



with torch.no_grad():
    for i, val_data in enumerate(val_loader):
        
        # plot the slice [:, :, 80]
        plt.figure("check", (18, 6))
        plt.subplot(1, 3, 1)
        plt.title(f"image {i}")
        plt.imshow(val_data["image"][0, 0, :, :, 80], cmap="gray")
        plt.subplot(1, 3, 2)
        plt.title(f"label {i}")
        plt.imshow(val_data["label"][0, 0, :, :, 80])
        plt.subplot(1,3, 3)
        plt.title(f"output {i}")
        plt.imshow(torch.Tensor(response["pred"]))

        plt.show()

## Clean up the resources

+ delete all the endpoints to save cost

In [None]:
import boto3
client = boto3.client('sagemaker')
endpoints=client.list_endpoints()['Endpoints']
endpoints

In [None]:
for endpoint in endpoints:
    response = client.delete_endpoint(
        EndpointName=endpoint['EndpointName']
    )