In [None]:
import urllib.request
import pathlib
import shutil
import collections

import numpy as np
import pydicom
import matplotlib.pyplot as plt
import skimage.measure

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from rai.model import load
from rai.data.images import paths_to_image_stack_hfs
from rai.mask.convert import contour_sequence_to_masks, mask_to_contours
from rai.metrics.dice import from_contours_by_slice

from rai.dicom import structures as _dicom_structures

from rai.inference import batch, merge

from raicontours import cfg, TG263

In [None]:
model = load.load_model()

In [None]:
data_path = pathlib.Path('data')

In [None]:
# # TODO: This can be downloaded in parallel.

# data_path.mkdir(exist_ok=True)

# data_root = 'https://github.com/RadiotherapyAI/data-tcia-deepmind/raw/61fd2525f9880c8b201758f43c773e515572be92/0522c0659'

# filenames = [f"CT-{item:03d}.dcm" for item in range(165)] + ["RS.dcm"]

# for filename in filenames:
#     urllib.request.urlretrieve(f"{data_root}/{filename}", data_path / filename)

In [None]:
image_paths = [data_path / f"CT-{item:03d}.dcm" for item in range(165)]

In [None]:
x_grid, y_grid, image_stack, image_uids = paths_to_image_stack_hfs(image_paths)

In [None]:
structure_ds = pydicom.read_file(data_path / "RS.dcm")

In [None]:
name_to_number_map = {
    item.ROIName: item.ROINumber for item in structure_ds.StructureSetROISequence
}

name_to_number_map

In [None]:
TG263_to_deepmind_map = {
    TG263.Eye_L: 'Orbit-Lt',
    TG263.Eye_R: 'Orbit-Rt',
    TG263.OpticNrv_L: 'Optic-Nerve-Lt',
    TG263.OpticNrv_R: 'Optic-Nerve-Rt',
    TG263.Lens_L: 'Lens-Lt',
    TG263.Lens_R: 'Lens-Rt',
}

In [None]:
number_to_contour_sequence_map = {
    item.ReferencedROINumber: item.ContourSequence for item in structure_ds.ROIContourSequence
}

In [None]:
structure_name_to_contour_sequence_map = {
    structure_name: number_to_contour_sequence_map[name_to_number_map[TG263_to_deepmind_map[structure_name]]] for structure_name in TG263_to_deepmind_map.keys()
}

In [None]:
structure_name_to_contour_sequence_map

In [None]:
points = [
    (45, 163, 280),
    (45, 163, 230),
    (45, 183, 280),
    (45, 183, 230),
]

In [None]:
model_input = batch.create_batch(image_stack, points)
model_input.shape

In [None]:
model_output = batch.run_batch(model=model, model_input=model_input, max_batch_size=4)

In [None]:
merged = np.zeros(shape=image_stack.shape + (len(cfg["structures"]),), dtype=np.uint8)
counts = np.zeros(shape=image_stack.shape + (1,), dtype=np.float32)

In [None]:
merged, counts = merge.merge_predictions(merged, counts, points, model_output)

In [None]:
merged.shape

In [None]:
cfg["structures"]

In [None]:
contours_by_structure_pd = {}

for structure_index, structure_name in enumerate(cfg["structures"]):
    this_structure_pd = merged[..., structure_index]
    
    contours_by_slice_pd = []
    for z_index in range(image_stack.shape[0]):
        this_slice_pd = this_structure_pd[z_index, ...]
        contours_pd = mask_to_contours(x_grid, y_grid, this_slice_pd)
        contours_by_slice_pd.append(contours_pd)
        
    contours_by_structure_pd[structure_name] = contours_by_slice_pd

In [None]:
# contours_by_structure_pd

In [None]:
contours_by_structure_gt = {}
dice = {}

for structure_name in cfg["structures"]:
    contours_by_slice_gt = _dicom_structures.contour_sequence_to_contours_by_slice(
        image_uids, structure_name_to_contour_sequence_map[structure_name], 
    )
    contours_by_structure_gt[structure_name] = contours_by_slice_gt
    
    contours_by_slice_pd = contours_by_structure_pd[structure_name]

    dice[structure_name] = from_contours_by_slice(contours_by_slice_gt, contours_by_slice_pd)


dice

In [None]:
image_stack.shape

In [None]:
len(contours_by_structure_gt[TG263.OpticNrv_R])

In [None]:
len(contours_by_structure_pd[TG263.OpticNrv_R])

In [None]:
def _plot_model_result(
    image_stack, contours_by_structure_gt, contours_by_structure_pd
):
    vmin = 0.2
    vmax = 0.4

    ylim = [-np.inf, np.inf]
    xlim = [np.inf, -np.inf]

    axs = []

    for z_index in range(image_stack.shape[0]):
        has_a_contour = False
        
        for structure_name, contours_by_slice in contours_by_structure_pd.items():
            contours = contours_by_slice[z_index]
            if len(contours) > 0:
                has_a_contour = True
                break
                
        if not has_a_contour:
            continue
        
        
        fig, ax = plt.subplots()
        axs.append(ax)

        ax.pcolormesh(
            x_grid, y_grid,
            image_stack[z_index, :, :],
            vmin=vmin,
            vmax=vmax,
            shading="nearest",
            cmap="gray",
        )

        for structure_name, contours_by_slice in contours_by_structure_pd.items():
            contours = contours_by_slice[z_index]
            
            for contour in contours:
                contour_array = np.array(contour)
                ax.plot(
                    contour_array[:, 0],
                    contour_array[:, 1],
                    label=structure_name.value,
                )

                xlim[1] = np.max([np.max(contour_array[:, 0]), xlim[1]])
                xlim[0] = np.min([np.min(contour_array[:, 0]), xlim[0]])
                ylim[0] = np.max([np.max(contour_array[:, 1]), ylim[0]])
                ylim[1] = np.min([np.min(contour_array[:, 1]), ylim[1]])

        ax.set_aspect("equal", "box")

        plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")

    x_range = xlim[1] - xlim[0]
    y_range = ylim[0] - ylim[1]

    margin = 0.2

    xlim[0] -= x_range * margin
    xlim[1] += x_range * margin

    ylim[1] -= y_range * margin
    ylim[0] += y_range * margin

    for ax in axs:
        ax.set_ylim(ylim)
        ax.set_xlim(xlim)

    plt.show()

In [None]:
_plot_model_result(image_stack, contours_by_structure_gt, contours_by_structure_pd)

In [None]:
# TODO: Calculate a few patches and merge