<h1>Rib fracture, RoI and prediction visualisations</h1>

<h3>Package imports</h3>

In [1]:
# Import all the necessary packages
import numpy as np
import nibabel as nib
import itk
import itkwidgets
from ipywidgets import interact, interactive, IntSlider, ToggleButtons
import matplotlib.pyplot as plt
import cv2
from dataset.fracnet_dataset import FracNetTrainDataset
from dataset import transforms as tsfm

# Matplotlib setting
%matplotlib inline

<h3>Function definition</h3>

In [2]:
# Image loading function
def load_img_data(image_path):
    image_obj = nib.load(image_path)
    return image_obj.get_fdata()

# Convert a grayscale image (i, j, k) to a rgb image (i, j, k, c)
def convert_gray_to_rgb(gray_img):
    stacked_img = np.stack((gray_img,)*3, axis=-1)
    stacked_img_sca = (stacked_img - stacked_img.min()) / (stacked_img.max() - stacked_img.min())
    return (stacked_img_sca*255).astype(int)

# Axial plane view function for single image
def explore_axial_single(layer):
    plt.figure(figsize=(20, 10))
    plt.imshow(gray_img1[:, :, layer], cmap='gray');
    plt.title('Explore Layers of Ribs', fontsize=20)
    plt.axis('off')
    return layer

# Saggital plane view function for two images
def explore_saggital_plane(img1, img2, title1, title2, fig_size=(18, 18), title_font_size=22):
    def explore_plane_func(layer):
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=fig_size)
        ax1.imshow(img1[layer, :, :, :])
        ax2.imshow(img2[layer, :, :, :])
        ax1.axis('off')
        ax2.axis('off')
        ax1.set_title(title1, fontsize=title_font_size)
        ax2.set_title(title2, fontsize=title_font_size)
        return layer
    return explore_plane_func

# Coronal plane view function for two images
def explore_coronal_plane(img1, img2, title1, title2, fig_size=(18, 18), title_font_size=22):
    def explore_plane_func(layer):
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=fig_size)
        ax1.imshow(img1[:, layer, :, :])
        ax2.imshow(img2[:, layer, :, :])
        ax1.axis('off')
        ax2.axis('off')
        ax1.set_title(title1, fontsize=title_font_size)
        ax2.set_title(title2, fontsize=title_font_size)
        return layer    
    return explore_plane_func

# Axial plane view function for two images
def explore_axial_plane(img1, img2, title1, title2, fig_size=(18, 18), title_font_size=22):
    def explore_plane_func(layer):
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=fig_size)
        ax1.imshow(img1[:, :, layer, :])
        ax2.imshow(img2[:, :, layer, :])
        ax1.axis('off')
        ax2.axis('off')
        ax1.set_title(title1, fontsize=title_font_size)
        ax2.set_title(title2, fontsize=title_font_size)
        return layer
    return explore_plane_func

# Function to get RoIs from centroid
def get_roi_coordinates(img_arr, centroid, crop_size=64):
    src_beg = [max(0, centroid[i] - crop_size // 2) for i in range(len(centroid))]
    src_end = [min(img_arr.shape[i], centroid[i] + crop_size // 2) for i in range(len(centroid))]
    dst_beg = [max(0, crop_size // 2 - centroid[i]) for i in range(len(centroid))]
    dst_end = [min(img_arr.shape[i] - (centroid[i] - crop_size // 2), crop_size) for i in range(len(centroid))]
    return src_beg, src_end, dst_beg, dst_end

<h3>Data Loading</h3>

In [3]:
# Load the image data
gray_img1 = load_img_data("../FracNet/data/train/ribfrac-train-images/RibFrac301-image.nii.gz")
gray_img2 = load_img_data("../FracNet/data/val/ribfrac-val-images/RibFrac381-image.nii.gz")
gray_img3 = load_img_data("../FracNet/data/val/ribfrac-val-images/RibFrac382-image.nii.gz")
label_img1 = load_img_data("../FracNet/data/train/ribfrac-train-labels/RibFrac301-label.nii.gz")
label_img2 = load_img_data("../FracNet/data/val/ribfrac-val-labels/RibFrac381-label.nii.gz")
label_img3 = load_img_data("../FracNet/data/val/ribfrac-val-labels/RibFrac382-label.nii.gz")
pred_img2 = load_img_data("../FracNet/data/val/ribfra-val-pred_w_const_gauss_noise/RibFrac381-label.nii.gz")
pred_img3 = load_img_data("../FracNet/data/val/ribfra-val-pred_w_const_gauss_noise/RibFrac382-label.nii.gz")

<h3>Colours</h3>

In [4]:
# Set the colours to apply
red = np.array([255, 0, 0])
blue = np.array([0, 0, 140])
green = np.array([0, 100, 0])

In [5]:
# Convert the gray scale images to rgb
rgb_img1 = convert_gray_to_rgb(gray_img1)
rgb_img2 = convert_gray_to_rgb(gray_img2)
rgb_img3 = convert_gray_to_rgb(gray_img3)

In [6]:
# Print the array shapes
gray_shape = (gray_img1.shape, gray_img2.shape, gray_img3.shape)
rgb_shape = (rgb_img1.shape, rgb_img2.shape, rgb_img3.shape)
print(f'gray_shape: {gray_shape}')
print(f'rgb_shape: {rgb_shape}')

gray_shape: ((512, 512, 407), (512, 512, 349), (512, 512, 313))
rgb_shape: ((512, 512, 407, 3), (512, 512, 349, 3), (512, 512, 313, 3))


<h3>Ribs visualisation (Axial plane)</h3>

In [7]:
# Create a widget to explore the image
interact(explore_axial_single, layer=(0, gray_img1.shape[2] - 1))

interactive(children=(IntSlider(value=203, description='layer', max=406), Output()), _dom_classes=('widget-int…

<function __main__.explore_axial_single(layer)>

<h3>Fracture visualisation</h3>

In [8]:
# Copy the rgb image
rgb_img1_w_labels = rgb_img1.copy()

# Set the fracture red
rgb_img1_w_labels[np.where(label_img1 != 0)] = red

# Set the title of the images
title1 = 'Original image'
title2 = 'Image with fracture'

<h3>Fracture visualisation (Saggital place)</h3>

In [9]:
# Create the plane exploration function
explore_func = explore_saggital_plane(img1=rgb_img1, img2=rgb_img1_w_labels, title1=title1, title2=title2)

# Create an interactive ipywidgets visualisation
interact(explore_func, layer=(0, rgb_img1.shape[0] - 1))

interactive(children=(IntSlider(value=255, description='layer', max=511), Output()), _dom_classes=('widget-int…

<function __main__.explore_saggital_plane.<locals>.explore_plane_func(layer)>

<h3>Fracture visualisation (Coronal place)</h3>

In [10]:
# Create the plane exploration function
explore_func = explore_coronal_plane(img1=rgb_img1, img2=rgb_img1_w_labels, title1=title1, title2=title2)

# Create an interactive ipywidgets visualisation
interact(explore_func, layer=(0, rgb_img1.shape[1] - 1))

interactive(children=(IntSlider(value=255, description='layer', max=511), Output()), _dom_classes=('widget-int…

<function __main__.explore_coronal_plane.<locals>.explore_plane_func(layer)>

<h3>Fracture visualisation (Axial place)</h3>

In [11]:
# Create the plane exploration function
explore_func = explore_axial_plane(img1=rgb_img1, img2=rgb_img1_w_labels, title1=title1, title2=title2)

# Create an interactive ipywidgets visualisation
interact(explore_func, layer=(0, rgb_img1.shape[2] - 1))

interactive(children=(IntSlider(value=203, description='layer', max=406), Output()), _dom_classes=('widget-int…

<function __main__.explore_axial_plane.<locals>.explore_plane_func(layer)>

<h3>Prediction visualisation</h3>

In [12]:
# Copy the rgb image
rgb_img2_w_preds = rgb_img2.copy()

# Set the predictions red
rgb_img2_w_preds[np.where(pred_img2 != 0)] = red

# Set the title of the images
title1 = 'Original image'
title2 = 'Image with facture prediction'

<h3>Prediction visualisation (Saggital place)</h3>

In [13]:
# Create the plane exploration function
explore_func = explore_saggital_plane(img1=rgb_img2, img2=rgb_img2_w_preds, title1=title1, title2=title2)

# Create an interactive ipywidgets visualisation
interact(explore_func, layer=(0, rgb_img2.shape[0] - 1))

interactive(children=(IntSlider(value=255, description='layer', max=511), Output()), _dom_classes=('widget-int…

<function __main__.explore_saggital_plane.<locals>.explore_plane_func(layer)>

<h3>Prediction visualisation (Coronal place)</h3>

In [14]:
# Create the plane exploration function
explore_func = explore_coronal_plane(img1=rgb_img2, img2=rgb_img2_w_preds, title1=title1, title2=title2)

# Create an interactive ipywidgets visualisation
interact(explore_func, layer=(0, rgb_img2.shape[1] - 1))

interactive(children=(IntSlider(value=255, description='layer', max=511), Output()), _dom_classes=('widget-int…

<function __main__.explore_coronal_plane.<locals>.explore_plane_func(layer)>

<h3>Prediction visualisation (Axial place)</h3>

In [15]:
# Create the plane exploration function
explore_func = explore_axial_plane(img1=rgb_img2, img2=rgb_img2_w_preds, title1=title1, title2=title2)

# Create an interactive ipywidgets visualisation
interact(explore_func, layer=(0, rgb_img2.shape[2] - 1))

interactive(children=(IntSlider(value=174, description='layer', max=348), Output()), _dom_classes=('widget-int…

<function __main__.explore_axial_plane.<locals>.explore_plane_func(layer)>

<h3>Prediction vs label visualisation</h3>

In [16]:
# Copy the rgb image
rgb_img2_w_preds_n_labels = rgb_img2.copy()

# Confusion matrix
true_pos_cond = (pred_img2 != 0) & (label_img2 != 0)
false_pos_cond = (pred_img2 != 0) & (label_img2 == 0)
false_neg_cond = (pred_img2 == 0) & (label_img2 != 0)

# Set colours
rgb_img2_w_preds_n_labels[np.where(true_pos_cond)] = red
rgb_img2_w_preds_n_labels[np.where(false_pos_cond)] = blue
rgb_img2_w_preds_n_labels[np.where(false_neg_cond)] = green

# Set the title of the images
title1 = 'Original image'
title2 = 'Image w inference (TP: red, FP: blue, FN: green)'

<h3>Prediction vs label visualisation (Saggital place)</h3>

In [17]:
# Create the plane exploration function
explore_func = explore_saggital_plane(img1=rgb_img2, img2=rgb_img2_w_preds_n_labels, title1=title1, title2=title2)

# Create an interactive ipywidgets visualisation
interact(explore_func, layer=(0, rgb_img2.shape[0] - 1))

interactive(children=(IntSlider(value=255, description='layer', max=511), Output()), _dom_classes=('widget-int…

<function __main__.explore_saggital_plane.<locals>.explore_plane_func(layer)>

<h3>Prediction vs label visualisation (Coronal place)</h3>

In [18]:
# Create the plane exploration function
explore_func = explore_coronal_plane(img1=rgb_img2, img2=rgb_img2_w_preds_n_labels, title1=title1, title2=title2)

# Create an interactive ipywidgets visualisation
interact(explore_func, layer=(0, rgb_img2.shape[1] - 1))

interactive(children=(IntSlider(value=255, description='layer', max=511), Output()), _dom_classes=('widget-int…

<function __main__.explore_coronal_plane.<locals>.explore_plane_func(layer)>

<h3>Prediction vs label visualisation (Axial place)</h3>

In [19]:
# Create the plane exploration function
explore_func = explore_axial_plane(img1=rgb_img2, img2=rgb_img2_w_preds_n_labels, title1=title1, title2=title2)

# Create an interactive ipywidgets visualisation
interact(explore_func, layer=(0, rgb_img2.shape[2] - 1))

interactive(children=(IntSlider(value=174, description='layer', max=348), Output()), _dom_classes=('widget-int…

<function __main__.explore_axial_plane.<locals>.explore_plane_func(layer)>

<h3>Original vs modified positive sampling visualisation</h3>

In [20]:
# Copy rgb image
rgb_img2_w_rois = rgb_img2.copy()

# Set transform for FracNetTrainDataset
transforms = [
    tsfm.Window(-200, 1000),
    tsfm.MinMaxNorm(-200, 1000)
]

# Instantiate FracNetTrainDataset
frac_dataset = FracNetTrainDataset(
    image_dir='/home/ryan/PycharmProjects/FracNet/data/val/ribfrac-val-images/', 
    label_dir='/home/ryan/PycharmProjects/FracNet/data/val/ribfrac-val-labels/',
    transforms=transforms
)

Here, we simulate 10 epochs of positive sampling regions of interest

In [21]:
# Retrieve ROIs
epochs_sim = 10
pos_centroids = []
pos_centroids_new = []

for e in range(epochs_sim):
    pos_centroids += frac_dataset._get_pos_centroids(label_img2.astype(int))
    pos_centroids_new += frac_dataset._get_pos_centroids_new(label_img2.astype(int))

In [22]:
# Get the positive RoIs
pos_roi_coordinates = [get_roi_coordinates(rgb_img2_w_rois, roi) for roi in pos_centroids]
pos_roi_coordinates_new = [get_roi_coordinates(rgb_img2_w_rois, roi) for roi in pos_centroids_new]

In [23]:
# Copy the rgb image
rgb_img2_w_rois_new = rgb_img2_w_rois.copy()

# Colour the RoIs of the original positive sampling
for src_beg, src_end, dst_beg, dst_end in pos_roi_coordinates:
    rgb_img2_w_rois[
        src_beg[0]:src_end[0],
        src_beg[1]:src_end[1],
        src_beg[2]:src_end[2],
        :
    ] = blue

# Colour the RoIs of the modified positive sampling
for src_beg, src_end, dst_beg, dst_end in pos_roi_coordinates_new:
    rgb_img2_w_rois_new[
        src_beg[0]:src_end[0],
        src_beg[1]:src_end[1],
        src_beg[2]:src_end[2],
        :
    ] = blue

In [24]:
# Colour the fractures
rgb_img2_w_rois[np.where(label_img2 != 0)] = red
rgb_img2_w_rois_new[np.where(label_img2 != 0)] = red

In [25]:
# Set the title of the images
title1 = 'Original positive sampling'
title2 = 'Modified positive sampling'

<h3>Original vs modified positive sampling visualisation (Saggital place)</h3>

In [26]:
# Create the plane exploration function
explore_func = explore_saggital_plane(img1=rgb_img2_w_rois, img2=rgb_img2_w_rois_new, title1=title1, title2=title2)

# Create an interactive ipywidgets visualisation
interact(explore_func, layer=(0, rgb_img2.shape[0] - 1))

interactive(children=(IntSlider(value=255, description='layer', max=511), Output()), _dom_classes=('widget-int…

<function __main__.explore_saggital_plane.<locals>.explore_plane_func(layer)>

<h3>Original vs modified positive sampling visualisation (Coronal place)</h3>

In [27]:
# Create the plane exploration function
explore_func = explore_coronal_plane(img1=rgb_img2_w_rois, img2=rgb_img2_w_rois_new, title1=title1, title2=title2)

# Create an interactive ipywidgets visualisation
interact(explore_func, layer=(0, rgb_img2.shape[1] - 1))

interactive(children=(IntSlider(value=255, description='layer', max=511), Output()), _dom_classes=('widget-int…

<function __main__.explore_coronal_plane.<locals>.explore_plane_func(layer)>

<h3>Original vs modified positive sampling visualisation (Axial place)</h3>

In [28]:
# Create the plane exploration function
explore_func = explore_axial_plane(img1=rgb_img2_w_rois, img2=rgb_img2_w_rois_new, title1=title1, title2=title2)

# Create an interactive ipywidgets visualisation
interact(explore_func, layer=(0, rgb_img2.shape[2] - 1))

interactive(children=(IntSlider(value=174, description='layer', max=348), Output()), _dom_classes=('widget-int…

<function __main__.explore_axial_plane.<locals>.explore_plane_func(layer)>

<h3>Original vs modified negative sampling visualisation</h3>

In [29]:
# Copy rgb image
rgb_img2_w_rois = rgb_img2.copy()

# Set transform for FracNetTrainDataset
transforms = [
    tsfm.Window(-200, 1000),
    tsfm.MinMaxNorm(-200, 1000)
]

# Instantiate FracNetTrainDataset
frac_dataset = FracNetTrainDataset(
    image_dir='/home/ryan/PycharmProjects/FracNet/data/val/ribfrac-val-images/', 
    label_dir='/home/ryan/PycharmProjects/FracNet/data/val/ribfrac-val-labels/',
    transforms=transforms
)

Here, we simulate 10 epochs of positive sampling regions of interest

In [30]:
# Retrieve ROIs
epochs_sim = 10
neg_centroids = []
neg_centroids_new = []

for e in range(epochs_sim):
    neg_centroids += frac_dataset._get_neg_centroids(pos_centroids, label_img2.shape)
    neg_centroids_temp = frac_dataset._get_neg_centroids(pos_centroids, label_img2.shape)
    neg_centroids_new += neg_centroids_temp + frac_dataset._new_get_neg_centroids(gray_img2, n_centroids=len(neg_centroids))

In [31]:
# Get the positive RoIs
neg_roi_coordinates = [get_roi_coordinates(rgb_img2_w_rois, roi) for roi in neg_centroids]
neg_roi_coordinates_new = [get_roi_coordinates(rgb_img2_w_rois, roi) for roi in neg_centroids_new]

In [32]:
# Copy the rgb image
rgb_img2_w_rois_new = rgb_img2_w_rois.copy()

# Colour the RoIs of the original negative sampling
for src_beg, src_end, dst_beg, dst_end in neg_roi_coordinates:
    rgb_img2_w_rois[
        src_beg[0]:src_end[0],
        src_beg[1]:src_end[1],
        src_beg[2]:src_end[2],
        :
    ] = blue

# Colour the RoIs of the modified negative sampling
for src_beg, src_end, dst_beg, dst_end in neg_roi_coordinates_new:
    rgb_img2_w_rois_new[
        src_beg[0]:src_end[0],
        src_beg[1]:src_end[1],
        src_beg[2]:src_end[2],
        :
    ] = blue

In [33]:
# Colour the fractures
rgb_img2_w_rois[np.where(label_img2 != 0)] = red
rgb_img2_w_rois_new[np.where(label_img2 != 0)] = red

In [34]:
# Set the title of the images
title1 = 'Original negative sampling'
title2 = 'Modified negative sampling'

<h3>Original vs modified negative sampling visualisation (Saggital place)</h3>

In [35]:
# Create the plane exploration function
explore_func = explore_saggital_plane(img1=rgb_img2_w_rois, img2=rgb_img2_w_rois_new, title1=title1, title2=title2)

# Create an interactive ipywidgets visualisation
interact(explore_func, layer=(0, rgb_img2.shape[0] - 1))

interactive(children=(IntSlider(value=255, description='layer', max=511), Output()), _dom_classes=('widget-int…

<function __main__.explore_saggital_plane.<locals>.explore_plane_func(layer)>

<h3>Original vs modified negative sampling visualisation (Coronal place)</h3>

In [36]:
# Create the plane exploration function
explore_func = explore_coronal_plane(img1=rgb_img2_w_rois, img2=rgb_img2_w_rois_new, title1=title1, title2=title2)

# Create an interactive ipywidgets visualisation
interact(explore_func, layer=(0, rgb_img2.shape[1] - 1))

interactive(children=(IntSlider(value=255, description='layer', max=511), Output()), _dom_classes=('widget-int…

<function __main__.explore_coronal_plane.<locals>.explore_plane_func(layer)>

<h3>Original vs modified negative sampling visualisation (Axial place)</h3>

In [37]:
# Create the plane exploration function
explore_func = explore_axial_plane(img1=rgb_img2_w_rois, img2=rgb_img2_w_rois_new, title1=title1, title2=title2)

# Create an interactive ipywidgets visualisation
interact(explore_func, layer=(0, rgb_img2.shape[2] - 1))

interactive(children=(IntSlider(value=174, description='layer', max=348), Output()), _dom_classes=('widget-int…

<function __main__.explore_axial_plane.<locals>.explore_plane_func(layer)>