## DeepEdit Inference Tutorial

DeepEdit is an algorithm that combines the power of two models in one single architecture. It allows the user to perform inference, as a standard segmentation method (i.e. UNet), and also to interactively segment part of an image using clicks (Sakinis et al.). DeepEdit aims to facilitate the user experience and at the same time the development of new active learning techniques.


This Notebooks shows the performance of a model trained to segment the spleen. 

**We recommend importing the pretrained model into the [DeepEdit App in MONAI Label](https://github.com/Project-MONAI/MONAILabel/tree/main/sample-apps/radiology#deepedit) for full experience.**

Sakinis et al., Interactive segmentation of medical images through fully convolutional neural networks. (2019) https://arxiv.org/abs/1903.08205

In [None]:
# !python -c "import monai" || pip install -q "monai-weekly[nibabel, tqdm]"

#### Using library versions:

monai==0.8.1 nibabel==3.2.2 numpy==1.22.3 pytorch-ignite==0.4.8 scikit-image==0.19.2 scipy==1.8.0 tensorboard==2.8.0 torch==1.11.0 tqdm==4.64.0


In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import jit

import monai

from monai.apps.deepedit.transforms import (
    AddGuidanceSignalCustomd,
    AddGuidanceFromPointsCustomd,
    ResizeGuidanceMultipleLabelCustomd,
)


from monai.transforms import (
    Activationsd,
    AsDiscreted,
    EnsureChannelFirstd,
    EnsureTyped,
    LoadImaged,
    Orientationd,
    Resized,
    ScaleIntensityRanged,
    SqueezeDimd,
    ToNumpyd,
    ToTensord,
)

### Plotting functions

In [None]:
def draw_points(guidance, slice_idx):
    if guidance is None:
        return
    for p in guidance:
        p1 = p[1]
        p2 = p[0]
        plt.plot(p1, p2, "r+", 'MarkerSize', 30)


def show_image(image, label, guidance=None, slice_idx=None):
    plt.figure("check", (12, 6))
    plt.subplot(1, 2, 1)
    plt.title("image")
    plt.imshow(image, cmap="gray")

    if label is not None:
        masked = np.ma.masked_where(label == 0, label)
        plt.imshow(masked, 'jet', interpolation='none', alpha=0.7)

    draw_points(guidance, slice_idx)
    plt.colorbar()

    if label is not None:
        plt.subplot(1, 2, 2)
        plt.title("label")
        plt.imshow(label)
        plt.colorbar()
        # draw_points(guidance, slice_idx)
    plt.show()


def print_data(data):
    for k in data:
        v = data[k]

        d = type(v)
        if type(v) in (int, float, bool, str, dict, tuple):
            d = v
        elif hasattr(v, 'shape'):
            d = v.shape

        if k in ('image_meta_dict', 'label_meta_dict'):
            for m in data[k]:
                print('{} Meta:: {} => {}'.format(k, m, data[k][m]))
        else:
            print('Data key: {} = {}'.format(k, d))

In [None]:
# Download data and model

resource = "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/_image.nii.gz"
dst = "_image.nii.gz"

if not os.path.exists(dst):
    monai.apps.download_url(resource, dst)

resource = "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/pretrained_deepedit_dynunet-final.ts"
dst = "pretrained_deepedit_dynunet-final.ts"
if not os.path.exists(dst):
    monai.apps.download_url(resource, dst)

In [None]:
# labels
labels = {'spleen': 1,
          'background': 0
          }

# Pre Processing
spatial_size = [128, 128, 128]

data = {
    'image': '_image.nii.gz',
    'guidance': {'spleen': [[66, 180, 105], [66, 180, 145]], 'background': []},
}

slice_idx = original_slice_idx = data['guidance']['spleen'][0][2]

pre_transforms = [
                # Loading the image
                LoadImaged(keys="image", reader="ITKReader"),
                # Ensure channel first
                EnsureChannelFirstd(keys="image"),
                # Change image orientation
                Orientationd(keys="image", axcodes="RAS"),
                # Scaling image intensity - works well for CT images
                ScaleIntensityRanged(keys="image", a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
                # DeepEdit Tranforms for Inference #
                # Add guidance (points) in the form of tensors based on the user input
                AddGuidanceFromPointsCustomd(ref_image="image", guidance="guidance", label_names=labels),
                # Resize the image
                Resized(keys="image", spatial_size=spatial_size, mode="area"),
                # Resize the guidance based on the image resizing
                ResizeGuidanceMultipleLabelCustomd(guidance="guidance", ref_image="image"),
                # Add the guidance to the input image
                AddGuidanceSignalCustomd(keys="image", guidance="guidance"),
                # Convert image to tensor 
                ToTensord(keys="image"),
                ]

original_image = None

# Going through each of the pre_transforms
for t in pre_transforms:
    
    tname = type(t).__name__
    data = t(data)         
    image = data['image']
    label = data.get('label')
    guidance = data.get('guidance')

    print("{} => image shape: {}".format(tname, image.shape))
    
    if tname == 'LoadImaged':
        original_image = data['image']
        label = None
        tmp_image = image[:, :, slice_idx]  
        show_image(tmp_image, label, [guidance['spleen'][0]], slice_idx)
          
transformed_image = data['image']
guidance = data.get('guidance')

In [None]:
# Evaluation
model_path = 'pretrained_deepedit_dynunet-final.ts'
model = jit.load(model_path)
model.cuda()
model.eval()

inputs = data['image'][None].cuda()
with torch.no_grad():
    outputs = model(inputs)
outputs = outputs[0]
data['pred'] = outputs

post_transforms = [
                    EnsureTyped(keys="pred"),
                    Activationsd(keys="pred", softmax=True),
                    AsDiscreted(keys="pred", argmax=True),
                    SqueezeDimd(keys="pred", dim=0),
                    ToNumpyd(keys="pred"),
]

pred = None
for t in post_transforms:
    
    tname = type(t).__name__
    data = t(data)
    image = data['image']
    label = data['pred']
    print("{} => image shape: {}, pred shape: {}".format(tname, image.shape, label.shape))
    
for i in range(data['pred'].shape[0]):
    image = transformed_image[0, :, :, i]  # Taking the first channel which is the main image
    label = data['pred'][:, :, i]
    if np.sum(label) == 0:
        continue

    print("Final PLOT:: {} => image shape: {}, pred shape: {}; min: {}, max: {}, sum: {}".format(
        i, image.shape, label.shape, np.min(label), np.max(label), np.sum(label)))
    show_image(image, label)

In [None]:
# remove downloaded files
os.remove('_image.nii.gz')
os.remove('pretrained_deepedit_dynunet-final.ts')