In [None]:
import sys
sys.version

In [None]:
import tqdm
import numpy as np
import skimage.measure
import matplotlib.pyplot as plt

In [None]:
import pydicom
pydicom.__version__

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import rai
rai.__version__

In [None]:
import raicontours

from raicontours import TG263

raicontours.__version__

In [None]:
cfg = raicontours.get_config()

In [None]:
cfg

In [None]:
rai_starting_model, rai_dependent_model = rai.load_model(cfg=cfg)

In [None]:
image_paths, structure_path = rai.download_deepmind_example()
structure_path

In [None]:
x_grid, y_grid, z_grid, image_stack, image_uids = rai.paths_to_image_stack_hfs(
    cfg=cfg, paths=image_paths
)

In [None]:
image_stack.shape

In [None]:
original_num_slices = image_stack.shape[0]

slice_reduction = cfg["reduce_block_sizes"][0][0]
desired_num_slices = int(np.ceil(original_num_slices / slice_reduction) * slice_reduction)

if original_num_slices != desired_num_slices:
    image_stack = image_stack.take(range(desired_num_slices), axis=0, mode='clip')
    
num_slices = image_stack.shape[0]
assert num_slices == desired_num_slices
assert image_stack.shape[1:3] == (512, 512)

image_stack.shape

In [None]:
reduced_image_stack = skimage.measure.block_reduce(image_stack, block_size=cfg["reduce_block_sizes"][0], func=np.mean)
reduced_image_stack.shape

In [None]:
310 / 4

In [None]:
step_size = 40

reduced_num_slices = reduced_image_stack.shape[0]
step_size = int(np.ceil(reduced_num_slices / np.ceil(reduced_num_slices / step_size)))
z = list(range(0, reduced_num_slices, step_size)) + [reduced_num_slices]
z

In [None]:
y = [30, 60, 90]
x = [30, 60, 90]

predicted_masks = rai.inference_over_jittered_grid(
    cfg=cfg, model=rai_starting_model, grid=(z, y, x), image_stack=reduced_image_stack, max_batch_size=10
)

In [None]:
where_mask = np.where(predicted_masks > 127.5)
np.min(where_mask, axis=1)

In [None]:
np.max(where_mask, axis=1)

In [None]:
reduced_x_grid = skimage.measure.block_reduce(x_grid, block_size=cfg["reduce_block_sizes"][0][2], func=np.mean)
reduced_y_grid = skimage.measure.block_reduce(y_grid, block_size=cfg["reduce_block_sizes"][0][1], func=np.mean)

reduced_predicted_contours_by_structure = rai.masks_to_contours_by_structure(
    cfg=cfg, x_grid=reduced_x_grid, y_grid=reduced_y_grid, masks=predicted_masks
)
rai.plot_contours_by_structure(
    reduced_x_grid, reduced_y_grid, reduced_image_stack, reduced_predicted_contours_by_structure
)

In [None]:
# predicted_masks = rai.inference_over_jittered_grid(
#     cfg=cfg, model=rai_dependent_model, grid=(z, y, x), image_stack=reduced_image_stack, masks_stack=predicted_masks
# )

In [None]:
cfg["reduce_block_sizes"]

In [None]:
upscaled = predicted_masks

for i in range(3):
    upscaled = np.repeat(upscaled, repeats=2, axis=i)


assert upscaled.shape[0] == image_stack.shape[0]
    
upscaled.shape

In [None]:
where_mask = np.where(upscaled > 127.5)
np.min(where_mask, axis=1)

In [None]:
np.max(where_mask, axis=1)

In [None]:
reduced_image_stack = skimage.measure.block_reduce(image_stack, block_size=cfg["reduce_block_sizes"][1], func=np.mean)
reduced_image_stack.shape

In [None]:
step_size = 40

reduced_num_slices = reduced_image_stack.shape[0]
step_size = int(np.ceil(reduced_num_slices / np.ceil(reduced_num_slices / step_size)))
z = list(range(0, reduced_num_slices, step_size)) + [reduced_num_slices]
z

In [None]:
y = [60, 95, 130, 165]
x = [65, 105, 145, 185]

predicted_masks = rai.inference_over_jittered_grid(
    cfg=cfg, model=rai_dependent_model, grid=(z, y, x), image_stack=reduced_image_stack, masks_stack=upscaled, max_batch_size=10
)

In [None]:
reduced_x_grid = skimage.measure.block_reduce(x_grid, block_size=cfg["reduce_block_sizes"][1][2], func=np.mean)
reduced_y_grid = skimage.measure.block_reduce(y_grid, block_size=cfg["reduce_block_sizes"][1][1], func=np.mean)

reduced_predicted_contours_by_structure = rai.masks_to_contours_by_structure(
    cfg=cfg, x_grid=reduced_x_grid, y_grid=reduced_y_grid, masks=predicted_masks[0:original_num_slices, ...]
)
rai.plot_contours_by_structure(
    reduced_x_grid, reduced_y_grid, reduced_image_stack[0:original_num_slices,...], reduced_predicted_contours_by_structure
)

In [None]:
predicted_masks.shape

In [None]:
np.max(predicted_masks)

In [None]:
upscaled = predicted_masks

for i in range(1,3):
    upscaled = np.repeat(upscaled, repeats=2, axis=i)
    
upscaled.shape

In [None]:
structure_ds = pydicom.read_file(structure_path)
[item.ROIName for item in structure_ds.StructureSetROISequence]

In [None]:
align_map = {
    "Brain": [TG263.Brain],
    "Brainstem": [TG263.Brainstem],
    "Cochlea-Lt": [TG263.Cochlea_L],
    "Cochlea-Rt": [TG263.Cochlea_R],
    "Lacrimal-Lt": [TG263.Glnd_Lacrimal_L],
    "Lacrimal-Rt": [TG263.Glnd_Lacrimal_R],
    "Lens-Lt": [TG263.Lens_L],
    "Lens-Rt": [TG263.Lens_R],
    "Lung-Lt": [TG263.Lung_L],
    "Lung-Rt": [TG263.Lung_R],
    "Mandible": [TG263.Bone_Mandible],
    "Optic-Nerve-Lt": [TG263.OpticNrv_L],
    "Optic-Nerve-Rt": [TG263.OpticNrv_R],
    "Orbit-Lt": [TG263.Eye_L],
    "Orbit-Rt": [TG263.Eye_R],
    "Parotid-Lt": [TG263.Parotid_L],
    "Parotid-Rt": [TG263.Parotid_R],
    "Spinal-Cord": [TG263.SpinalCord],
    "Submandibular-Lt": [TG263.Glnd_Submand_L],
    "Submandibular-Rt": [TG263.Glnd_Submand_R],
}
structure_names = list(align_map.keys())

dicom_contours_by_structure = rai.dicom_to_contours_by_structure(
    ds=structure_ds, image_uids=image_uids, structure_names=structure_names
)



In [None]:
num_slices = image_stack.shape[0]
step_size = int(np.ceil(num_slices / np.ceil(num_slices / 40)))
z = list(range(0, num_slices, step_size)) + [num_slices]
z

In [None]:
where_mask = np.where(upscaled > 127.5)
np.min(where_mask, axis=1)

In [None]:
np.max(where_mask, axis=1)

In [None]:
y = list(range(125, 330, 40))
y

In [None]:
x = list(range(130, 380, 40))
x

In [None]:
predicted_masks = upscaled

looped_dice = []
for i in range(1):
    predicted_masks = rai.inference_over_jittered_grid(
        cfg=cfg, 
        model=rai_dependent_model, 
        grid=(z, y, x), 
        image_stack=image_stack, 
        masks_stack=predicted_masks, 
        max_batch_size=10, 
        verify=False  # TODO: Remove this before publishing
    )
    
    
    predicted_contours_by_structure = rai.masks_to_contours_by_structure(
        cfg=cfg, x_grid=x_grid, y_grid=y_grid, masks=predicted_masks[0:original_num_slices, ...]
    )
    
    aligned_predicted_contours_by_structure = rai.merge_contours_by_structure(
        predicted_contours_by_structure, align_map
    )
    
    dice = {}
    for name in align_map:
        dice[name] = rai.dice_from_contours_by_slice(
            dicom_contours_by_structure[name],
            aligned_predicted_contours_by_structure[name],
        )
        
    looped_dice.append(dice)

In [None]:
for name in structure_names:
    dice = []
    for item in looped_dice:
        dice.append(item[name])
        
    plt.plot(dice, '-o', label=name)
    
plt.legend()

In [None]:
# predicted_contours_by_structure

In [None]:
# looped_dice

In [None]:
rai.plot_contours_by_structure(
    x_grid, y_grid, image_stack[0:original_num_slices, ...], predicted_contours_by_structure, align_map
)

In [None]:
renamed_dicom_contours_by_structure = {
    f"DICOM {key}": item for key, item in dicom_contours_by_structure.items()
}

In [None]:
# TODO: Create an ipywidget slider for the slices

In [None]:
combined_contours_by_structure = {
    **predicted_contours_by_structure,
    **renamed_dicom_contours_by_structure,
}

rai.plot_contours_by_structure(
    x_grid, y_grid, image_stack[0:original_num_slices, ...], combined_contours_by_structure, align_map
)

In [None]:
import plotly.offline as pyo
pyo.init_notebook_mode()

In [None]:
# Create interactive bokeh
# Press button to centre on a structure

cfg
grids = (z_grid, y_grid, x_grid)
images = image_stack[0:original_num_slices, ...]
masks = predicted_masks[0:original_num_slices, ...]

In [None]:
from rai.display import interactive
import pandas

In [None]:
fig = interactive.main()
fig.show()

In [None]:
import numpy as np
import plotly.graph_objects as go
r, theta = np.mgrid[0.1:1:10j, 0:360:20j]
color = np.random.random(r.shape)
fig = go.Figure(go.Barpolar(
    r=r.ravel(),
    theta=theta.ravel(),
    marker_color=color.ravel()),)
fig.update_layout(polar_bargap=0)
fig.show()

In [None]:
import plotly.graph_objects as go

In [None]:
from plotly.subplots import make_subplots

In [None]:
from rai.mask import convert

In [None]:
from plotly.subplots import make_subplots
from skimage import data
img = data.chelsea()
fig = make_subplots(1, 2)
# We use go.Image because subplots require traces, whereas px functions return a figure
fig.add_trace(go.Image(z=img), 1, 1)
for channel, color in enumerate(['red', 'green', 'blue']):
    fig.add_trace(go.Histogram(x=img[..., channel].ravel(), opacity=0.5,
                               marker_color=color, name='%s channel' %color), 1, 2)
fig.update_layout(height=400)
fig.show()

In [None]:
img.shape

In [None]:
img = image_stack[50, ...]

In [None]:
x0, dx = convert._grid_to_transform(x_grid)
y0, dy = convert._grid_to_transform(y_grid)

In [None]:
y_grid[-1]

In [None]:
x0 = x_grid[0]
dx = x_grid[1] - x_grid[0]
y0 = y_grid[-1]
dy = y_grid[-2] - y_grid[-1]
z0 = z_grid[0]
dz = z_grid[1] - z_grid[0]

In [None]:
# go.Heatmap?

In [None]:
fig = make_subplots(
    rows=2,
    cols=2,
    vertical_spacing=0,
    horizontal_spacing=0,
)

fig.add_trace(
    go.Heatmap(
        x0=x0,
        dx=dx,
        y0=y0,
        dy=dy,
        z=image_stack[50, ...],
        colorscale="gray",
        name="transverse",
        zmin=0.2,
        zmax=0.4,
        xaxis="x",
        yaxis="y",
        hoverinfo='skip',
    ), 
    1,
    1,
)

fig.add_trace(
    go.Heatmap(
        x0=y0,
        dx=dy,
        y0=z0,
        dy=dz,
        z=image_stack[:,:,256],
        colorscale="gray",
        name="sagittal",
        zmin=0.2,
        zmax=0.4,
        xaxis="x1",
        yaxis="y1",
        hoverinfo='skip',
    ), 
    2,
    2,
)

fig.add_trace(
    go.Heatmap(
        x0=x0,
        dx=dx,
        y0=z0,
        dy=dz,
        z=image_stack[:,256,:],
        colorscale="gray",
        name="coronal",
        zmin=0.2,
        zmax=0.4,
        xaxis="x2",
        yaxis="y2",
        hoverinfo='skip',
    ), 
    2,
    1,
)


fig.update_layout(
    {
        "height": 900,
        "width": 900,
        "dragmode": "pan",
        'xaxis': {"range": [-120, 150], 'scaleanchor': 'y', "constrain":"domain", "showticklabels": False},
        'yaxis': {"range": [-100, 170], "constrain":"domain", "showticklabels": False},
        'xaxis3': {'matches': 'x', "constrain":"domain", "showticklabels": False},
        'yaxis3': {"range": [-160, 110], 'scaleanchor': 'x3', "constrain":"domain", "showticklabels": False},
        'yaxis4': {'matches': 'y3', "constrain":"domain", "showticklabels": False},
        'xaxis4': {"range": [-100, 170], 'matches': 'y', "constrain":"domain", "showticklabels": False},
    }
)

fig.update_coloraxes(showscale=False)

fig.show(
    config={
        'displayModeBar': False,
        "displaylogo": False,
        "scrollZoom": True,
    }

)
# # fig.update_layout(coloraxis_showscale=False)
# fig.update_xaxes(
#     selector={"name": "coronal"},
#     range=[-150, 150],
#     constrain="domain",
# )
# fig.update_yaxes(
#     selector={"name": "coronal"},
#     scaleanchor = "x",
#     scaleratio = 1,
#     range=[-150, 150],
#     constrain="domain",
# )
# fig.update_xaxes(
#     selector={"name": "sagittal"},
#     scaleanchor = "y",
#     scaleratio = 1,
# )

In [None]:
fig.show?

In [None]:
print(fig)

In [None]:
fig.update_xaxes?

In [None]:
go.Heatmap?

In [None]:
fig = px.imshow(
    image_stack[50, ...], 
    x=x_grid,
    y=y_grid,
    width=800, 
    height=800,
    color_continuous_scale='gray',
    zmin=0.2,
    zmax=0.4,
)


fig.update_layout(coloraxis_showscale=False)
fig.update_xaxes(showticklabels=False, range=[-100, 150])
fig.update_yaxes(showticklabels=False, range=[150, -100])

In [None]:
fig.update_yaxes?

In [None]:
# Import the necessaries libraries
import plotly.offline as pyo
import plotly.graph_objs as go
# Set notebook mode to work in offline
pyo.init_notebook_mode()
# Create traces
trace0 = go.Scatter(
    x=[1, 2, 3, 4],
    y=[10, 15, 13, 17]
)
trace1 = go.Scatter(
    x=[1, 2, 3, 4],
    y=[16, 5, 11, 9]
)
# Fill out data with our traces
data = [trace0, trace1]
# Plot it and save as basic-line.html
pyo.iplot(data, filename = 'basic-line')

In [None]:
# predicted_contours_by_structure = rai.masks_to_contours_by_structure(
#     cfg=cfg, x_grid=x_grid, y_grid=y_grid, masks=predicted_masks[0:original_num_slices, ...]
# )

# rai.plot_contours_by_structure(
#     x_grid, y_grid, image_stack[0:original_num_slices, ...], combined_contours_by_structure, align_map
# )

In [None]:
# TODO:
# * Change figures to clickable interactive transverse/coronal/sagital bokeh
# * Calculate and report hausdorff and surface dice as well