In [None]:
import sys
sys.version

In [None]:
import tqdm
import numpy as np
import skimage.measure
import matplotlib.pyplot as plt

In [None]:
import pydicom
pydicom.__version__

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import rai
rai.__version__

In [None]:
import raicontours

from raicontours import TG263

raicontours.__version__

In [None]:
cfg = raicontours.get_config()

In [None]:
cfg

In [None]:
rai_starting_model, rai_dependent_model = rai.load_model(cfg=cfg)

In [None]:
image_paths, structure_path = rai.download_deepmind_example()
structure_path

In [None]:
x_grid, y_grid, z_grid, image_stack, image_uids = rai.paths_to_image_stack_hfs(
    cfg=cfg, paths=image_paths
)

In [None]:
image_stack.shape

In [None]:
original_num_slices = image_stack.shape[0]

slice_reduction = cfg["reduce_block_sizes"][0][0]
desired_num_slices = int(np.ceil(original_num_slices / slice_reduction) * slice_reduction)

if original_num_slices != desired_num_slices:
    image_stack = image_stack.take(range(desired_num_slices), axis=0, mode='clip')
    
num_slices = image_stack.shape[0]
assert num_slices == desired_num_slices
assert image_stack.shape[1:3] == (512, 512)

image_stack.shape

In [None]:
reduced_image_stack = skimage.measure.block_reduce(image_stack, block_size=cfg["reduce_block_sizes"][0], func=np.mean)
reduced_image_stack.shape

In [None]:
310 / 4

In [None]:
step_size = 40

reduced_num_slices = reduced_image_stack.shape[0]
step_size = int(np.ceil(reduced_num_slices / np.ceil(reduced_num_slices / step_size)))
z = list(range(0, reduced_num_slices, step_size)) + [reduced_num_slices]
z

In [None]:
y = [30, 60, 90]
x = [30, 60, 90]

predicted_masks = rai.inference_over_jittered_grid(
    cfg=cfg, model=rai_starting_model, grid=(z, y, x), image_stack=reduced_image_stack, max_batch_size=10
)

In [None]:
where_mask = np.where(predicted_masks > 127.5)
np.min(where_mask, axis=1)

In [None]:
np.max(where_mask, axis=1)

In [None]:
reduced_x_grid = skimage.measure.block_reduce(x_grid, block_size=cfg["reduce_block_sizes"][0][2], func=np.mean)
reduced_y_grid = skimage.measure.block_reduce(y_grid, block_size=cfg["reduce_block_sizes"][0][1], func=np.mean)

reduced_predicted_contours_by_structure = rai.masks_to_contours_by_structure(
    cfg=cfg, x_grid=reduced_x_grid, y_grid=reduced_y_grid, masks=predicted_masks
)
rai.plot_contours_by_structure(
    reduced_x_grid, reduced_y_grid, reduced_image_stack, reduced_predicted_contours_by_structure
)

In [None]:
# predicted_masks = rai.inference_over_jittered_grid(
#     cfg=cfg, model=rai_dependent_model, grid=(z, y, x), image_stack=reduced_image_stack, masks_stack=predicted_masks
# )

In [None]:
cfg["reduce_block_sizes"]

In [None]:
upscaled = predicted_masks

for i in range(3):
    upscaled = np.repeat(upscaled, repeats=2, axis=i)


assert upscaled.shape[0] == image_stack.shape[0]
    
upscaled.shape

In [None]:
where_mask = np.where(upscaled > 127.5)
np.min(where_mask, axis=1)

In [None]:
np.max(where_mask, axis=1)

In [None]:
reduced_image_stack = skimage.measure.block_reduce(image_stack, block_size=cfg["reduce_block_sizes"][1], func=np.mean)
reduced_image_stack.shape

In [None]:
step_size = 40

reduced_num_slices = reduced_image_stack.shape[0]
step_size = int(np.ceil(reduced_num_slices / np.ceil(reduced_num_slices / step_size)))
z = list(range(0, reduced_num_slices, step_size)) + [reduced_num_slices]
z

In [None]:
y = [60, 95, 130, 165]
x = [65, 105, 145, 185]

predicted_masks = rai.inference_over_jittered_grid(
    cfg=cfg, model=rai_dependent_model, grid=(z, y, x), image_stack=reduced_image_stack, masks_stack=upscaled, max_batch_size=10
)

In [None]:
reduced_x_grid = skimage.measure.block_reduce(x_grid, block_size=cfg["reduce_block_sizes"][1][2], func=np.mean)
reduced_y_grid = skimage.measure.block_reduce(y_grid, block_size=cfg["reduce_block_sizes"][1][1], func=np.mean)

reduced_predicted_contours_by_structure = rai.masks_to_contours_by_structure(
    cfg=cfg, x_grid=reduced_x_grid, y_grid=reduced_y_grid, masks=predicted_masks[0:original_num_slices, ...]
)
rai.plot_contours_by_structure(
    reduced_x_grid, reduced_y_grid, reduced_image_stack[0:original_num_slices,...], reduced_predicted_contours_by_structure
)

In [None]:
predicted_masks.shape

In [None]:
np.max(predicted_masks)

In [None]:
upscaled = predicted_masks

for i in range(1,3):
    upscaled = np.repeat(upscaled, repeats=2, axis=i)
    
upscaled.shape

In [None]:
structure_ds = pydicom.read_file(structure_path)
[item.ROIName for item in structure_ds.StructureSetROISequence]

In [None]:
align_map = {
    "Brain": [TG263.Brain],
    "Brainstem": [TG263.Brainstem],
    "Cochlea-Lt": [TG263.Cochlea_L],
    "Cochlea-Rt": [TG263.Cochlea_R],
    "Lacrimal-Lt": [TG263.Glnd_Lacrimal_L],
    "Lacrimal-Rt": [TG263.Glnd_Lacrimal_R],
    "Lens-Lt": [TG263.Lens_L],
    "Lens-Rt": [TG263.Lens_R],
    "Lung-Lt": [TG263.Lung_L],
    "Lung-Rt": [TG263.Lung_R],
    "Mandible": [TG263.Bone_Mandible],
    "Optic-Nerve-Lt": [TG263.OpticNrv_L],
    "Optic-Nerve-Rt": [TG263.OpticNrv_R],
    "Orbit-Lt": [TG263.Eye_L],
    "Orbit-Rt": [TG263.Eye_R],
    "Parotid-Lt": [TG263.Parotid_L],
    "Parotid-Rt": [TG263.Parotid_R],
    "Spinal-Cord": [TG263.SpinalCord],
    "Submandibular-Lt": [TG263.Glnd_Submand_L],
    "Submandibular-Rt": [TG263.Glnd_Submand_R],
}
structure_names = list(align_map.keys())

dicom_contours_by_structure = rai.dicom_to_contours_by_structure(
    ds=structure_ds, image_uids=image_uids, structure_names=structure_names
)



In [None]:
num_slices = image_stack.shape[0]
step_size = int(np.ceil(num_slices / np.ceil(num_slices / 40)))
z = list(range(0, num_slices, step_size)) + [num_slices]
z

In [None]:
where_mask = np.where(upscaled > 127.5)
np.min(where_mask, axis=1)

In [None]:
np.max(where_mask, axis=1)

In [None]:
y = list(range(125, 330, 40))
y

In [None]:
x = list(range(130, 380, 40))
x

In [None]:
predicted_masks = upscaled

looped_dice = []
for i in range(1):
    predicted_masks = rai.inference_over_jittered_grid(
        cfg=cfg, 
        model=rai_dependent_model, 
        grid=(z, y, x), 
        image_stack=image_stack, 
        masks_stack=predicted_masks, 
        max_batch_size=10, 
        verify=False  # TODO: Remove this before publishing
    )
    
    
    predicted_contours_by_structure = rai.masks_to_contours_by_structure(
        cfg=cfg, x_grid=x_grid, y_grid=y_grid, masks=predicted_masks[0:original_num_slices, ...]
    )
    
    aligned_predicted_contours_by_structure = rai.merge_contours_by_structure(
        predicted_contours_by_structure, align_map
    )
    
    dice = {}
    for name in align_map:
        dice[name] = rai.dice_from_contours_by_slice(
            dicom_contours_by_structure[name],
            aligned_predicted_contours_by_structure[name],
        )
        
    looped_dice.append(dice)

In [None]:
for name in structure_names:
    dice = []
    for item in looped_dice:
        dice.append(item[name])
        
    plt.plot(dice, '-o', label=name)
    
plt.legend()

In [None]:
# predicted_contours_by_structure

In [None]:
# looped_dice

In [None]:
rai.plot_contours_by_structure(
    x_grid, y_grid, image_stack[0:original_num_slices, ...], predicted_contours_by_structure, align_map
)

In [None]:
renamed_dicom_contours_by_structure = {
    f"DICOM {key}": item for key, item in dicom_contours_by_structure.items()
}

In [None]:
# TODO: Create an ipywidget slider for the slices

In [None]:
combined_contours_by_structure = {
    **predicted_contours_by_structure,
    **renamed_dicom_contours_by_structure,
}

rai.plot_contours_by_structure(
    x_grid, y_grid, image_stack[0:original_num_slices, ...], combined_contours_by_structure, align_map
)

In [None]:
import plotly.offline as pyo
pyo.init_notebook_mode()

In [None]:
# Create interactive bokeh
# Press button to centre on a structure

cfg
grids = (z_grid, y_grid, x_grid)
images = image_stack[0:original_num_slices, ...]
masks = predicted_masks[0:original_num_slices, ...]

In [None]:
from rai.display import interactive
import pandas

In [None]:
fig = interactive.main()
fig.show()

In [None]:
import numpy as np
import plotly.graph_objects as go
r, theta = np.mgrid[0.1:1:10j, 0:360:20j]
color = np.random.random(r.shape)
fig = go.Figure(go.Barpolar(
    r=r.ravel(),
    theta=theta.ravel(),
    marker_color=color.ravel()),)
fig.update_layout(polar_bargap=0)
fig.show()

In [None]:
import plotly.graph_objects as go

In [None]:
from plotly.subplots import make_subplots

In [None]:
from rai.mask import convert

In [None]:
from plotly.subplots import make_subplots
from skimage import data
img = data.chelsea()
fig = make_subplots(1, 2)
# We use go.Image because subplots require traces, whereas px functions return a figure
fig.add_trace(go.Image(z=img), 1, 1)
for channel, color in enumerate(['red', 'green', 'blue']):
    fig.add_trace(go.Histogram(x=img[..., channel].ravel(), opacity=0.5,
                               marker_color=color, name='%s channel' %color), 1, 2)
fig.update_layout(height=400)
fig.show()

In [None]:
img.shape

In [None]:
img = image_stack[50, ...]

In [None]:
x0, dx = convert._grid_to_transform(x_grid)
y0, dy = convert._grid_to_transform(y_grid)

In [None]:
y_grid[-1]

In [None]:
x0 = x_grid[0]
dx = x_grid[1] - x_grid[0]
y0 = y_grid[-1]
dy = y_grid[-2] - y_grid[-1]
z0 = z_grid[0]
dz = z_grid[1] - z_grid[0]

In [None]:
import base64
import imageio
from io import BytesIO


In [None]:
def convert_to_b64(image):
    in_memory_file = BytesIO()
    scaled_img = np.round(((image - 0.2) / 0.2) * 255)
    scaled_img[scaled_img < 0] = 0
    scaled_img[scaled_img > 255] = 255
    scaled_img = scaled_img.astype(np.uint8)
    in_memory_file.seek(0)
    imageio.imsave(in_memory_file, scaled_img, format="png")
    in_memory_file.seek(0)
    raw = in_memory_file.read()
    b64 = base64.encodebytes(raw).decode()
    img = f"data:image/png;base64,{b64}"
    
    return img

In [None]:
transverse = []
for i in range(original_num_slices):
    img = convert_to_b64(image_stack[i, :, :])
    
    transverse.append(img)
    
    
coronal = []
for i in range(image_stack.shape[1]):
    img = convert_to_b64(image_stack[:, i, :])
    
    coronal.append(img)
    
    
sagittal = []
for i in range(image_stack.shape[2]):
    img = convert_to_b64(image_stack[:, -1::-1, i])
    
    sagittal.append(img)

In [None]:
x0 = x_grid[0]
dx = x_grid[1] - x_grid[0]
y0 = y_grid[-1]
dy = y_grid[-2] - y_grid[-1]
z0 = z_grid[0]
dz = z_grid[1] - z_grid[0]

In [None]:
sizex = np.abs(x_grid[-1] - x_grid[0]) + dx
x = x0 - dx / 2

sizey = np.abs(y_grid[-1] - y_grid[0]) + dy
y = y0 - dy / 2

sizez = np.abs(z_grid[-1] - z_grid[0]) + dz
z = z0 - dz / 2

In [None]:
# go.Heatmap?

In [None]:
images = []

for i, img in enumerate(transverse):
    images.append(
        dict(
            name=f"transverse_{i}",
            visible=i==50,
            source=img,
            xref="x",
            yref="y",
            x=x,
            y=y,
            sizex=sizex,
            sizey=sizey,
            sizing="stretch",
            layer="below",
        )
    )
    
for i, img in enumerate(coronal):
    images.append(
        dict(
            name=f"coronal_{i}",
            visible=i==256,
            source=img,
            xref="x3",
            yref="y3",
            x=x,
            y=z,
            sizex=sizex,
            sizey=sizez,
            sizing="stretch",
            layer="below",
        )
    )
    
for i, img in enumerate(sagittal):
    images.append(
        dict(
            name=f"sagittal_{i}",
            visible=i==256,
            source=img,
            xref="x4",
            yref="y4",
            x=y_grid[0] - dy/2,
            y=z,
            sizex=sizey,
            sizey=sizez,
            sizing="stretch",
            layer="below",
        )
    )

In [None]:
xx, yy = np.meshgrid(x_grid, y_grid)

In [None]:
fig = make_subplots(
    rows=2,
    cols=2,
    vertical_spacing=0,
    horizontal_spacing=0,
)

fig.add_trace(
    go.Heatmap(
        visible=True,
        x0=x0,
        dx=dx,
        y0=y0,
        dy=dy,
        z=image_stack[50, ...],
        colorscale="gray",
        name="transverse",
        zmin=0.2,
        zmax=0.4,
        xaxis="x",
        yaxis="y",
        # hoverinfo='skip',
    ), 
    1,
    1,
)


axis_common = {
    "constrain":"domain", 
    "showticklabels": False,
    "spikesnap": "data", 
    "spikemode": "across", 
    "spikedash": "solid", 
    "spikethickness": 0,
}

fig.update_layout(
    {
        # "hoverdistance": 0,
        "height": 900,
        "width": 900,
        "images": images,
        "dragmode": "pan",
        'xaxis': {"range": [-120, 150], 'scaleanchor': 'y', **axis_common},
        'yaxis': {"range": [-100, 170],  **axis_common},
        'xaxis3': {'matches': 'x',  **axis_common},
        'yaxis3': {"range": [-160, 110], 'scaleanchor': 'x3',  **axis_common},
        'yaxis4': {'matches': 'y3', **axis_common},
        'xaxis4': {'matches': 'y',  **axis_common},
    }
);

post_script = """
var plot_element = document.getElementById('{plot_id}');

console.log('callback js has been added to page. did we get plot_element?');
console.log(plot_element);


plot_element.on('plotly_click', function(data){
    console.log('Im inside the plotly_click!!');
    console.log(data);
})


plot_element.on('plotly_doubleclick', function(data){
    console.log('Im inside the plotly_doubleclick!!');
    console.log(data);
})
"""

html = fig.to_html(
    config={
        'displayModeBar': True,
        "displaylogo": False,
        "scrollZoom": True,
    },
    full_html=False,
    post_script=post_script,
    include_plotlyjs=False,
    validate=True,
)

display(HTML(html))

In [None]:
fig.to_html?

In [None]:
# TODO:
# * Change figures to clickable interactive transverse/coronal/sagital bokeh
# * Calculate and report hausdorff and surface dice as well