In [None]:

import copy

import matplotlib.pyplot as plt
import numpy as np
from IPython.display import HTML
from celluloid import Camera  # getting the camera

from byoc.transforms import (
    SpatialCropForegroundd,
    AddInitialSeedPointd,
    AddGuidanceSignald,
    FindAllValidSlicesd,
)
from monai.transforms import (
    LoadNiftid,
    AsChannelFirstd,
    Spacingd,
    Orientationd,
    AddChanneld,
    NormalizeIntensityd,
    Resized,
)

def draw_points(guidance, slice_idx):
    if guidance is None:
        return
    colors = ['r+', 'b+']
    for color, points in zip(colors, guidance):
        for p in points:
            if p[-3] != slice_idx:
                continue
            p1 = p[-1]
            p2 = p[-2]
            plt.plot(p1, p2, color, 'MarkerSize', 30)


def show_image(image, label, slice_idx=None, guidance=None):
    plt.figure("check", (12, 6))
    plt.subplot(1, 2, 1)
    plt.title("image")
    plt.imshow(image, cmap="gray")

    #if label is not None:
    #    masked = np.ma.masked_where(label == 0, label)
    #    plt.imshow(masked, 'jet', interpolation='none', alpha=0.7)

    draw_points(guidance, slice_idx)
    plt.colorbar()

    if label is not None:
        plt.subplot(1, 2, 2)
        plt.title("label")
        plt.imshow(label)
        plt.colorbar()
        # draw_points(guidance, slice_idx)
    plt.show()

In [None]:
roi_size = [128, 256, 256]
model_size = [128, 192, 192]
data = {'image': '/workspace/data/52432/Training/img/img0001.nii.gz', 'label': '/workspace/data/52432/Training/label/label0001.nii.gz'}
slice_idx = 111
region = 6 # liver = 6

transforms = [
    LoadNiftid(keys=('image', 'label')),
    AsChannelFirstd(keys=('image', 'label')),
    Spacingd(keys=('image', 'label'), pixdim=(1.0, 1.0, 1.0), mode=('bilinear', 'nearest')),
    Orientationd(keys=('image', 'label'), axcodes="RAS"),
]

original_label = None
for t in transforms:
    tname = type(t).__name__ 

    data = t(data)
    image = data['image']
    label = data['label']

    print(f"{tname} => image shape: {image.shape}, label shape: {label.shape}")

    image = image[:, :, slice_idx] if tname in ('LoadNiftid') else image[slice_idx, :, :]
    label = label[:, :, slice_idx] if tname in ('LoadNiftid') else label[slice_idx, :, :]
    show_image(image, label)


In [None]:
pre_transforms = [
    AddChanneld(keys=('image', 'label')),
    SpatialCropForegroundd(keys=('image', 'label'), source_key='label', spatial_size=roi_size),
    Resized(keys=('image', 'label'), spatial_size=model_size, mode=('area', 'nearest')),
    NormalizeIntensityd(keys='image', subtrahend=208.0, divisor=388.0),
    FindAllValidSlicesd(label='label', sids='sids'),
]

pdata = copy.deepcopy(data)
pdata['label'] = pdata['label'] == region
original_label = None

for t in pre_transforms:
    tname = type(t).__name__ 
    pdata = t(pdata)

    image = pdata['image']
    label = pdata['label']
    guidance = pdata.get('guidance')

    if tname == 'AddChanneld':
        original_label = label

    factor = 1 if original_label is None else label.shape[1] / original_label.shape[1]
    sid = guidance[0][0][1] if guidance is not None else int(slice_idx * factor)
    #print('Guidance: {}'.format(guidance.tolist() if guidance is not None else None))
    print(f"{tname} => {sid} => image shape: {image.shape}, label shape: {label.shape}")

    image = image[0][sid]
    label = label[0][sid]

In [None]:
print('Total {} valid slices: {}'.format(len(pdata['sids'].tolist()), pdata['sids'].tolist()))

rand_transforms = [
    AddInitialSeedPointd(label='label', guidance='guidance'),
    #AddGuidanceSignald(image='image', guidance='guidance'),
    #ToTensord(keys=('image', 'label'))
]

sid_counts = {}
for i in range(200):
    rdata = copy.deepcopy(pdata)
    #rdata['sids'] = None
    for t in rand_transforms:
        tname = type(t).__name__ 
        rdata = t(rdata)

        image = rdata['image']
        label = rdata['label']
        guidance = rdata.get('guidance')

        sid = guidance[0][0][1]
        if sid_counts.get(sid) is None:
            sid_counts[sid] = 0
        sid_counts[sid] = sid_counts[sid] + 1
        #print(f"{tname} => {sid} => image shape: {image.shape}, label shape: {label.shape}")

print('Used sid count: {} of {}'.format(len(sid_counts), len(pdata['sids'])))
image = image[0][sid]
label = label[0][sid]
if tname == 'AddInitialSeedPointd':
    show_image(image, label, sid, guidance)

In [None]:
fig, ax = plt.subplots(figsize=(18, 12))  # make it bigger
camera = Camera(fig)  # the camera gets our figure

rdata = rdata
image = rdata['image']
label = rdata['label']
for i in range(0, image.shape[1]):
    # Get slice and matching label
    if np.sum(label[0][i]) == 0:
        continue

    j = int(i * original_label.shape[1] / label.shape[1])
    #show_image(image[0][i], label[0][i], i)

    ax.imshow(image[0][i], cmap="gray")  # plotting
    masked = np.ma.masked_where(label[0][i] == 0, label[0][i])
    ax.imshow(masked, 'hsv', interpolation='none', alpha=0.7)
    camera.snap()

animation = camera.animate()
HTML(animation.to_html5_video())