__Microsoft Computer Vision Accelerator for Business Solutions__
# Brain tumor 3D segmentation with AzureML and MONAI (BRATS21)

Glioma brain tumors are among the most aggressive and lethal types of brain tumors. They can cause a range of symptoms, including headaches, seizures, and difficulty with speech and movement. Gliomas can be difficult to diagnose and treat, and early detection is critical for improving patient outcomes.

Computer vision AI has emerged as a promising tool for supporting the diagnosis and treatment of glioma brain tumors. AI algorithms can analyze medical images of the brain and identify the location and extent of tumors with a high degree of accuracy. This can help clinicians make more informed decisions about treatment options, such as surgery or radiation therapy, and monitor the progress of the disease over time. Additionally, AI algorithms can help researchers better understand the underlying biology of gliomas and develop new therapies for this challenging disease.

This demo is based on the [MONAI 3d brain tumor segmentation tutorial](https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/swin_unetr_brats21_segmentation_3d.ipynb) and shows how to construct a training workflow of multi-labels segmentation task.

The sub-regions considered for evaluation in the BraTS 21 challenge are the "enhancing tumor" (ET), the "tumor core" (TC), and the "whole tumor" (WT). The ET is described by areas that show hyper-intensity in T1Gd when compared to T1, but also when compared to “healthy” white matter in T1Gd. The TC describes the bulk of the tumor, which is what is typically resected. The TC entails the ET, as well as the necrotic (NCR) parts of the tumor. The appearance of NCR is typically hypo-intense in T1-Gd when compared to T1. The WT describes the complete extent of the disease, as it entails the TC and the peritumoral edematous/invaded tissue (ED), which is typically depicted by the hyper-intense signal in FLAIR [BraTS 21].

![image](./media/fig_brats21.png)

This notebook has been developed and tested with VSCode connected to an AzureML `STANDARD_D13_V2` Compute Instance using the `azureml_py310_sdkv2` kernel.

## References

[1]: Hatamizadeh, A., Nath, V., Tang, Y., Yang, D., Roth, H. and Xu, D., 2022. Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images. arXiv preprint arXiv:2201.01266.

[2]: Tang, Y., Yang, D., Li, W., Roth, H.R., Landman, B., Xu, D., Nath, V. and Hatamizadeh, A., 2022. Self-supervised pre-training of swin transformers for 3d medical image analysis. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 20730-20740).

[3] U.Baid, et al., The RSNA-ASNR-MICCAI BraTS 2021 Benchmark on Brain Tumor Segmentation and Radiogenomic Classification, arXiv:2107.02314, 2021.

[4] B. H. Menze, A. Jakab, S. Bauer, J. Kalpathy-Cramer, K. Farahani, J. Kirby, et al. "The Multimodal Brain Tumor Image Segmentation Benchmark (BRATS)", IEEE Transactions on Medical Imaging 34(10), 1993-2024 (2015) DOI: 10.1109/TMI.2014.2377694

[5] S. Bakas, H. Akbari, A. Sotiras, M. Bilello, M. Rozycki, J.S. Kirby, et al., "Advancing The Cancer Genome Atlas glioma MRI collections with expert segmentation labels and radiomic features", Nature Scientific Data, 4:170117 (2017) DOI: 10.1038/sdata.2017.117

[6] S. Bakas, H. Akbari, A. Sotiras, M. Bilello, M. Rozycki, J. Kirby, et al., "Segmentation Labels and Radiomic Features for the Pre-operative Scans of the TCGA-GBM collection", The Cancer Imaging Archive, 2017. DOI: 10.7937/K9/TCIA.2017.KLXWJJ1Q



## Installs and Imports

In [None]:
# based on azureml_py310_sdkv2 kernel
# %pip install torch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0
# %pip install 'monai[nibabel, ignite, tqdm]'
# %pip install itkwidgets

In [None]:
import os
import tempfile
import base64
import json

import numpy as np
import matplotlib.pyplot as plt

import torch

import tarfile
import urllib.request

from itkwidgets import view
from ipywidgets import interact

from azure.identity import DefaultAzureCredential
from azure.ai.ml import MLClient, command, Input
from azure.ai.ml.constants import AssetTypes
from azure.ai.ml.entities import ManagedOnlineEndpoint, ManagedOnlineDeployment, Model, Environment, JobService, Data, CodeConfiguration, OnlineRequestSettings, AmlCompute
from azure.core.exceptions import ResourceNotFoundError

from monai.apps import DecathlonDataset
from monai.data import DataLoader, Dataset
from monai.transforms import Compose, LoadImaged, EnsureChannelFirstd, EnsureTyped, Orientationd, Spacingd, NormalizeIntensityd, MapTransform
from monai.visualize.utils import blend_images

## Define central variables

In [None]:
# AzureML Workspace
subscription_id = '<your Azure subscription id>'
resource_group = '<your Azure resource group name>'
workspace = '<your AzureML Workspace name'

# Training
experiment = 'brain-tumor-segmentation' # AzureML experiment name
dataset_name="BRATS2021"
train_target = 'NC96adsA100'

# Deployment
online_endpoint_name = "brain-tumor-seg-brats21"
registered_model_name = 'BRATS21'
deployment_name = 'blue'

# Visualization and validation sample
sample_image = './samples-2021/BraTS2021_00402/BraTS2021_00402_flair.nii.gz' # pick flair modality
sample_image_t1 = './samples-2021/BraTS2021_00402/BraTS2021_00402_t1.nii.gz'
sample_image_t1ce = './samples-2021/BraTS2021_00402/BraTS2021_00402_t1ce.nii.gz'
sample_image_t2 = './samples-2021/BraTS2021_00402/BraTS2021_00402_t2.nii.gz'
sample_label = './samples-2021/BraTS2021_00402/BraTS2021_00402_seg.nii.gz'

In [None]:
# Connect to AzureML Workspace
ml_client = MLClient(DefaultAzureCredential(), subscription_id, resource_group, workspace)

## Inspect sample image

In [None]:
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
    """
    Convert labels to multi channels based on brats 2021 classes:
    label 1 necrotic tumor core (NCR)
    label 2 peritumoral edematous/invaded tissue 
    label 3 is not used in the new dataset version
    label 4 GD-enhancing tumor 
    The possible classes are:
      TC (Tumor core): merge labels 1 and 4
      WT (Whole tumor): merge labels 1,2 and 4
      ET (Enhancing tumor): label 4

    """

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []
            # merge label 1 and label 4 to construct TC
            result.append(torch.logical_or(d[key] == 1, d[key] == 4))
            # merge labels 1, 2 and 4 to construct WT
            result.append(
                torch.logical_or(
                    torch.logical_or(d[key] == 1, d[key] == 2), d[key] == 4
                )
            )
            # label 4 is ET
            result.append(d[key] == 4)
            d[key] = torch.stack(result, axis=0).float()
        return d

val_transform = Compose(
[
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys="image"),
    EnsureTyped(keys=["image", "label"]),
    ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    Spacingd(
        keys=["image", "label"],
        pixdim=(1.0, 1.0, 1.0),
        mode=("bilinear", "nearest"),
    ),
    NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
])

data_list = [{'image': sample_image, 'label': sample_label}]
val_ds = Dataset(data=data_list, transform=val_transform)

img_vol = val_ds[0]["image"].numpy()
seg_vol = val_ds[0]["label"].numpy()


In [None]:
img_vol.shape

In [None]:
# Inspect 3d structure - viewer works in VSCode 

img_vol_ch = img_vol[0,:,:,:]
seg_vol_ch = seg_vol[0,:,:,:]


viewer = view(image= img_vol_ch * 255,
              label_image= seg_vol_ch * 255,
              gradient_opacity=0.4,
              background = (0.5, 0.5, 0.5))
              
viewer

In [None]:
viewer.close()

In [None]:
def show_slice(slice_index=64):

    img = np.expand_dims(img_vol[0,:,:,slice_index], 0)
    seg_ch0 = np.expand_dims(seg_vol[0,:,:,slice_index], 0)
    seg_ch1 = np.expand_dims(seg_vol[1,:,:,slice_index], 0)
    seg_ch2 = np.expand_dims(seg_vol[2,:,:,slice_index], 0)

    # Image with all segmentations overlayed
    blend = blend_images(img, seg_ch1, cmap='Blues')
    blend = blend_images(blend, seg_ch0, cmap='hsv')
    blend = blend_images(blend, seg_ch2, cmap='Greens')
    over_all = np.transpose(blend, (1,2,0))

    # Individual blends for each segmentation

    blend = blend_images(img, seg_ch0, cmap='hsv')
    over_ch0 = np.transpose(blend, (1,2,0))
    blend = blend_images(img, seg_ch1, cmap='Blues')
    over_ch1 = np.transpose(blend, (1,2,0))
    blend = blend_images(img, seg_ch2, cmap='Greens')
    over_ch2 = np.transpose(blend, (1,2,0))

    fig, ((ax1, ax2, ax3, ax4)) = plt.subplots(1, 4, figsize=(24, 8))

    ax1.imshow(over_all)
    ax1.set_title('All tumor structures')
    ax2.imshow(over_ch0)
    ax2.set_title('Tumor core')
    ax3.imshow(over_ch1)
    ax3.set_title('Whole tumor')
    ax4.imshow(over_ch2)
    ax4.set_title('Enhanceing structure')
    
    plt.tight_layout()
    plt.show()

# Use the interact function to create a slider for the slice index
_ = interact(show_slice, slice_index=(0, img_vol.shape[-1]-1))

## Create compute resources, environments and datasets
__Note:__ Creating compute resources, training/scoring environments and the dataset __need only performed once__. If you have executed these steps previously, navigate to the next section of this notebook.  

Note that we are using low priority compute in this demo as the most cost efficient option. Low priority VMs are significantly cheaper than standard dedictaed compute. However, these resources are not always available and there is a risk that a training job might be pre-empted.

In [None]:
try:
    _ = ml_client.compute.get(train_target)
    print("Found existing compute target.")
except ResourceNotFoundError:
    print("Creating a new compute target...")
    compute_config = AmlCompute(
        name=train_target,
        type="amlcompute",
        size="STANDARD_NC24RS_V3", # 4 x Tesla V100, 16 GB GPU memory each
        tier="low_priority",
        idle_time_before_scale_down=600,
        min_instances=0,
        max_instances=2,
    )
    ml_client.begin_create_or_update(compute_config)

In [None]:
training_environment = Environment(
    image="mcr.microsoft.com/azureml/" + "openmpi4.1.0-cuda11.1-cudnn8-ubuntu20.04:latest",
    conda_file="./src/train-env.yml",
    name="monai-multigpu-azureml",
    description="Parallel PyTorch training on AzureML with MONAI")

ml_client.environments.create_or_update(training_environment)

In [None]:
scoring_environment = Environment(
    image="mcr.microsoft.com/azureml/" + "openmpi4.1.0-cuda11.1-cudnn8-ubuntu20.04:latest",
    conda_file="./src/scoring-env.yaml",
    name="brats-inference-environment",
    description="Brain tumor segmentation inference environment")

ml_client.environments.create_or_update(scoring_environment)

The commands below can be used to download the dataset using the Kaggle API (https://github.com/Kaggle/kaggle-api). Use the instructions to generate your own API key and fill them in on the code cell.

In [None]:
# Export Kaggle configuration variables

%env KAGGLE_USERNAME=<your Kaggle user name>
%env KAGGLE_KEY=<your API token>

In [None]:
# Download and unzip the BRATS dataset.
!kaggle datasets download -d dschettler8845/brats-2021-task1 -p /tmp
!unzip -q /tmp/brats-2021-task1.zip -d /tmp

In [None]:
# Set the path to the local folder where the dataset will be saved
data_dir = "/tmp/brats"
filename = os.path.join("/tmp", "BraTS2021_Training_Data.tar")

# Create the local folder if it doesn't exist
if not os.path.exists(data_dir):
    os.makedirs(data_dir)

# Extract the contents of the file to the local folder
tar = tarfile.open(filename, "r")
tar.extractall(path=data_dir)
tar.close()

print("Dataset downloaded and extracted to:", data_dir)

In [None]:
# Register the dataset in the AzureML Workspace

my_data = Data(
    path=data_dir,
    type=AssetTypes.URI_FOLDER,
    description="Gliomas segmentation necrotic/active tumour and oedema (Source: BRATS 2021 datasets)",
    name=dataset_name,
)

ml_client.data.create_or_update(my_data)

## Submit Parallel Training Job
We are using multi-GPU PyTorch Distrubuted Data Parallel training with scalable Azure ML compute resources. Feel free to change the number of cluster nodes `instance_count` and the number of GPUs per node `process_count_per_instance` to leverage depending on the compute SKU you provisioned.  
Note that you can interact with the job for monitoring or debugging using JupyterLab, VSCode, or Tensorboard during training.

In [None]:
# Retrieve latest version of BRATS dataset
latest_version = [dataset.latest_version for dataset in ml_client.data.list() if dataset.name == dataset_name][0]
dataset_asset = ml_client.data.get(name= dataset_name, version= latest_version)
print(f'Latest version of {dataset_name}: {latest_version}')

In [None]:
job = command(
    inputs = {"input_data": Input(type=AssetTypes.URI_FOLDER, path= dataset_asset.path)},
    code = 'src/',
    command = "python train-brats21.py --epochs 2 --initial_lr 0.00025 --train_batch_size 1 --val_batch_size 1 --input_data ${{inputs.input_data}} --best_model_name BRATS21",
    environment = "monai-multigpu-azureml@latest", 
    compute = train_target,
    experiment_name = experiment,
    display_name = f"3d brain tumor segmentation based on BRATS21",
    description = "## Brain tumor segmentation on 3D MRI brain scans",
    shm_size='300g',
    resources=dict(instance_count= 1), # cluster nodes 
    distribution=dict(type="PyTorch", process_count_per_instance= 4), # GPUs per node
    environment_variables=dict(AZUREML_ARTIFACTS_DEFAULT_TIMEOUT = 1000),
    services={
    "My_jupyterlab": JobService(job_service_type="jupyter_lab"),
    "My_vscode": JobService(job_service_type="vs_code",),
    "My_tensorboard": JobService(job_service_type="tensor_board",),
        })

returned_job = ml_client.create_or_update(job)

## Deploy model to a Managed Endpoint

In [None]:
# create an online endpoint
endpoint = ManagedOnlineEndpoint(
    name=online_endpoint_name,
    description="MONAI 3d brain tumor segmentation",
    auth_mode="key",
    tags={
        "training_dataset": "Medical Segmentation Decathlon: Brain tumor segmentation",
        "model_type": "pytorch",
        "dataset" : dataset_name,
    },
)

endpoint = ml_client.begin_create_or_update(endpoint)

In [None]:
endpoint = ml_client.online_endpoints.get(online_endpoint_name)
print(f"Endpoint {endpoint.name} provisioning state: {endpoint.provisioning_state}")

In [None]:
# Let's pick the latest version of the model
latest_model_version = max([int(m.version) for m in ml_client.models.list(name= registered_model_name)])

print(f'Latest version of {registered_model_name} found: {latest_model_version}')

In [None]:
# picking the model to deploy. Here we use the latest version of our registered model
model = ml_client.models.get(name=registered_model_name, version= latest_model_version)

In [None]:
# create an online deployment.
deployment = ManagedOnlineDeployment(
    name = deployment_name,
    endpoint_name = online_endpoint_name,
    model = model,
    environment = "brats-inference-environment@latest",
    code_configuration=CodeConfiguration(code= "./src", scoring_script="score-brats21.py"),
    instance_type = "Standard_NC6s_v3",
    instance_count = 1,
    request_settings= OnlineRequestSettings(request_timeout_ms = 90000),

)
deployment = ml_client.begin_create_or_update(deployment)

In [None]:
# existing traffic details
print(endpoint.traffic)

# Get the scoring URI
print(endpoint.scoring_uri)

## Get Predictions

In [None]:
# Encode input images for JSON request file
with open(sample_image, "rb") as image_file:
    flair_encoded = base64.b64encode(image_file.read()).decode('utf-8')

with open(sample_image_t1, "rb") as image_file:
    t1_encoded = base64.b64encode(image_file.read()).decode('utf-8')

with open(sample_image_t1ce, "rb") as image_file:
    t1ce_encoded = base64.b64encode(image_file.read()).decode('utf-8')

with open(sample_image_t2, "rb") as image_file:
    t2_encoded = base64.b64encode(image_file.read()).decode('utf-8')

request_data = {
    "data": [{"flair": flair_encoded, "t1": t1_encoded, 
              "t1ce": t1ce_encoded, "t2": t2_encoded
             }]
}

# Write the JSON request data to a file
with open("request-brats2021.json", "w") as outfile:
    json.dump(request_data, outfile)

In [None]:
# Send request to Managed Online Endpoint
response = ml_client.online_endpoints.invoke(
    endpoint_name= online_endpoint_name,
    deployment_name= deployment_name,
    request_file="./request-brats2021.json",
)

# convert response to numpy array with dimensions channel, height, width, slice
json_response = json.loads(response)
pred_vol = np.array(json_response)

## Review Predictions
We are inspecting the predictions for the tumor core segmentations and compare them with the ground truth annotations.

In [None]:
# Inspect 3d structure - viewer works in VSCode 

img_vol_ch = img_vol[0,:,:,:]
seg_vol_ch = pred_vol[0,:,:,:]

viewer = view(image= img_vol_ch * 255,
              label_image= seg_vol_ch * 255,
              gradient_opacity=0.4,
              background = (0.5, 0.5, 0.5))
              
viewer

In [None]:
viewer.close()

In [None]:
def show_slice(slice_index=60):

    img = np.expand_dims(img_vol[0,:,:,slice_index], 0) # images
    true_seg = np.expand_dims(seg_vol[0,:,:,slice_index], 0) # annotated ground truth labels
    pred_seg = np.expand_dims(pred_vol[0,:,:,slice_index], 0)  # predicted labels
    
    blend = blend_images(img, true_seg, cmap='hsv')
    over_true = np.transpose(blend, (1,2,0))
    blend = blend_images(img, pred_seg, cmap='Blues')
    over_pred = np.transpose(blend, (1,2,0))
    
    fig, ((ax1, ax2)) = plt.subplots(1, 2, figsize=(14, 7))

    ax1.imshow(over_true)
    ax1.set_title('Ground truth segmentations')
    ax2.imshow(over_pred)
    ax2.set_title('Predicted segmentations')
    
    plt.tight_layout()
    plt.show()

# Use the interact function to create a slider for the slice index
_ = interact(show_slice, slice_index=(0, img_vol.shape[-1]-1))