In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import jit

from monai.apps.deepgrow.transforms import (
    AddGuidanceFromPointsd,
    AddGuidanceSignald,
    Fetch2DSliced,
    ResizeGuidanced,
    RestoreCroppedLabeld,
    SpatialCropGuidanced,
)
from monai.transforms import (
    AsChannelFirstd,
    Spacingd,
    LoadNiftid,
    AddChanneld,
    NormalizeIntensityd,
    ToTensord,
    ToNumpyd,
    Activationsd,
    AsDiscreted,
    Resized
)


def draw_points(guidance):
    if guidance is None:
        return
    colors = ['r+', 'b+']
    for color, points in zip(colors, guidance):
        for p in points:
            p1 = p[-1]
            p2 = p[-2]
            plt.plot(p1, p2, color, 'MarkerSize', 30)


def show_image(image, label, guidance=None):
    plt.figure("check", (12, 6))
    plt.subplot(1, 2, 1)
    plt.title("image")
    plt.imshow(image, cmap="gray")

    if label is not None:
        masked = np.ma.masked_where(label == 0, label)
        plt.imshow(masked, 'jet', interpolation='none', alpha=0.7)

    draw_points(guidance)
    plt.colorbar()

    if label is not None:
        plt.subplot(1, 2, 2)
        plt.title("label")
        plt.imshow(label)
        plt.colorbar()
        # draw_points(guidance)
    plt.show()

In [None]:
# Pre Processing
roi_size = [256, 256]
pixdim = (1.0, 1.0)
dimensions = 2

data = {
    'image': '/salle/Downloads/spleen_19.nii.gz',
    'foreground': [[354, 336, 40]],  # ,[259,381,40]],
    'background': [],
    'spatial_size': [384, 384]
}
slice_idx = original_slice_idx = data['foreground'][0][2]

pre_transforms = [
    LoadNiftid(keys='image'),
    AsChannelFirstd(keys='image'),
    Spacingd(keys='image', pixdim=pixdim, mode='bilinear'),

    AddGuidanceFromPointsd(ref_image='image', guidance='guidance', foreground='foreground', background='background',
                           dimensions=dimensions),
    Fetch2DSliced(keys='image', guidance='guidance'),
    AddChanneld(keys='image'),

    SpatialCropGuidanced(keys='image', guidance='guidance', spatial_size=roi_size),
    Resized(keys='image', spatial_size=roi_size, mode='area'),
    ResizeGuidanced(guidance='guidance', ref_image='image'),
    NormalizeIntensityd(keys='image', subtrahend=208.0, divisor=388.0),
    AddGuidanceSignald(image='image', guidance='guidance'),
    ToTensord(keys='image')
]

original_image = None
original_image_slice = None
for t in pre_transforms:
    tname = type(t).__name__

    data = t(data)
    image = data['image']
    label = data.get('label')
    guidance = data.get('guidance')

    print("{} => image shape: {}, label shape: {}".format(
        tname, image.shape, label.shape if label is not None else None))

    image = image if tname == 'Fetch2DSliced' else image[:, :, slice_idx] if tname in (
        'LoadNiftid') else image[slice_idx, :, :]
    label = label if tname == 'Fetch2DSliced' else label[:, :, slice_idx] if tname in (
        'xyz') else label[slice_idx, :, :] if label is not None else None

    guidance = guidance if guidance else [np.roll(data['foreground'], 1).tolist(), []]
    print('Guidance: {}'.format(guidance))

    show_image(image, label, guidance)
    if tname == 'Fetch2DSliced':
        slice_idx = 0
    if tname == 'LoadNiftid':
        original_image = data['image']
    if tname == 'AddChanneld':
        original_image_slice = data['image']


In [None]:
# Evaluation
model_path = '/workspace/Downloads/models/roi_b8_256x256_c32.ts'
model = jit.load(model_path)
model.cuda()
model.eval()

inputs = data['image'][None].cuda()
with torch.no_grad():
    outputs = model(inputs)
outputs = outputs[0]
data['pred'] = outputs

post_transforms = [
    Activationsd(keys='pred', sigmoid=True),
    AsDiscreted(keys='pred', threshold_values=True, logit_thresh=0.5),
    ToNumpyd(keys='pred'),
    RestoreCroppedLabeld(keys='pred', ref_image='image', mode='nearest'),
]

for t in post_transforms:
    tname = type(t).__name__

    data = t(data)
    image = data['image']
    label = data['pred']
    print("{} => image shape: {}, pred shape: {}".format(tname, image.shape, label.shape))

    if tname in 'RestoreCroppedLabeld':
        image = original_image[:, :, original_slice_idx]
        label = label[0, :, :].detach().cpu().numpy() if torch.is_tensor(label) else label[original_slice_idx]
        print("PLOT:: {} => image shape: {}, pred shape: {}; min: {}, max: {}, sum: {}".format(
            tname, image.shape, label.shape, np.min(label), np.max(label), np.sum(label)))
        show_image(image, label)
    else:
        image = image[0, :, :].detach().cpu().numpy() if torch.is_tensor(image) else image[0]
        label = label[0, :, :].detach().cpu().numpy() if torch.is_tensor(label) else label[0]
        print("PLOT:: {} => image shape: {}, pred shape: {}; min: {}, max: {}, sum: {}".format(
            tname, image.shape, label.shape, np.min(label), np.max(label), np.sum(label)))
        show_image(image, label)
