In [1]:
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import ipywidgets as widgets
from ipywidgets import interact

In [2]:
import matplotlib
font = {'family' : 'serif',
        'weight' : 'normal',
        'size'   : 22}

matplotlib.rc('font', **font)

In [3]:
ordered_verts = ['T4', 'T5', 'T6', 'T7', 'T8', 'T9', 'T10', 'T11', 'T12', 'L1', 'L2', 'L3', 'L4']

## Post-processing for midline finder

In [4]:
preds = np.load('../outputs/midline_finder_preds.npz')

ids, masks, labels = preds.values()
ids.shape, masks.shape, labels.shape

((46,), (46, 1, 512, 512), (46, 13))

In [5]:
#np.savez('../outputs/midline_no_mips_preds.npz',ids=ids, masks=masks, labels=labels)

In [6]:
def sigmoid(x):
    return 1/(1+np.exp(-x))

In [7]:
@interact
def plot_mask(name=list(ids)):
    img = np.load(f'../midline_data/testing/slices/coronal/{name}.npy')
    print(img.shape)
    fig, ax = plt.subplots(1,1, figsize=(9, 9))
    ax.axis('off')
    n = list(ids).index(name)
    mask = sigmoid(masks[n, 0])
    norm_labels = np.round(sigmoid(labels[n]))
    levels = []
    for i in range(norm_labels.shape[0]):
        if norm_labels[i] == 1:
            levels.append(ordered_verts[i])
    mask = np.where(mask < 0.5, np.nan, 1.0)
    #contours = measure.find_contours(mask, 0.55)
    ax.set_title(levels)
    # 0: MIP, 1: MIDLINE, 2:STD
    ax.imshow(img[..., 1], cmap='gray')
    #ax.imshow(mask, alpha=1, cmap='autumn')
    #print(contours)
#     for contour in contours:
#         ax.plot(contour[:, 1], contour[:, 0], linewidth=3)
        

interactive(children=(Dropdown(description='name', options=('04_06_2014_431_Sag', 'fr_553_LS_Sag', '16_05_2014…

In [8]:
from scipy.ndimage import gaussian_filter

In [9]:
@interact
def plot_midline(name=list(ids)):
    idx = list(ids).index(name)
    img = np.load(f'../midline_data/testing/slices/coronal/{name}.npy')
    pred = masks[idx, 0]
    fig, ax = plt.subplots(1, 2, figsize=(20, 10))
    #ax[1].axis('off')
    x = np.linspace(0, pred.shape[0], pred.shape[0])
    #pred = gaussian_filter(pred, sigma=5)
    norm_pred = np.where(pred < 0.5, 0, 1) 
    max_ = np.argmax(pred, axis=1)
    ax[0].scatter(max_, x, s=20)
    ax[1].imshow(img)
    ax[1].imshow(norm_pred, alpha=0.5)
    ax[0].set_ylim([512, 0])

interactive(children=(Dropdown(description='name', options=('04_06_2014_431_Sag', 'fr_553_LS_Sag', '16_05_2014…

In [10]:
mip_preds = np.load('../outputs/midline_finder_preds.npz')

mip_ids, mip_masks, mip_labels = mip_preds.values()
mip_ids.shape, mip_masks.shape, mip_labels.shape

((46,), (46, 1, 512, 512), (46, 13))

In [11]:
from skimage import measure
from scipy import interpolate

In [12]:
@interact
def plot_midline(name=list(ids)):
    idx = list(ids).index(name)
    img = np.load(f'../midline_data/testing/slices/coronal/{name}.npy')
    pred = mip_masks[idx, 0]
    fig, ax = plt.subplots(1, 2, figsize=(20, 10))
    #ax[1].axis('off')
    xd = np.linspace(0, pred.shape[0], pred.shape[0])
    #pred = gaussian_filter(pred, sigma=5)
    norm_pred = np.where(pred < 0.5, 0, 1) 
    max_ = np.argmax(norm_pred, axis=1)
    contours = measure.find_contours(pred, 0.5)
    # Get center of contours
    x = [np.mean(contour[:, 1]) for contour in contours]
    y = [np.mean(contour[:, 0]) for contour in contours]
    # Fit spline
    tck = interpolate.splrep(y, x, k=1, s=0)
    xd = np.linspace(0, 512, 512)
    fit = interpolate.splev(xd, tck, der=0)
    
    ax[0].imshow(img)
    ax[1].imshow(img)
    ax[1].imshow(norm_pred, alpha=0.5, cmap='magma')
    ax[1].plot(fit, xd, c='y', lw=2, alpha=0.8)
    # Plot contours
    for contour in contours:
        x, y = np.mean(contour[:, 1]), np.mean(contour[:, 0])
        ax[0].plot(contour[:, 1], contour[:, 0], linewidth=3)
        ax[0].scatter(x, y, marker='+', c='r', s=500)
        ax[1].scatter(x, y, marker='+', c='r', s=500)
    
    ax[1].set_ylim([512, 0])
    ax[1].set_xlim([0, 512])

interactive(children=(Dropdown(description='name', options=('04_06_2014_431_Sag', 'fr_553_LS_Sag', '16_05_2014…

# Extract Sagittal Midline

In [13]:
from ast import literal_eval
import SimpleITK as sitk

In [14]:
direct_df = pd.read_csv('../dicom_direction.csv', index_col='Name', converters={'Direction': literal_eval})
base_direction = (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
direct_df.head()

Unnamed: 0_level_0,Direction,Origin
Name,Unnamed: 1_level_1,Unnamed: 2_level_1
03_06_2014_389_Sag,"(0.0, 0.0, -1.0, 1.0, 0.0, 0.0, 0.0, -1.0, 0.0)","(201.6633148, -353.9631348, 131.249939)"
03_06_2014_402_Sag,"(0.0, 0.0, -1.0, 1.0, 0.0, 0.0, 0.0, -1.0, 0.0)","(173.4781189, -159.1593781, 98.0)"
03_06_2014_396_Sag,"(0.0, 0.0, -1.0, 1.0, 0.0, 0.0, 0.0, -1.0, 0.0)","(203.453125, -223.1618958, 57.64001465)"
03_06_2014_395_Sag,"(-0.3499322470372717, 0.0, -0.9367750116668602...","(190.7415924, -124.9147186, 81.60995483)"
03_06_2014_399_TS_Sag,"(0.0, 0.0, -1.0, 1.0, 0.0, 0.0, 0.0, -1.0, 0.0)","(165.7218781, -164.1593781, 122.75)"


In [15]:
def pad_image(img, output_shape=(512, 512, 512)):
    """
    Insert into array of fixed size - such that all inputs to model have same dimensions
    """
    padding = [(s-x)//2.0 for x, s in zip(output_shape, img.GetSize())]
    output_origin = img.TransformContinuousIndexToPhysicalPoint(padding)
    pad = sitk.ResampleImageFilter()
    pad.SetInterpolator(sitk.sitkLinear)
    pad.SetDefaultPixelValue(-1024)
    pad.SetOutputSpacing(img.GetSpacing())
    pad.SetSize(output_shape)
    pad.SetOutputOrigin(output_origin)
    pad.SetOutputDirection(img.GetDirection())
    return pad.Execute(img), padding

def resample(image, new_spacing):
    """
    Resample image to new resolution with pixel dims defined by new_spacing
    """
    resample = sitk.ResampleImageFilter()
    resample.SetInterpolator(sitk.sitkLinear)
    resample.SetOutputDirection(image.GetDirection())
    resample.SetOutputOrigin(image.GetOrigin())
    resample.SetOutputSpacing(new_spacing)
    orig_size = np.array(image.GetSize(), dtype=np.int)
    orig_spacing = image.GetSpacing()
    ratio = [x/y for x, y in zip(orig_spacing, new_spacing)]
    new_size = orig_size*ratio
    new_size = np.ceil(new_size).astype(np.int)
    new_size = [int(s) for s in new_size]
    resample.SetSize(new_size)
    return resample.Execute(image), ratio

def normalize(img, min=0, max=255):
    """
    Cast float image to int, and normalise to [0, 255]
    """
    img = sitk.Cast(img, sitk.sitkInt32)
    norm = sitk.RescaleIntensityImageFilter()
    norm.SetOutputMaximum(max)
    norm.SetOutputMinimum(min)
    return norm.Execute(img)

In [16]:
def get_ROI(img, fit, width=4):
    #width in pixel
    # Create temp. holder
    roi = np.zeros((512, 2*width, 512))
    roi -= 1024
    for i, y in enumerate(fit):
        sample = img[i, int(y)-width:int(y)+width]
        roi[i] = sample
    return roi

def reformat_scan(name, new_spacing =(1.25, 1.25, 1.25)):
    # New spacing in mm

    # Get Image data
    reader = sitk.ImageFileReader()
    reader.SetImageIO("NiftiImageIO")
    reader.SetFileName(f'../ct_volumes/{name}.nii')
    data = reader.Execute()
    # Set image orientation (this is lost when writing out to .nii)
    data.SetDirection(direct_df.loc[name].values[0])
    
    image_out = sitk.GetImageFromArray(sitk.GetArrayFromImage(data))
    #setup other image characteristics
    image_out.SetOrigin(data.GetOrigin())
    image_out.SetSpacing(data.GetSpacing())
    #set to RAI
    image_out.SetDirection(base_direction)
    
    norm_img, scaling = resample(image_out, new_spacing)
    print(norm_img.GetSize())
    norm_img, padding = pad_image(norm_img)
    
    #data, norm_img = normalize(data), normalize(norm_img)
    return sitk.GetArrayFromImage(data), sitk.GetArrayFromImage(norm_img)

In [17]:
@interact
def extract_midline(name=list(ids)):
    img = np.load(f'../midline_data/testing/slices/coronal/{name}.npy')
    fig, ax = plt.subplots(1,3, figsize=(20, 10))
    plt.tight_layout()
    n = list(ids).index(name)
    pred = sigmoid(mip_masks[n, 0])

    # Get midpoints and fit line
    contours = measure.find_contours(pred, 0.5)
    # Get center of contours
    x = [np.mean(contour[:, 1]) for contour in contours]
    y = [np.mean(contour[:, 0]) for contour in contours]
    # Fit spline
    try:
        tck = interpolate.splrep(y, x, k=1, s=0)
    except TypeError:
        print("No predictions, can't fit line!")
        raise BaseException
    xd = np.linspace(0, 512, 512)
    fit = interpolate.splev(xd, tck, der=0)
    
    data, norm_data = reformat_scan(name)
    roi = get_ROI(norm_data, fit, width=4)
    #roi = sitk.GetArrayFromImage(normalize(sitk.GetImageFromArray(roi)))
    # W/L
    roi = np.clip(roi, a_min=-100, a_max=500)
    roi = np.max(roi, axis=1)
    # ---- PLOTTING ---
    clipped_norm = np.clip(norm_data, a_min=400, a_max=1000)
    
    proj_d = np.max(data, axis=-1)
    proj_n = np.max(clipped_norm, axis=-1)
    # SUBPLOT (0)
    ax[0].set_title('Original Scan')
    ax[0].imshow(proj_d, cmap='gray')
    # SUBPLOT (1)
    ax[1].set_title('Reformat')
    ax[1].imshow(proj_n, cmap='gray')
    ax[1].plot(fit, xd, c='y', lw=2, alpha=0.8)
    for contour in contours:
        x, y = np.mean(contour[:, 1]), np.mean(contour[:, 0])
        ax[1].plot(contour[:, 1], contour[:, 0], linewidth=3, c='y')
        ax[1].scatter(x, y, marker='+', c='r', s=700, edgecolor='k')
    ax[1].set_ylim([512, 0])
    ax[1].set_xlim([0, 512])
    # SUBPLOT (2)
    ax[2].set_title('Midline')
    ax[2].imshow(roi, cmap='gray')

interactive(children=(Dropdown(description='name', options=('04_06_2014_431_Sag', 'fr_553_LS_Sag', '16_05_2014…

# Eval. midline accuracy

In [18]:
@interact
def evaluate_midline(name=list(ids)):
    img = np.load(f'../midline_data/testing/slices/coronal/{name}.npy')
    fig, ax = plt.subplots(1,2, figsize=(20, 10))
    plt.tight_layout()
    # ---- Ground-truth ---- 
    coord_df = pd.read_csv(f'../midline_data/coordinates/{name}.csv', index_col='Level')
    #----predictions ---
    n = list(ids).index(name)
    pred = sigmoid(mip_masks[n, 0])
    # Get midpoints and fit line
    contours = measure.find_contours(pred, 0.5)
    # Get center of contours
    x = [np.mean(contour[:, 1]) for contour in contours]
    y = [np.mean(contour[:, 0]) for contour in contours]
    
    # ---- Spline fitting ---- 
    try:
        tck = interpolate.splrep(y, x, k=1, s=0)
    except TypeError:
        print("No predictions, can't fit line!")
        raise BaseException
    xd = np.linspace(0, 512, 512)

    fit = interpolate.splev(xd, tck, der=0)
    
    
    # --- plotting --- 
    ax[0].imshow(img)
    for row in coord_df.iterrows():
        x, y = row[1]
        print(x, y)
        ax[0].scatter(x, y, marker='o', c='y', s=700, edgecolor='k')
    
    ax[1].imshow(img)
    for contour in contours:
        x, y = np.mean(contour[:, 1]), np.mean(contour[:, 0])
        ax[1].plot(contour[:, 1], contour[:, 0], linewidth=3, c='b')
        ax[1].scatter(x, y, marker='*', c='r', s=700, edgecolor='k')

interactive(children=(Dropdown(description='name', options=('04_06_2014_431_Sag', 'fr_553_LS_Sag', '16_05_2014…