In [1]:
import os
import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import SimpleITK as sitk
import json
import sys
import csv

import nibabel as nib
import nibabel.orientations as nio
from data_utilities import reorient_centroids_to, reorient_to, load_centroids

import ipywidgets as widgets
from ipywidgets import interact

sys.path.append('/home/donal/PhD/spines/spine_nn/')

In [2]:
%matplotlib inline

In [3]:
all_verts = ['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'T1', 'T2', 'T3', 'T4', 'T5', 'T6', 'T7', 'T8', 'T9', 'T10', 'T11', 'T12', 'L1', 'L2', 'L3', 'L4', 'L5','L6', 'Sacrum', 'Cocygis', 'T13']

## Prepare level labelling with VerSe data

In [4]:
root_dir = '/data/VerSe/'
img_dict = {}
mask_dict = {}
ctrd_dict = {} # Centroid dict
for root, dirs, files in os.walk(root_dir):
    if 'test' in root: continue
    if 'rawdata' in root:
        #Process images
        for file in files:
            if file.endswith('.nii.gz'):
                name = file.split('.')[0]
                img_dict[name] = os.path.join(root, file)
    elif 'derivatives' in root:
        # Load annotations
        for file in files:
            if file.endswith('.nii.gz'):
                # Read mask
                name = file.split('.')[0]
                name = name.replace('_seg-vert_msk', '_ct')
                mask_dict[name] = os.path.join(root, file)
            elif file.endswith('.json'):
                # Read centroid annotations
                name = file.split('.')[0]
                if 'VerSe19' in root:
                    name = name.replace('_seg-vb_ctd', '_ct')
                elif 'VerSe20' in root:
                    name = name.replace('_seg-subreg_ctd', '_ct')
                else:
                    pass
                ctrd_dict[name] = load_centroids(os.path.join(root, file))
    else:
        ...
        #print(files)

In [5]:
ids = list(img_dict.keys())

In [6]:
ids = np.array(ids)
ids.shape

(261,)

In [7]:
from sklearn.model_selection import KFold

In [8]:
kf = KFold(n_splits=4, random_state=124, shuffle=True)
q1, q2, q3, q4 = kf.split(ids)

folds = {'q1': list(ids[q1[-1]]), 
        'q2': list(ids[q2[-1]]), 
        'q3': list(ids[q3[-1]]), 
        'q4': list(ids[q4[-1]])
        }

In [9]:
print(len(list(img_dict.keys())), len(list(mask_dict.keys())), len(list(ctrd_dict.keys())))

261 261 261


***

### __Train/Test/Validation:__ _141/113/120 = 374_

***

## Resample + Projections

In [10]:
from projections import maximum_projection, WL_norm, average_projection, standard_deviation
from scipy.ndimage import measurements
from skimage import measure

In [11]:
def prep(img, spacing):
    # For Nibabel images
    img_iso = resample_nib(img, voxel_spacing=spacing, order=3)
    img_iso = reorient_to(img_iso, axcodes_to=('I', 'P', 'L'))
    return img_iso

def resample(image, new_spacing):
    """
    Resample image to new resolution with pixel dims defined by new_spacing
    """
    resample = sitk.ResampleImageFilter()
    resample.SetInterpolator(sitk.sitkLinear)
    resample.SetOutputDirection(image.GetDirection())
    resample.SetOutputOrigin(image.GetOrigin())
    resample.SetOutputSpacing(new_spacing)
    orig_size = np.array(image.GetSize(), dtype=np.int)
    orig_spacing = image.GetSpacing()
    ratio = [x/y for x, y in zip(orig_spacing, new_spacing)]
    new_size = orig_size*ratio
    new_size = np.ceil(new_size).astype(np.int)
    new_size = [int(s) for s in new_size]
    resample.SetSize(new_size)
    return resample.Execute(image), ratio

def pad_image(img, output_shape=(626, 452)):
    """
    Insert into array of fixed size - such that all inputs to model have same dimensions
    """
    padding = [(s-x)//2.0 for x, s in zip(output_shape, img.GetSize())]
    output_origin = img.TransformContinuousIndexToPhysicalPoint(padding)
    pad = sitk.ResampleImageFilter()
    pad.SetOutputSpacing(img.GetSpacing())
    pad.SetDefaultPixelValue(0.0)
    pad.SetSize(output_shape)
    pad.SetOutputOrigin(output_origin)
    pad.SetOutputDirection(img.GetDirection())
    return pad.Execute(img), padding

def normalize(img, min_=0, max_=255):
    """
    Cast float image to int, and normalise to [0, 255]
    """
    img = sitk.Cast(img, sitk.sitkInt32)
    norm = sitk.RescaleIntensityImageFilter()
    norm.SetOutputMaximum(max_)
    norm.SetOutputMinimum(min_)
    return norm.Execute(img)

def post_projection(img, new_spacing,output_shape=(626, 452)):
    """
    Post-processing for projections, resample to isotropic grid and make standardised shape
    """
    resampled_img, scale = resample(img, new_spacing)
    padded_img, padding = pad_image(resampled_img, output_shape)
    return padded_img, scale, padding

def points2frame(img, points):
    assert points[0] == ('R', 'P', 'I') # Check image orientation
    # Convert points into correct coordinate system (Scaling + padding)
    out_dict = {}
    for vert, x, y, z in points[1:]:
        scale = np.array([x/1.25 for x in img.GetSpacing()])
        padding = np.array([(512-x*scale[i])//2 for i, x in enumerate(img.GetSize())])
        coords = np.array([x, y, z]).T * scale + padding
        coords = tuple([x for x in coords])
        if any(t < 0.0 for t in coords):
            print(coords)
            print('POINT REMOVED - outside borders')
            continue
        out_dict[vert] = coords
    return out_dict


In [12]:
cmap=sns.cubehelix_palette(start=3.0, rot=0.0, hue=1.2, gamma=0.8, as_cmap=True)

In [13]:
def viz_img(name=ids, dim=[0, 1, 2], save=False, plot=False):
    new_spacing=(1.25, 1.25)
    output_shape=(512, 512)
    
    Image = sitk.ReadImage(img_dict[name])
    Mask = sitk.ReadImage(mask_dict[name]) 
    points = ctrd_dict[name]
    
    # Reorder NIB for fixing coordinates
    nib_img = nib.load(img_dict[name])
    img_iso = reorient_to(nib_img, axcodes_to=('R', 'P', 'I'))
    
    #Change image Orientation
    orient = sitk.DICOMOrientImageFilter()
    orient.SetDesiredCoordinateOrientation('RPI')
    Image = orient.Execute(Image)
    # Reorient points  + mask
    points = reorient_centroids_to(points, img_iso)
    points = points2frame(Image, points)
    Mask = orient.Execute(Mask)
    
    bone_Image = WL_norm(Image, window=1000, level=700)
    tissue_Image = WL_norm(Image, window=600, level=200)
    
    if dim==0:
        mip = maximum_projection(bone_Image, dim=dim)[0]
        avg = average_projection(tissue_Image, dim=dim)[0]
        std = standard_deviation(tissue_Image, dim=dim)[0]
        mask = maximum_projection(Mask, dim=dim)[0]
    elif dim==1:
        mip = maximum_projection(bone_Image, dim=dim)[:, 0]
        avg = average_projection(tissue_Image, dim=dim)[:, 0]
        std = standard_deviation(tissue_Image, dim=dim)[:, 0]
        mask = maximum_projection(Mask, dim=dim)[:, 0]
    elif dim==2:
        mip = maximum_projection(bone_Image, dim=dim)[..., 0]
        avg = average_projection(tissue_Image, dim=dim)[..., 0]
        std = standard_deviation(tissue_Image, dim=dim)[..., 0]
        mask = maximum_projection(Mask, dim=dim)[..., 0]
    else:
        raise ValueError
    #Cast to int and scale to [0, 255]
    mip=normalize(mip, min_=0, max_=255)
    avg=normalize(avg, min_=0, max_=255)
    std=normalize(std, min_=0, max_=255)
    
    #Resample + Pad [512, 512]
    padded_mip, scale, padding = post_projection(mip, new_spacing, output_shape)
    padded_avg, _, _ = post_projection(avg, new_spacing, output_shape)
    padded_std, _, _ = post_projection(std, new_spacing, output_shape)
    padded_mask, _, _ = post_projection(mask, new_spacing, output_shape)
    
    #Convert to array
    mip_img = sitk.GetArrayFromImage(padded_mip)
    avg_img = sitk.GetArrayFromImage(padded_avg)
    std_img = sitk.GetArrayFromImage(padded_std)
    plt_mask = sitk.GetArrayFromImage(padded_mask)
    
    img = np.stack((mip_img, std_img, avg_img), axis=-1)
    
    data_dict = {'mip': mip_img, 'avg': avg_img, 'std': std_img, 'all_projections': img}  
    if save:
        for fold, ids in folds.items():
            if name in ids:
                for dir_, img in data_dict.items():
                    if dim==0:
                        outpath = f'/data/VerSe/all_VerSe/images_sagittal/{fold}/{dir_}/'
                    elif dim==1:
                        outpath = f'/data/VerSe/all_VerSe/images_coronal/{fold}/{dir_}/'
                    else: 
                        break
                    os.makedirs(outpath, exist_ok=True)
                    np.save(os.path.join(outpath, f'{name}.npy'), img)
                    
    # ----- PLOTTING ------
    if plot:
        fig, ax = plt.subplots(1, 1, figsize=(10,10))
        ax.axis('off')
        print(img.max(), img.min())
        ax.imshow(np.squeeze(img), cmap='gray')
        ax.imshow(np.where(plt_mask==0, np.nan, plt_mask), alpha=0.5, cmap=cmap)
        # Find contours at a constant value
        contours = measure.find_contours(plt_mask, 1.)

        for contour in contours:
            ax.plot(contour[:, 1], contour[:, 0], linewidth=2, c='k')

        for elem, (x, y, z) in points.items():
            vert = all_verts[elem-1]
            if dim==0:
                plot_x, plot_y = y, z
            elif dim==1:
                plot_x, plot_y = x, z
            elif dim==2:
                plot_x, plot_y = x, y
            ax.scatter(plot_x, plot_y, c='y', edgecolor='k')
            ax.text(plot_x+15, plot_y, vert, c='cyan')

In [13]:
import random

In [2]:
# for dim in [0, 1]:
#     for name in ids:
#         print(name)
#         viz_img(name, dim=dim, save=False, plot=True)
#         break

***

### EXTRACT SAGITTAL MIDLINE

***

In [14]:
from utils.midline_utils import *
from scipy import interpolate

In [15]:
def points2fit(points):
    x = [x for x, _, _ in points.values()]
    z = [z for _, _, z in points.values()]
    z, x = zip(*sorted(zip(z, x)))
    #* Fit spline
    tck = interpolate.splrep(z, x, k=1, s=0)
    xd = np.linspace(min(z), max(z), int(max(z)-min(z)))
    fit = interpolate.splev(xd, tck, der=0)
    #* Extend top and bottom
    top = [fit[0]]*int(min(z))
    bot = [fit[-1]]*(512-int(max(z)))
    top.extend(fit)
    top.extend(bot)
    if len(top) != 512:
        top.append(fit[-1])
    return top

def get_ROI(vol, fit, width=4):
    #* Extract thick sagittal slice
    roi = np.zeros((512, 512, 2*width,))
    #roi -= 1024
    for i, y in enumerate(fit):
        sample = vol[:, i, int(y)-width:int(y)+width]
        roi[:, i] = sample
    return roi


In [18]:
def extract_sagittal_midline(name, plot=False, save=False):
    new_spacing=(1.25, 1.25)
    output_shape=(512, 512)
    dim=1
    
    Image = sitk.ReadImage(img_dict[name])
    points = ctrd_dict[name]
    Mask = sitk.ReadImage(mask_dict[name])
    
    # Reorder NIB for fixing coordinates
    nib_img = nib.load(img_dict[name])
    img_iso = reorient_to(nib_img, axcodes_to=('R', 'P', 'I'))
    
    #Change image Orientation
    orient = sitk.DICOMOrientImageFilter()
    orient.SetDesiredCoordinateOrientation('RPI')
    Image = orient.Execute(Image)
    Mask = orient.Execute(Mask)
    mask = maximum_projection(Mask, dim=0)[0]
    padded_mask, _, _ = post_projection(mask, new_spacing, output_shape)
    plt_mask = sitk.GetArrayFromImage(padded_mask)
    
    # Reorient points  + mask
    points = reorient_centroids_to(points, img_iso)
    points = points2frame(Image, points)
    if not points or len(list(points.keys())) == 1:
        print('No annotations for ', name)
        return

    # Get midline
    fit = points2fit(points)
    xd = np.linspace(0, 512, 512)
    
    # Image processing
    bone_Image = WL_norm(Image, window=1000, level=700)
    
    vol, _ = resample(Image, (1.25, 1.25, 1.25))
    vol, _ = pad_image(vol, (512, 512, 512))
    vol = sitk.GetArrayFromImage(vol)
    thick_slice = get_ROI(vol, fit)
    norm_slice = WL_norm(sitk.GetImageFromArray(thick_slice), window=1000, level=400)
    norm_slice = normalize(norm_slice, min_=0, max_=255)
    
    slice_ = np.max(np.squeeze(sitk.GetArrayFromImage(norm_slice)), axis=-1)
    
    if save:
        for fold, ids in folds.items():
            if name in ids:
                outpath = f'/data/VerSe/all_VerSe/vert_labelling/{fold}/slices/'
                std = np.load(f'/data/VerSe/all_VerSe/images_sagittal/{fold}/std/{name}.npy')
                mip = np.load(f'/data/VerSe/all_VerSe/images_sagittal/{fold}/mip/{name}.npy')
                img = np.stack((mip, slice_, std), axis=-1)
                os.makedirs(outpath, exist_ok=True)
                np.save(os.path.join(outpath, f'{name}.npy'), img)
    
    
    # ----- PLOTTING ------
    if plot:
        fig, ax = plt.subplots(1, 1, figsize=(10,10))
        ax.axis('off')
        if save:
            ax.imshow(img)
        else:    
            ax.imshow(slice_, cmap='gray')
        ax.imshow(np.where(plt_mask==0, np.nan, plt_mask), alpha=0.8, cmap='viridis')
        for elem, (x, y, z) in points.items():
            vert = all_verts[elem-1]
            plot_x, plot_y = y, z
            ax.scatter(plot_x, plot_y, c='y', edgecolor='k')
            ax.text(plot_x+15, plot_y, vert, c='cyan')

In [1]:
# start_name = 'sub-verse116_ct'
# for name in ids[list(ids).index(start_name):]:
#     extract_sagittal_midline(name, plot=False, save=True)

***

### TARGETS prep

In [16]:
from scipy.ndimage import gaussian_filter

In [20]:
def create_heatmaps(points):
    heatmap_holder = np.zeros((512, 512, len(all_verts)), dtype=np.float32)
    for channel, (x, y, z) in points.items():
        xd = np.linspace(0, 512, 512)
        tmp = np.zeros((512, 512))
        tmp[round(z), round(y)] = 1
        tmp = gaussian_filter(tmp, sigma=5)
        heatmap_holder[..., channel-1] = tmp
    return heatmap_holder

In [21]:
def prepare_targets(name, plot=False, save=False):
    new_spacing=(1.25, 1.25)
    output_shape=(512, 512)
    dim=1
    
    Image = sitk.ReadImage(img_dict[name])
    points = ctrd_dict[name]
    Mask = sitk.ReadImage(mask_dict[name])
    
    # Reorder NIB for fixing coordinates
    nib_img = nib.load(img_dict[name])
    img_iso = reorient_to(nib_img, axcodes_to=('R', 'P', 'I'))
    
    #Change image Orientation
    orient = sitk.DICOMOrientImageFilter()
    orient.SetDesiredCoordinateOrientation('RPI')
    Image = orient.Execute(Image)
    # ----- MASK ---- 
    Mask = orient.Execute(Mask)
    mask = maximum_projection(Mask, dim=0)[0]
    padded_mask, _, _ = post_projection(mask, new_spacing, output_shape)
    plt_mask = sitk.GetArrayFromImage(padded_mask)
    
    # Reorient points  + mask
    points = reorient_centroids_to(points, img_iso)
    points = points2frame(Image, points)
    if not points or len(list(points.keys())) == 1:
        print('No annotations for ', name)
        return
    
    heatmap = create_heatmaps(points)
    
    for fold, ids in folds.items():
        if name in ids:
            mip = np.load(f'/data/VerSe/all_VerSe/images_sagittal/{fold}/mip/{name}.npy')
    
    if save:
        for fold, ids in folds.items():
            if name in ids:
                outpath = f'/data/VerSe/all_VerSe/vert_labelling/{fold}/targets/masks/'
                heatmap_out = f'/data/VerSe/all_VerSe/vert_labelling/{fold}/targets/heatmaps/'
                os.makedirs(outpath, exist_ok=True)
                os.makedirs(heatmap_out, exist_ok=True)
                np.save(os.path.join(outpath, f'{name}.npy'), plt_mask)
                np.save(os.path.join(heatmap_out, f'{name}.npy'), heatmap)
                
                coords_out = f'/data/VerSe/all_VerSe/vert_labelling/{fold}/targets/coordinates/'
                os.makedirs(coords_out, exist_ok=True)
                with open(os.path.join(coords_out, f'{name}.csv'), 'w') as f:
                    wrt = csv.writer(f, dialect='excel')
                    wrt.writerow(['Level', 'X', 'Y', 'Z'])
                    for vert in all_verts:
                        idx = all_verts.index(vert) + 1
                        if idx not in points: continue
                        x, y, z = points[idx]
                        wrt.writerow([vert, x, y, z])
    
    
    # ----- PLOTTING ------
    if plot:
        fig, ax = plt.subplots(1, 1, figsize=(10,10))
        ax.axis('off')
        ax.imshow(mip, cmap='gray')
        ax.imshow(np.where(plt_mask==0, np.nan, plt_mask), alpha=0.8, cmap='viridis')
        ax.imshow(np.max(heatmap, axis=-1), alpha=0.5)
        for elem, (x, y, z) in points.items():
            vert = all_verts[elem-1]
            plot_x, plot_y = y, z
            ax.scatter(plot_x, plot_y, c='y', edgecolor='k')
            ax.text(plot_x+15, plot_y, vert, c='cyan')

In [22]:
for name in ids:
    prepare_targets(name, plot=False, save=True)

(266.76, 237.0821875, -33.64228469371796)
POINT REMOVED - outside borders
(267.72, 232.9140625, -19.543998594284062)
POINT REMOVED - outside borders
(265.96000000000004, 239.05656249999998, -1.8799981117248592)
POINT REMOVED - outside borders
(255.33971280097964, 229.94265625, -29.47999906539917)
POINT REMOVED - outside borders
(256.9835976982117, 238.09953124999998, -15.127998495101927)
POINT REMOVED - outside borders
(271.05109374999995, 275.135, -21.216000471115112)
POINT REMOVED - outside borders
(270.49734375, 281.57234375, -5.736000881195068)
POINT REMOVED - outside borders
(264.04, 231.66531250000003, -34.68235834121704)
POINT REMOVED - outside borders
(262.92, 233.34765625, -23.54632308959961)
POINT REMOVED - outside borders
(261.96000000000004, 240.62265625, -12.955726299285892)
POINT REMOVED - outside borders
(262.44000000000005, 249.35265625, -1.9105974578857357)
POINT REMOVED - outside borders
(302.8790480041504, 249.18515624999998, -37.31491843700409)
POINT REMOVED - outsi

***