In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import SimpleITK as sitk
from mpl_toolkits.axes_grid1 import ImageGrid
from skimage.measure import find_contours
from skimage.transform import resize

In [None]:
plt.style.use("dark_background")

In [None]:
pre = 'preRT'
mid = 'midRT'

path = r"data\HNTSMRG24_train"
patient_number = 95
task = 1
columns = 5
colormap = 'gray'

label1_color = 'chartreuse'
label2_color = 'deeppink'
label1 = 'GTVp'
label2 = 'GTVn'

label_colors = [label1_color, label2_color]
labels = [label1, label2]

In [None]:
def get_contours(mask, label):
    """Obtain contours of given label (value) for input mask"""

    contours = find_contours(mask == label)
    return contours


def draw_contours(ax, mask, label_num, label_color, label_name):
    """Draw contours of given mask with given label number"""

    contours = get_contours(mask, label=label_num)
    for contour in contours:
        ax.plot(contour[:, 1], contour[:, 0], color=label_color, label=label_name)


def create_fixed_legend(ax, labels, label_colors):
    """Create a fixed legend for given labels, to draw a legend in an image even though the lines are absent"""

    dummy_lines = [ax.plot([], [], label=label, color=label_color)[0] for label, label_color in
                   zip(labels, label_colors)]
    ax.legend(handles=dummy_lines, frameon=False, bbox_to_anchor=(1.05, 1.23), ncol=2)


def display_segmented_image(images, masks, size=2, columns=6, cmap='bone', patient_number=None, task=None,
                            label1_color='chartreuse',
                            label2_color='deeppink', label1='GTVp', label2='GTVn'):
    """Display the segmented MRI image"""

    rows = (len(images) - 1) // columns + 1

    fig = plt.figure(figsize=(columns * size, rows * size))
    ax = ImageGrid(fig, 111, nrows_ncols=(rows, columns), axes_pad=0.05)
    for i in range(len(images), len(ax)):
        ax[i].axis('off')

    create_fixed_legend(ax[columns - 1], [label1, label2], [label1_color, label2_color])
    for i in range(len(images)):
        ax[i].imshow(images[i], cmap=cmap)
        ax[i].set_xticks([])
        ax[i].set_yticks([])
        ax[i].text(20, 50, f'{i:02d}')

        draw_contours(ax[i], masks[i], label_num=1, label_color=label1_color, label_name=label1)
        draw_contours(ax[i], masks[i], label_num=2, label_color=label2_color, label_name=label2)

    patient = f'Patient No. {patient_number}' if patient_number else ''
    if task == 1:
        task_title = 'pre-Radiotherapy Treatment'
    elif task == 2:
        task_title = 'mid-Radiotherapy Treatment'
    else:
        task_title = None

    plt.suptitle(f'{patient} - {task_title}', y=.91)
    plt.show()


def get_array(nii_path):
    """Obtain numpy array data from .nii.gz file"""

    nifti = sitk.ReadImage(nii_path)
    array = sitk.GetArrayFromImage(nifti)
    return array


def get_image_and_mask(path, observation, task):
    """Obtain the 3D image and the corresponding mask"""

    task_name = pre if task == 1 else mid
    obs_path = path / observation / task_name
    if task_name == pre:
        image_path = str(list(obs_path.glob("*T2*"))[0])
        mask_path = str(list(obs_path.glob("*mask*"))[0])
    if task_name == mid:
        image_path = str(list(obs_path.glob("*midRT_T2*"))[0])
        mask_path = str(list(obs_path.glob("*midRT_mask*"))[0])
    print(image_path)
    print(mask_path)
    img = get_array(image_path)
    mask = get_array(mask_path)

    return img, mask


def get_path_and_patients(path):
    """Obtain path for data, with all patient (numbers) from the data directory"""

    absolute_path = Path().absolute().parent
    data_path = absolute_path / path
    patients = [obs.name for obs in data_path.iterdir()]
    return data_path, patients


def get_contours(mask, label):
    """Obtain contours of given label (value) for input mask"""

    contours = find_contours(mask == label)
    return contours


def draw_contours(ax, mask, label_num, label_color, label_name):
    """Draw contours of given mask with given label number"""

    contours = get_contours(mask, label=label_num)
    for contour in contours:
        ax.plot(contour[:, 1], contour[:, 0], color=label_color, label=label_name)


def create_fixed_legend(ax, labels, label_colors):
    """Create a fixed legend for given labels, to draw a legend in an image even though the lines are absent"""

    dummy_lines = [ax.plot([], [], label=label, color=label_color)[0] for label, label_color in
                   zip(labels, label_colors)]
    ax.legend(handles=dummy_lines, frameon=False, bbox_to_anchor=(1.05, 1.23), ncol=2)





In [None]:
path, patients = get_path_and_patients(path)
patient = patients[patient_number]
img, mask = get_image_and_mask(path, patient, task)

In [None]:
display_segmented_image(img, mask, columns=columns, cmap=colormap, task=task, patient_number=patient)

In [None]:
def display_slice(img, orientation, slice, labels, label_colors, cmap='gray', mask=None, figsize=(8, 8)):
    orientations = ['axial', 'coronal', 'saggital']
    assert orientation in orientations

    if orientation == orientations[0]:
        img = img[slice, :, :]
        if mask is not None:
            mask = mask[slice, :, :]
    elif orientation == orientations[1]:
        img = img[:, slice, :]
        img = np.flip(resize(img, (65 * 4, 512)), axis=0)
        if mask is not None:
            mask = mask[:, slice, :]
    else:
        img = img[:, :, slice]
        img = np.flip(resize(img, (65 * 4, 512)), axis=0)
        if mask is not None:
            mask = mask[:, :, slice]

    fig = plt.figure(frameon=False, figsize=figsize)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    ax.axis('off')
    ax.imshow(img, cmap=cmap)
    ax.set_xticks([])
    ax.set_yticks([])

    if mask is not None:
        draw_contours(ax, mask, label_num=1, label_color=label_colors[0], label_name=labels[0])
        draw_contours(ax, mask, label_num=2, label_color=label_colors[1], label_name=labels[1])

    masked = '' if mask is not None else '_annotated'
    save_path = f'figs/{orientation}{masked}.png'
    plt.savefig(save_path, dpi=1200)

In [None]:
display_slice(img, 'axial', 34, labels, label_colors, mask=mask)

In [None]:
display_slice(img, 'axial', 34, labels, label_colors)

In [None]:
display_slice(img, 'saggital', 256, labels, label_colors)

In [None]:
display_slice(img, 'coronal', 256, labels, label_colors)