In [None]:
import urllib.request
import pathlib
import shutil
import collections
import itertools
import multiprocessing

import numpy as np
import pydicom
import matplotlib.pyplot as plt
import skimage.measure

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from rai.model import load
from rai.data.images import paths_to_image_stack_hfs
from rai.mask.convert import contour_sequence_to_masks, mask_to_contours
from rai.metrics.dice import from_contours_by_slice

from rai.dicom import structures as _dicom_structures

from rai.inference import batch, merge
from rai.data import download as _data_download

from raicontours import cfg, TG263

In [None]:
model = load.load_model()

In [None]:
image_paths, structure_path = _data_download.example_head_and_neck("data")

In [None]:
x_grid, y_grid, image_stack, image_uids = paths_to_image_stack_hfs(image_paths)

In [None]:
structure_ds = pydicom.read_file(structure_path)

In [None]:
name_to_number_map = {
    item.ROIName: item.ROINumber for item in structure_ds.StructureSetROISequence
}

name_to_number_map

In [None]:
name_map = {
    'Eyes': [TG263.Eye_L, TG263.Eye_R],
    'L Optic Nerve': [TG263.OpticNrv_L],
    'R Optic Nerve': [TG263.OpticNrv_R]
}

In [None]:
number_to_contour_sequence_map = {
    item.ReferencedROINumber: item.ContourSequence for item in structure_ds.ROIContourSequence
}

In [None]:
structure_name_to_contour_sequence_map = {
    structure_name: number_to_contour_sequence_map[name_to_number_map[structure_name]] for structure_name in name_map.keys()
}

In [None]:
structure_name_to_contour_sequence_map

In [None]:
z = [35, 45, 55]
y = [155, 175, 195, 215]
x = [210, 230, 250, 270, 290, 310]

points = []
for point in itertools.product(z, y, x):
    point = np.random.randint(-1, 2, size=3) + point
    points.append(tuple(point.tolist()))

In [None]:
model_input = batch.create_batch(image_stack, points)
model_input.shape

In [None]:
model_output = model.predict(model_input)

In [None]:
merged = np.zeros(shape=image_stack.shape + (len(cfg["structures"]),), dtype=np.uint8)
counts = np.zeros(shape=image_stack.shape + (1,), dtype=np.float32)

In [None]:
merged, counts = merge.merge_predictions(merged, counts, points, model_output)

In [None]:
merged.shape

In [None]:
where_mask = np.where(merged > 127.5)
min_where_mask = np.min(where_mask, axis=1)
max_where_mask = np.max(where_mask, axis=1)

points_array = np.array(points)

for i in range(3):
    min_point = np.min(points_array[:, i])
    max_point = np.max(points_array[:, i])
    
    assert min_point < min_where_mask[i]
    assert max_point > max_where_mask[i]
    
    print(f"Patch centre points modelled had range of {[min_point, max_point]} which appropriately "
          f"encompassed the range of the found masks {[min_where_mask[i], max_where_mask[i]]}")

In [None]:
cfg["structures"]

In [None]:
contours_by_structure_pd = {}

for structure_index, structure_name in enumerate(cfg["structures"]):
    this_structure_pd = merged[..., structure_index]
    
    contours_by_slice_pd = []
    for z_index in range(image_stack.shape[0]):
        this_slice_pd = this_structure_pd[z_index, ...]
        contours_pd = mask_to_contours(x_grid, y_grid, this_slice_pd)
        contours_by_slice_pd.append(contours_pd)
        
    contours_by_structure_pd[structure_name] = contours_by_slice_pd

In [None]:
# contours_by_structure_pd

In [None]:
contours_by_structure_gt = {}
dice = {}

for hnscc_name, tg263_names in name_map.items():
    contours_by_slice_gt = _dicom_structures.contour_sequence_to_contours_by_slice(
        image_uids, structure_name_to_contour_sequence_map[hnscc_name], 
    )
    contours_by_structure_gt[hnscc_name] = contours_by_slice_gt
    
    contours_by_slice_pd = []
    for z_index, _ in enumerate(image_uids):
        
        contours_for_this_slice = []
        for tg263_name in tg263_names:
            contours_for_this_slice += contours_by_structure_pd[tg263_name][z_index]
            
        contours_by_slice_pd.append(contours_for_this_slice)

    dice[hnscc_name] = from_contours_by_slice(contours_by_slice_gt, contours_by_slice_pd)


dice

In [None]:
image_stack.shape

In [None]:
len(contours_by_structure_gt[TG263.OpticNrv_R])

In [None]:
len(contours_by_structure_pd[TG263.OpticNrv_R])

In [None]:
cfg["structures"]

In [None]:
colours = {
    TG263.Eye_L: "C0",
    TG263.Eye_R: "C1",
    TG263.OpticNrv_L: "C2",
    TG263.OpticNrv_R: "C3",
    "Eyes": "C4",
    "L Optic Nerve": "C5",
    "R Optic Nerve": "C6",
}

labels = {
    TG263.Eye_L: "RAi Eye_L",
    TG263.Eye_R: "RAi Eye_R",
    TG263.OpticNrv_L: "RAi OpticNrv_L",
    TG263.OpticNrv_R: "RAi OpticNrv_R",
    "Eyes": "HNSCC Eyes",
    "L Optic Nerve": "HNSCC OpticNrv_L",
    "R Optic Nerve": "HNSCC OpticNrv_R",
}

In [None]:
def _plot_model_result(
    image_stack, contours_by_structure, colours, labels
):
    vmin = 0.2
    vmax = 0.4

    ylim = [-np.inf, np.inf]
    xlim = [np.inf, -np.inf]

    axs = []

    for z_index in range(image_stack.shape[0]):
        has_a_contour = False
        
        for structure_name, contours_by_slice in contours_by_structure.items():
            contours = contours_by_slice[z_index]
            if len(contours) > 0:
                has_a_contour = True
                break
                
        if not has_a_contour:
            continue
        
        
        fig, ax = plt.subplots()
        axs.append(ax)

        ax.pcolormesh(
            x_grid, y_grid,
            image_stack[z_index, :, :],
            vmin=vmin,
            vmax=vmax,
            shading="nearest",
            cmap="gray",
        )

        for structure_name, contours_by_slice in contours_by_structure.items():            
            contours = contours_by_slice[z_index]
            
            for contour in contours:
                contour_array = np.array(contour + [contour[0]])
                
                ax.plot(
                    contour_array[:, 0],
                    contour_array[:, 1],
                    label=labels[structure_name],
                    c=colours[structure_name],
                    
                )

                xlim[1] = np.max([np.max(contour_array[:, 0]), xlim[1]])
                xlim[0] = np.min([np.min(contour_array[:, 0]), xlim[0]])
                ylim[0] = np.max([np.max(contour_array[:, 1]), ylim[0]])
                ylim[1] = np.min([np.min(contour_array[:, 1]), ylim[1]])

        ax.set_aspect("equal", "box")

        plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
        plt.title(f"Slice: {z_index}")

    x_range = xlim[1] - xlim[0]
    y_range = ylim[0] - ylim[1]

    margin = 0.2

    xlim[0] -= x_range * margin
    xlim[1] += x_range * margin

    ylim[1] -= y_range * margin
    ylim[0] += y_range * margin

    for ax in axs:
        ax.set_ylim(ylim)
        ax.set_xlim(xlim)

    plt.show()

In [None]:
all_contours_by_structure = {
    **contours_by_structure_gt,
    **contours_by_structure_pd,
}


_plot_model_result(image_stack, all_contours_by_structure, colours, labels)