In [None]:
import sys
sys.version

In [None]:
import tqdm
import numpy as np
import skimage.measure
import matplotlib.pyplot as plt

In [None]:
import pydicom
pydicom.__version__

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import rai
rai.__version__

In [None]:
import raicontours

from raicontours import TG263

raicontours.__version__

In [None]:
cfg = raicontours.get_config()

In [None]:
cfg

In [None]:
rai_starting_model, rai_dependent_model = rai.load_model(cfg=cfg)

In [None]:
image_paths, structure_path = rai.download_deepmind_example()
structure_path

In [None]:
x_grid, y_grid, z_grid, original_image_stack, image_uids = rai.paths_to_image_stack_hfs(
    cfg=cfg, paths=image_paths
)

In [None]:
original_image_stack.shape

In [None]:
original_num_slices = original_image_stack.shape[0]

slice_reduction = cfg["reduce_block_sizes"][0][0]
desired_num_slices = int(np.ceil(original_num_slices / slice_reduction) * slice_reduction)

if original_num_slices != desired_num_slices:
    image_stack = original_image_stack.take(range(desired_num_slices), axis=0, mode='clip')
else:
    image_stack = original_image_stack
    
num_slices = image_stack.shape[0]
assert num_slices == desired_num_slices
assert image_stack.shape[1:3] == (512, 512)

image_stack.shape

In [None]:
reduced_image_stack = skimage.measure.block_reduce(image_stack, block_size=cfg["reduce_block_sizes"][0], func=np.mean)
reduced_image_stack.shape

In [None]:
310 / 4

In [None]:
step_size = 40

reduced_num_slices = reduced_image_stack.shape[0]
step_size = int(np.ceil(reduced_num_slices / np.ceil(reduced_num_slices / step_size)))
z = list(range(0, reduced_num_slices, step_size)) + [reduced_num_slices]
z

In [None]:
y = [30, 60, 90]
x = [30, 60, 90]

predicted_masks = rai.inference_over_jittered_grid(
    cfg=cfg, model=rai_starting_model, grid=(z, y, x), image_stack=reduced_image_stack, max_batch_size=10
)

In [None]:
where_mask = np.where(predicted_masks > 127.5)
np.min(where_mask, axis=1)

In [None]:
np.max(where_mask, axis=1)

In [None]:
# predicted_masks = rai.inference_over_jittered_grid(
#     cfg=cfg, model=rai_dependent_model, grid=(z, y, x), image_stack=reduced_image_stack, masks_stack=predicted_masks
# )

In [None]:
cfg["reduce_block_sizes"]

In [None]:
upscaled = predicted_masks

for i in range(3):
    upscaled = np.repeat(upscaled, repeats=2, axis=i)


assert upscaled.shape[0] == image_stack.shape[0]
    
upscaled.shape

In [None]:
where_mask = np.where(upscaled > 127.5)
np.min(where_mask, axis=1)

In [None]:
np.max(where_mask, axis=1)

In [None]:
reduced_image_stack = skimage.measure.block_reduce(image_stack, block_size=cfg["reduce_block_sizes"][1], func=np.mean)
reduced_image_stack.shape

In [None]:
step_size = 40

reduced_num_slices = reduced_image_stack.shape[0]
step_size = int(np.ceil(reduced_num_slices / np.ceil(reduced_num_slices / step_size)))
z = list(range(0, reduced_num_slices, step_size)) + [reduced_num_slices]
z

In [None]:
y = [60, 95, 130, 165]
x = [65, 105, 145, 185]

predicted_masks = rai.inference_over_jittered_grid(
    cfg=cfg, model=rai_dependent_model, grid=(z, y, x), image_stack=reduced_image_stack, masks_stack=upscaled, max_batch_size=10
)

In [None]:
predicted_masks.shape

In [None]:
np.max(predicted_masks)

In [None]:
upscaled = predicted_masks

for i in range(1,3):
    upscaled = np.repeat(upscaled, repeats=2, axis=i)
    
upscaled.shape

In [None]:
structure_ds = pydicom.read_file(structure_path)
[item.ROIName for item in structure_ds.StructureSetROISequence]

In [None]:
align_map = {
    "Brain": [TG263.Brain],
    "Brainstem": [TG263.Brainstem],
    "Cochlea-Lt": [TG263.Cochlea_L],
    "Cochlea-Rt": [TG263.Cochlea_R],
    "Lacrimal-Lt": [TG263.Glnd_Lacrimal_L],
    "Lacrimal-Rt": [TG263.Glnd_Lacrimal_R],
    "Lens-Lt": [TG263.Lens_L],
    "Lens-Rt": [TG263.Lens_R],
    "Lung-Lt": [TG263.Lung_L],
    "Lung-Rt": [TG263.Lung_R],
    "Mandible": [TG263.Bone_Mandible],
    "Optic-Nerve-Lt": [TG263.OpticNrv_L],
    "Optic-Nerve-Rt": [TG263.OpticNrv_R],
    "Orbit-Lt": [TG263.Eye_L],
    "Orbit-Rt": [TG263.Eye_R],
    "Parotid-Lt": [TG263.Parotid_L],
    "Parotid-Rt": [TG263.Parotid_R],
    "Spinal-Cord": [TG263.SpinalCord],
    "Submandibular-Lt": [TG263.Glnd_Submand_L],
    "Submandibular-Rt": [TG263.Glnd_Submand_R],
}
structure_names = list(align_map.keys())

dicom_contours_by_structure = rai.dicom_to_contours_by_structure(
    ds=structure_ds, image_uids=image_uids, structure_names=structure_names
)



In [None]:
num_slices = image_stack.shape[0]
step_size = int(np.ceil(num_slices / np.ceil(num_slices / 40)))
z = list(range(0, num_slices, step_size)) + [num_slices]
z

In [None]:
where_mask = np.where(upscaled > 127.5)
np.min(where_mask, axis=1)

In [None]:
np.max(where_mask, axis=1)

In [None]:
y = list(range(125, 330, 40))
y

In [None]:
x = list(range(130, 380, 40))
x

In [None]:
predicted_masks = upscaled

looped_dice = []
for i in range(1):
    predicted_masks = rai.inference_over_jittered_grid(
        cfg=cfg, 
        model=rai_dependent_model, 
        grid=(z, y, x), 
        image_stack=image_stack, 
        masks_stack=predicted_masks, 
        max_batch_size=10, 
        verify=False  # TODO: Remove this before publishing
    )
    
    
    predicted_contours_by_structure = rai.masks_to_contours_by_structure(
        cfg=cfg, x_grid=x_grid, y_grid=y_grid, masks=predicted_masks[0:original_num_slices, ...]
    )
    
    aligned_predicted_contours_by_structure = rai.merge_contours_by_structure(
        predicted_contours_by_structure, align_map
    )
    
    dice = {}
    for name in align_map:
        dice[name] = rai.dice_from_contours_by_slice(
            dicom_contours_by_structure[name],
            aligned_predicted_contours_by_structure[name],
        )
        
    looped_dice.append(dice)

In [None]:
for name in structure_names:
    dice = []
    for item in looped_dice:
        dice.append(item[name])
        
    plt.plot(dice, '-o', label=name)
    
plt.legend()

In [None]:
# predicted_contours_by_structure

In [None]:
# looped_dice

In [None]:
# TODO: Create an ipywidget slider for the slices

In [None]:
# TODO: Allow creation over other axis
# First, verify that this can be used within plotly on tranverse

predicted_contours_by_structure = rai.masks_to_contours_by_structure(
    cfg=cfg, x_grid=x_grid, y_grid=y_grid, masks=predicted_masks[0:original_num_slices, ...]
)

In [None]:
import plotly.graph_objects as go



fig = go.Figure()

contours = predicted_contours_by_structure[TG263.Brain][50]

for contour in contours:
    contour_array = np.array(contour + [contour[0]])
    
    fig.add_trace(go.Scatter(
        x=contour_array[:, 0],
        y=contour_array[:, 1],
        hoverinfo="skip",
        mode="lines"
    ))
fig.show()

In [None]:
print(fig)

In [None]:
# predicted_contours_by_structure[TG263.Brain]

In [None]:
from rai.display import interactive

In [None]:
transverse, coronal, sagittal = interactive.collect_slices(
    original_image_stack, 
    vmin = 0.2, 
    vmax = 0.4,
)

grids = (z_grid, y_grid, x_grid)

images = interactive.create_plotly_layout_images(
    grids, (50, 256, 256), transverse, coronal, sagittal)

interactive.draw(
    grids=grids,
    images=images,
    ranges=[
        [-150, 100],
        [80, -170],
        [125,-125],
    ]
)

In [None]:
# TODO:
# * Change figures to clickable interactive transverse/coronal/sagital bokeh
# * Calculate and report hausdorff and surface dice as well