In [1]:
# Import all the necessary packages
import numpy as np
import nibabel as nib
import itk
import itkwidgets
from ipywidgets import interact, interactive, IntSlider, ToggleButtons
import matplotlib.pyplot as plt
import cv2
from dataset.fracnet_dataset import FracNetTrainDataset
from dataset import transforms as tsfm

%matplotlib inline
import seaborn as sns
sns.set_style('darkgrid')

<h3>Data Loading</h3>

In [2]:
# Load the image data
def load_img_data(image_path):
    image_obj = nib.load(image_path)
    return image_obj.get_fdata()
gray_img1 = load_img_data("/home/ryan/PycharmProjects/FracNet/data/train/ribfrac-train-images/RibFrac301-image.nii.gz")
gray_img2 = load_img_data("/home/ryan/PycharmProjects/FracNet/data/val/ribfrac-val-images/RibFrac381-image.nii.gz")
gray_img3 = load_img_data("/home/ryan/PycharmProjects/FracNet/data/val/ribfrac-val-images/RibFrac382-image.nii.gz")
label_img1 = load_img_data("/home/ryan/PycharmProjects/FracNet/data/train/ribfrac-train-labels/RibFrac301-label.nii.gz")
label_img2 = load_img_data("/home/ryan/PycharmProjects/FracNet/data/val/ribfrac-val-labels/RibFrac381-label.nii.gz")
label_img3 = load_img_data("/home/ryan/PycharmProjects/FracNet/data/val/ribfrac-val-labels/RibFrac382-label.nii.gz")
pred_img2 = load_img_data("/home/ryan/PycharmProjects/FracNet/data/val/ribfra-val-pred_w_const_gauss_noise/RibFrac381-label.nii.gz")
pred_img3 = load_img_data("/home/ryan/PycharmProjects/FracNet/data/val/ribfra-val-pred_w_const_gauss_noise/RibFrac382-label.nii.gz")

<h3>Colours</h3>

In [3]:
red = np.array([255, 0, 0])
blue = np.array([0, 0, 140])
green = np.array([0, 100, 0])

In [4]:
def convert_gray_to_rgb(gray_img):
    stacked_img = np.stack((gray_img,)*3, axis=-1)
    stacked_img_sca = (stacked_img - stacked_img.min()) / (stacked_img.max() - stacked_img.min())
    return (stacked_img_sca*255).astype(int)

In [5]:
rgb_img1 = convert_gray_to_rgb(gray_img1)
rgb_img2 = convert_gray_to_rgb(gray_img2)
rgb_img3 = convert_gray_to_rgb(gray_img3)

In [6]:
(rgb_img1.shape, rgb_img2.shape, rgb_img3.shape)

((512, 512, 407, 3), (512, 512, 349, 3), (512, 512, 313, 3))

<h3>Exploration</h3>
<p><b>Visualize original volume (Image 1)</b></p>

In [7]:
# Define a function to visualize the data
def explore_3dimage(layer):
    plt.figure(figsize=(20, 10))
    plt.imshow(gray_img1[:, :, layer], cmap='gray');
    plt.title('Explore Layers of Ribs', fontsize=20)
    plt.axis('off')
    return layer
    
# Run the ipywidgets interact() function to explore the data
interact(explore_3dimage, layer=(0, gray_img1.shape[2] - 1));

interactive(children=(IntSlider(value=203, description='layer', max=406), Output()), _dom_classes=('widget-int…

<p><b>Visualize the fractures (Image 1)</b></p>

In [8]:
rgb_img1_w_labels = rgb_img1.copy()

In [9]:
rgb_img1_w_labels[np.where(label_img1 != 0)] = np.array([255, 0, 0])

In [10]:
# Define a function to visualize the data
def explore_3dimage(layer):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 20))
    ax1.imshow(rgb_img1[:, :, layer, :])
    ax2.imshow(rgb_img1_w_labels[:, :, layer, :])
    # fig.suptitle('Horizontally stacked subplots')
    # fig.tight_layout()
    ax1.axis('off')
    ax2.axis('off')
    ax1.set_title('Original image', fontsize=22)
    ax2.set_title('Image with fracture', fontsize=22)
    return layer
    
# Run the ipywidgets interact() function to explore the data
interact(explore_3dimage, layer=(0, rgb_img1.shape[2] - 1));

interactive(children=(IntSlider(value=203, description='layer', max=406), Output()), _dom_classes=('widget-int…

<p><b>Visualize the inference (Image 2)</b></p>

In [11]:
rgb_img2_w_preds = rgb_img2.copy()

In [12]:
rgb_img2_w_preds[np.where(pred_img2 != 0)] = red

In [13]:
# Define a function to visualize the data
def explore_3dimage(layer):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 20))
    ax1.imshow(rgb_img2[:, :, layer, :])
    ax2.imshow(rgb_img2_w_preds[:, :, layer, :])
    # fig.suptitle('Horizontally stacked subplots')
    # fig.tight_layout()
    ax1.axis('off')
    ax2.axis('off')
    ax1.set_title('Original image', fontsize=22)
    ax2.set_title('Image with fracture pred', fontsize=22)
    return layer
    
# Run the ipywidgets interact() function to explore the data
interact(explore_3dimage, layer=(0, rgb_img2.shape[2] - 1));

interactive(children=(IntSlider(value=174, description='layer', max=348), Output()), _dom_classes=('widget-int…

<p><b>Visualize the inference vs label (Image 2)</b></p>

In [14]:
rgb_img2_w_preds_n_labels = rgb_img2.copy()

In [15]:
true_pos_cond = (pred_img2 != 0) & (label_img2 != 0)
false_pos_cond = (pred_img2 != 0) & (label_img2 == 0)
false_neg_cond = (pred_img2 == 0) & (label_img2 != 0)
rgb_img2_w_preds_n_labels[np.where(true_pos_cond)] = red
rgb_img2_w_preds_n_labels[np.where(false_pos_cond)] = blue
rgb_img2_w_preds_n_labels[np.where(false_neg_cond)] = green

In [16]:
(pred_img2.shape, label_img2.shape)

((512, 512, 349), (512, 512, 349))

In [17]:
# Define a function to visualize the data
def explore_3dimage(layer):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 20))
    ax1.imshow(rgb_img2[:, :, layer, :])
    ax2.imshow(rgb_img2_w_preds_n_labels[:, :, layer, :])
    # fig.suptitle('Horizontally stacked subplots')
    # fig.tight_layout()
    ax1.axis('off')
    ax2.axis('off')
    ax1.set_title('Original image', fontsize=22)
    ax2.set_title('Image w inference (TP: red, FP: blue, FN: green)', fontsize=22)
    return layer
    
# Run the ipywidgets interact() function to explore the data
interact(explore_3dimage, layer=(0, rgb_img2.shape[2] - 1));

interactive(children=(IntSlider(value=174, description='layer', max=348), Output()), _dom_classes=('widget-int…

<p><b>Visualize the ROIs (Image 2)</b></p>

In [18]:
rgb_img2_w_rois = rgb_img2.copy()

In [19]:
transforms = [
    tsfm.Window(-200, 1000),
    tsfm.MinMaxNorm(-200, 1000)
]

frac_dataset = FracNetTrainDataset(
    image_dir='/home/ryan/PycharmProjects/FracNet/data/val/ribfrac-val-images/', 
    label_dir='/home/ryan/PycharmProjects/FracNet/data/val/ribfrac-val-labels/',
    transforms=transforms
)

In [20]:
# Retrieve ROIs
pos_centroids = frac_dataset._get_pos_centroids(label_img2.astype(int))
neg_centroids = frac_dataset._get_neg_centroids(pos_centroids, label_img2.shape)

In [21]:
def get_roi_coordinates(img_arr, centroid, crop_size=64):
    src_beg = [max(0, centroid[i] - crop_size // 2) for i in range(len(centroid))]
    src_end = [min(img_arr.shape[i], centroid[i] + crop_size // 2) for i in range(len(centroid))]
    dst_beg = [max(0, crop_size // 2 - centroid[i]) for i in range(len(centroid))]
    dst_end = [min(img_arr.shape[i] - (centroid[i] - crop_size // 2), crop_size) for i in range(len(centroid))]
    return src_beg, src_end, dst_beg, dst_end

In [22]:
pos_roi_coordinates = [get_roi_coordinates(rgb_img2_w_rois, roi) for roi in pos_centroids]
neg_roi_coordinates = [get_roi_coordinates(rgb_img2_w_rois, roi) for roi in neg_centroids]

In [23]:
for src_beg, src_end, dst_beg, dst_end in pos_roi_coordinates:
    rgb_img2_w_rois[
        src_beg[0]:src_end[0],
        src_beg[1]:src_end[1],
        src_beg[2]:src_end[2],
        :
    ] = blue

In [24]:
for src_beg, src_end, dst_beg, dst_end in neg_roi_coordinates:
    rgb_img2_w_rois[
        src_beg[0]:src_end[0],
        src_beg[1]:src_end[1],
        src_beg[2]:src_end[2],
        :
    ] = green

In [25]:
rgb_img2_w_rois[np.where(label_img2 != 0)] = red

In [26]:
# Define a function to visualize the data
def explore_3dimage(layer):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 20))
    ax1.imshow(rgb_img2[:, :, layer, :])
    ax2.imshow(rgb_img2_w_rois[:, :, layer, :])
    # fig.suptitle('Horizontally stacked subplots')
    # fig.tight_layout()
    ax1.axis('off')
    ax2.axis('off')
    ax1.set_title('Original image', fontsize=22)
    ax2.set_title('Image w ROIs (Fracture: red, Pos ROI: blue, Neg ROI: green)', fontsize=22)
    return layer
    
# Run the ipywidgets interact() function to explore the data
interact(explore_3dimage, layer=(0, rgb_img2.shape[2] - 1));

interactive(children=(IntSlider(value=174, description='layer', max=348), Output()), _dom_classes=('widget-int…

<p><b>Visualize the ROIs (Image 3)</b></p>

In [27]:
rgb_img3_w_rois = rgb_img3.copy()

In [28]:
transforms = [
    tsfm.Window(-200, 1000),
    tsfm.MinMaxNorm(-200, 1000)
]

frac_dataset = FracNetTrainDataset(
    image_dir='/home/ryan/PycharmProjects/FracNet/data/val/ribfrac-val-images/', 
    label_dir='/home/ryan/PycharmProjects/FracNet/data/val/ribfrac-val-labels/',
    transforms=transforms
)

In [29]:
# Retrieve ROIs
pos_centroids = frac_dataset._get_pos_centroids(label_img3.astype(int))
neg_centroids = frac_dataset._get_neg_centroids(pos_centroids, label_img3.shape)

In [30]:
def get_roi_coordinates(img_arr, centroid, crop_size=64):
    src_beg = [max(0, centroid[i] - crop_size // 2) for i in range(len(centroid))]
    src_end = [min(img_arr.shape[i], centroid[i] + crop_size // 2) for i in range(len(centroid))]
    dst_beg = [max(0, crop_size // 2 - centroid[i]) for i in range(len(centroid))]
    dst_end = [min(img_arr.shape[i] - (centroid[i] - crop_size // 2), crop_size) for i in range(len(centroid))]
    return src_beg, src_end, dst_beg, dst_end

In [31]:
pos_roi_coordinates = [get_roi_coordinates(rgb_img3_w_rois, roi) for roi in pos_centroids]
neg_roi_coordinates = [get_roi_coordinates(rgb_img3_w_rois, roi) for roi in neg_centroids]

In [32]:
for src_beg, src_end, dst_beg, dst_end in pos_roi_coordinates:
    rgb_img3_w_rois[
        src_beg[0]:src_end[0],
        src_beg[1]:src_end[1],
        src_beg[2]:src_end[2],
        :
    ] = blue

In [33]:
for src_beg, src_end, dst_beg, dst_end in neg_roi_coordinates:
    rgb_img3_w_rois[
        src_beg[0]:src_end[0],
        src_beg[1]:src_end[1],
        src_beg[2]:src_end[2],
        :
    ] = green

In [34]:
rgb_img3_w_rois[np.where(label_img3 != 0)] = red

In [35]:
# Define a function to visualize the data
def explore_3dimage(layer):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 20))
    ax1.imshow(rgb_img3[:, :, layer, :])
    ax2.imshow(rgb_img3_w_rois[:, :, layer, :])
    # fig.suptitle('Horizontally stacked subplots')
    # fig.tight_layout()
    ax1.axis('off')
    ax2.axis('off')
    ax1.set_title('Original image', fontsize=22)
    ax2.set_title('Image w ROIs (Fracture: red, Pos ROI: blue, Neg ROI: green)', fontsize=22)
    return layer
    
# Run the ipywidgets interact() function to explore the data
interact(explore_3dimage, layer=(0, rgb_img3.shape[2] - 1));

interactive(children=(IntSlider(value=156, description='layer', max=312), Output()), _dom_classes=('widget-int…