In [1]:
import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import nibabel as nib
%matplotlib inline

# Prepare midline data

In [2]:
ordered_verts = ['T4', 'T5', 'T6', 'T7', 'T8', 'T9', 'T10', 'T11', 'T12', 'L1', 'L2', 'L3', 'L4']

In [3]:
points_path = '/home/donal/PhD/initial_spines/CT_models/MIP/data/points/' #! Originally from /mnt/stopfrac/HICF_and_STOpFrac/CT_volumes/CT_REFORMATS/points_coronal_pass1
#data_list = '/home/donal/PhD/initial_spines/CT_models/data_lists/data_list_all_forviewing.txt' 
coronal_path = '../images_coronal/all_projections/'

### Collect data lists for folds

In [4]:
data_list_dir = '/home/donal/PhD/initial_spines/CT_models/data_lists/'
data_lists = []
for file in os.listdir(data_list_dir):
    if '_q' in file:
        print(file)
        data_lists.append(data_list_dir + file)
print(data_lists)

data_list_q2.txt
data_list_q4.txt
data_list_q1.txt
data_list_q3.txt
['/home/donal/PhD/initial_spines/CT_models/data_lists/data_list_q2.txt', '/home/donal/PhD/initial_spines/CT_models/data_lists/data_list_q4.txt', '/home/donal/PhD/initial_spines/CT_models/data_lists/data_list_q1.txt', '/home/donal/PhD/initial_spines/CT_models/data_lists/data_list_q3.txt']


In [5]:
def get_id(data_list):
    """
    Collect path to point files in a dict.
    """
    pts_files = {}
    with open(data_list, 'r') as f:
        # Read data list (pts file : image name)
        lines = f.readlines()
        lines = [line.strip() for line in lines]
        # Iterate over points file
        for line in lines:
            pts, img = line.split(':')
            coronal_name = pts.split('_midline')[0] +'.npy'
            # Find points w. matching coronal projection (accounts for failures)
            matched_mip = [filename for filename in os.listdir(
                coronal_path) if coronal_name in filename]
            
            # If multiple/no matches (some have slightly diff. file names -> ID_SAG_3mm.png ...)
            if len(matched_mip) > 1:
                print('Too many', matched_mip)
                continue
            elif len(matched_mip) == 1:
                id_ = matched_mip[0].split('.npy')[0]
                file = list(filter(lambda x: id_ in x, os.listdir(points_path)))
                pts_files[id_] = file
            else:
                print(coronal_name)
                print('No matching mip found, skipping...')
                continue
    return pts_files

In [6]:
def get_points(pts_files):
    """
     Get coordinates of each vertebral centre point
    """
    pts_dict = {}
    for key, val in pts_files.items():
        pts_list = []
        name = f'{key}_kj'
        filename = val[0]
        with open(points_path + filename, 'r') as f:
            lines = f.readlines()
            lines = [line.strip() for line in lines]
            start = lines.index('{')
            end = lines.index('}')
            points = lines[start+1:end]
            for coord in points:
                x, y = coord.split(' ')
                x, y = float(x), float(y)
                pts_list.append((x, y))
        pts_dict[name] = pts_list
    return pts_dict

In [7]:
fold_data = {}
for data_list in data_lists:
    num = data_list.split('/')[-1].split('.')[0].split('_')[-1]
    pts_files = get_id(data_list)
    print('FOLD:', num)
    print(f'Found {len(list(pts_files.keys()))} points files w/ matching MIP')
    pts_dict= get_points(pts_files)
    print(f'Found {len(list(pts_dict.keys()))} vertebral annotations.')
    fold_data[num] = pts_dict

18_03_2015_20150318151507421_SRS00000.npy
No matching mip found, skipping...
18_03_2015_20150318151758781_SRS00000.npy
No matching mip found, skipping...
18_03_2015_20150318152008968_SRS00000.npy
No matching mip found, skipping...
18_03_2015_20150318152747953_SRS00001.npy
No matching mip found, skipping...
18_03_2015_20150319114454875_SRS00000.npy
No matching mip found, skipping...
18_03_2015_20150319114614375_SRS00000.npy
No matching mip found, skipping...
18_03_2015_20150319114803000_SRS00000.npy
No matching mip found, skipping...
18_03_2015_20150319115452796_SRS00001.npy
No matching mip found, skipping...
18_03_2015_20150319120307453_SRS00001.npy
No matching mip found, skipping...
18_03_2015_20150319121258156_SRS00000.npy
No matching mip found, skipping...
18_03_2015_20150319121709796_SRS00001.npy
No matching mip found, skipping...
18_03_2015_20150319123927953_SRS00000.npy
No matching mip found, skipping...
18_03_2015_20150319124859640_SRS00000.npy
No matching mip found, skipping...

In [8]:
df_list = []
for key, pts_dict in fold_data.items():
    pts_df = pd.DataFrame(dict([ (k,pd.Series(v)) for k,v in pts_dict.items() ]))
    pts_df.columns = pts_dict.keys()
    pts_df = pts_df.T
    pts_df['Fold'] = key
    df_list.append(pts_df)

fold_pts = pd.concat(df_list)
fold_pts.to_csv('../formatted_pts.csv', index=True)

In [9]:
pts_df = pd.read_csv('../formatted_pts.csv', index_col=0, header=0)
pts_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,12,13,14,15,16,17,Fold,18,19,20
04_06_2014_428_Sag_kj,"(178.643, 36.1739)","(177.53, 71.7913)","(176.974, 109.078)","(176.974, 144.696)","(180.313, 180.313)","(180.313, 217.043)","(180.87, 249.322)","(180.313, 281.6)",,,...,,,,,,,q2,,,
04_06_2014_431_Sag_kj,"(192.557, 2.22609)","(193.113, 32.8348)","(193.67, 71.7913)","(192.0, 104.07)","(190.887, 139.13)","(192.0, 174.191)","(192.557, 209.809)","(191.443, 245.426)","(191.443, 276.591)",,...,,,,,,,q2,,,
04_06_2014_432_Sag_kj,"(254.33, 27.2696)","(248.209, 75.1304)","(232.626, 166.957)","(228.73, 212.035)","(226.504, 256.557)","(226.504, 296.07)",,,,,...,,,,,,,q2,,,
04_06_2014_433_Sag_kj,"(314.435, 16.6957)","(316.661, 53.9826)","(316.104, 93.4957)","(319.443, 130.226)","(323.896, 168.626)","(324.452, 208.139)","(322.783, 247.652)","(320.0, 281.6)",,,...,,,,,,,q2,,,
04_06_2014_434_Sag_kj,"(212.591, 13.913)","(209.252, 41.7391)","(204.243, 68.4522)","(199.791, 101.843)","(194.783, 130.783)","(186.435, 166.957)","(182.539, 200.348)","(180.313, 234.852)","(177.53, 269.913)","(175.304, 294.957)",...,,,,,,,q2,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
fr_592_LS_Sag_kj,"(171.864, 32.3368)","(170.667, 71.2608)","(171.265, 115.574)","(170.068, 169.469)","(171.265, 215.579)","(169.469, 255.701)","(169.469, 290.433)","(166.475, 320.973)",,,...,,,,,,,q3,,,
fr_592_TS_Sag_kj,"(210.189, 22.1567)","(208.992, 35.9298)","(207.195, 57.4877)","(207.195, 79.6444)","(204.201, 142.522)","(205.998, 179.05)","(206.596, 217.974)","(200.009, 251.509)","(196.416, 291.63)","(196.416, 335.345)",...,"(197.015, 474.274)",,,,,,q3,,,
fr_593_LS_Sag_kj,"(170.068, 41.3193)","(172.463, 75.4526)","(176.655, 113.179)","(178.451, 146.713)","(181.446, 187.434)","(184.44, 226.358)","(184.44, 264.084)","(184.44, 299.415)",,,...,,,,,,,q3,,,
fr_593_TS_Sag_kj,"(218.573, 56.8889)","(220.968, 83.8363)","(222.166, 114.377)","(220.37, 141.324)","(218.573, 168.271)","(218.573, 195.818)","(219.771, 227.556)","(219.771, 262.288)","(219.172, 297.619)","(219.172, 332.95)",...,"(226.358, 464.692)",,,,,,q3,,,


# Get coronal midline annotations

In [10]:
from ast import literal_eval
import SimpleITK as sitk
from scipy.ndimage import gaussian_filter
from PIL import Image
from scipy import interpolate

In [28]:
def align_points(name, df):
    # Collect resampling info
    cor_info_path = '/data/PAB_data/images_coronal/annotation_info.csv'
    cor_info = pd.read_csv(cor_info_path)
    filename = name.replace('_kj', '')
    cor = cor_info.loc[cor_info['Name'] == filename]
    padding = literal_eval(cor.iloc[0]['Padding'])
    scale = literal_eval(cor.iloc[0]['Pixel Scaling'])
    orig_pix = literal_eval(cor.iloc[0]['Orig. Pix'])
    orig_thick = round(float(cor.iloc[0]['Slice Thickness']))
    num_slices = int(cor.iloc[0]['Num Slices'])
    # Get original points
    print(filename, name)
    print(df.loc[name].dropna().values)
    og_x, og_y = zip(*[literal_eval(x) for x in df.loc[name].dropna().values])
    # Update points
    x = np.array(og_x, dtype=float)
    y = np.array(og_y, dtype=float)
    if orig_thick != max(orig_pix):
        # Handle mismatch in slice thickness along sagittal plane
        dist = (orig_thick- orig_pix[-1])*num_slices
        out_y = (y*min(orig_pix)/(4*0.3125)) - padding[0]
        out_x = x*min(orig_pix)*orig_pix[-1]/(orig_thick*4*0.3125) -padding[1]
        return out_x, out_y, og_x, og_y
    else:
        out_y = (y*min(orig_pix)/(4*0.3125)) - padding[0]
        out_x = (x*min(orig_pix)/(4*0.3125)) - padding[1]
        return out_x, out_y, og_x, og_y

In [24]:
def coronal_midline(name, pts_df, extent=50, plot=False, write_coords=False):
    # Extend +- in mm
    og_img = Image.open(f'/home/donal/PhD/initial_spines/CT_models/MIP/data/images/{name}.tiff')
    og_img = np.array(og_img)
    if any(og_img.shape) > 512:
        print(name)
    # Read coronal mip
    filename = name.replace('_kj', '')
    filepath = f'/data/PAB_data/images_coronal/all_projections/{filename}.npy'
    img = np.load(filepath)
    mip = img[..., 0]
    fold = pts_df.loc[name]['Fold']
    print(fold)
    x_coords, y_coords, og_x, og_y = align_points(name, pts_df)
    if write_coords:
        df = pd.DataFrame(columns=['X', 'Y'])
        df['X'], df['Y'] = x_coords, y_coords
        os.makedirs(f'/data/PAB_data/midline_data/parent_data/{fold}/coordinates/', exist_ok=True)
        df.to_csv(f'/data/PAB_data/midline_data/parent_data/{fold}/coordinates/{filename}.csv', index=False)
    
    # --- Fitting ---
    tck = interpolate.splrep(y_coords, x_coords, k=1, s=0)
    fit_x = np.linspace(int(min(y_coords)), int(max(y_coords)), int(max(y_coords)-min(y_coords)))
    fit = interpolate.splev(fit_x, tck, der=0)
    
    # Convert annotations to mask
    midline_holder = np.zeros((*mip.shape[:2], 1), dtype=np.float32)
    # annotation extent in pixels
    pix_extent = int(extent/1.25)
    print(f'Extend: {extent} mm & {pix_extent} in pix')
    for x, y in zip(fit_x.astype(int), fit.astype(int)):
        midline_holder[x, y-pix_extent:y+pix_extent] = 1 # This controls size of point
    #midline_holder = gaussian_filter(midline_holder, sigma=5)
    # Convert to binary mask
    #midline_holder = np.where(midline_holder < 0.5, 0, 1).astype(np.int16)
    os.makedirs(f'/data/PAB_data/midline_data/parent_data/{fold}/midline_mask/', exist_ok=True)
    os.makedirs(f'/data/PAB_data/midline_data/parent_data/{fold}/mips/', exist_ok=True)
    np.save(f'/data/PAB_data/midline_data/parent_data/midline_mask/{filename}.npy', midline_holder)
    np.save(f'/data/PAB_data/midline_data/parent_data/mips/{filename}.npy', mip)
    if plot:
        fig, ax = plt.subplots(1, 1, figsize=(10, 10))
        ax.imshow(mip, cmap='gray')
        ax.imshow(np.where(midline_holder==0, np.nan, midline_holder), alpha=0.5)
        ax.plot(fit, fit_x , c='y',lw=1.5, alpha=1)
        ax.scatter(x_coords, y_coords, c='y', lw=1.5, edgecolor='k')
        fig.savefig(f'../midline_data/sanity/{name}.png')
        plt.close()

In [29]:
for name in pts_df.index:
    print(name)
    coronal_midline(name, pts_df, extent=2.5, plot=True, write_coords=False)
    break

04_06_2014_428_Sag_kj
q2
04_06_2014_428_Sag 04_06_2014_428_Sag_kj


ValueError: malformed node or string: <_ast.Name object at 0x7f4ad49d2d30>

# Get sagittal midline projection

In [9]:
from PIL import Image
from scipy import optimize, interpolate

In [10]:
def get_ROI(img, fit, width=4):
    roi = np.zeros((512, 2*width, 512))
    roi -= 1024
    for i, y in enumerate(fit):
        sample = img[i, int(y)-width:int(y)+width]
        roi[i] = sample
    return roi

def plot_midline(name, mip, scaling, padding, orig_scale, plot=False):
    img = Image.open(f'/home/donal/PhD/initial_spines/CT_models/MIP/data/images/{name}_kj.tiff')
    img = np.array(img)
    # --- Load coordinates and transform to frame ---
    og_x, og_y = zip(*[literal_eval(x) for x in df.loc[f'{name}_kj'].dropna().values])
    x = np.array(og_x, dtype=float)
    x *= min(orig_scale) 
    x /= 4*0.3125
    x -= padding[2]
    y = np.array(og_y, dtype=float)
    y *= min(orig_scale)
    y /= 4*0.3125
    y -= padding[1]
    # --- FIT 1st order spline ---
    tck = interpolate.splrep(og_y, og_x, k=1, s=0)
    xd = np.linspace(0, 512, 512)
    fit = interpolate.splev(xd, tck, der=0)
    if plot:
        fig, ax = plt.subplots(1,2, figsize=(20, 10))
        ax[0].imshow(img, cmap='gray')
        sns.scatterplot(x=og_x, y=og_y, color='y', s=45, edgecolor='k', ax=ax[0])
        ax[0].plot(fit, xd, linewidth=1.5, color='r')
        ax[0].set_ylim([512, 0])
        # My attempt
        ax[1].imshow(mip.T, cmap='gray')
        sns.scatterplot(x=x, y=y, color='y', s=45, edgecolor='k', ax=ax[1])
    return fit

In [11]:
def pad_image(img, output_shape=(512, 512)):
    """
    Insert into array of fixed size - such that all inputs to model have same dimensions
    """
    padding = [(s-x)//2.0 for x, s in zip(output_shape, img.GetSize())]
    output_origin = img.TransformContinuousIndexToPhysicalPoint(padding)
    pad = sitk.ResampleImageFilter()
    pad.SetInterpolator(sitk.sitkLinear)
    pad.SetDefaultPixelValue(-1024)
    pad.SetOutputSpacing(img.GetSpacing())
    pad.SetSize(output_shape)
    pad.SetOutputOrigin(output_origin)
    #pad.SetOutputDirection(img.GetDirection())
    pad.SetOutputDirection(base_direction)
    return pad.Execute(img), padding

def normalize(img, min=0, max=255):
    """
    Cast float image to int, and normalise to [0, 255]
    """
    img = sitk.Cast(img, sitk.sitkInt32)
    norm = sitk.RescaleIntensityImageFilter()
    norm.SetOutputMaximum(max)
    norm.SetOutputMinimum(min)
    return norm.Execute(img)

def resample(image, new_spacing, output_shape=(512, 512, 512)):
    """
    Resample image to new resolution with pixel dims defined by new_spacing
    """
    # Calculate image size
    orig_size = np.array(image.GetSize(), dtype=np.int)
    orig_spacing = image.GetSpacing()
    ratio = [x/y for x, y in zip(orig_spacing, new_spacing)]
    new_size = orig_size*ratio
    new_size = np.ceil(new_size).astype(np.int)
    new_size = [int(s) for s in new_size]

    # Prepare filter
    resample = sitk.ResampleImageFilter()
    resample.SetInterpolator(sitk.sitkLinear)
    resample.SetDefaultPixelValue(-1024)
    resample.SetOutputDirection(image.GetDirection())
    resample.SetOutputOrigin(image.GetOrigin())
    resample.SetOutputSpacing(new_spacing)
    resample.SetSize(new_size)
    #resample.SetTransform(transform)
    res_img = resample.Execute(image)

    padding = [(s-x)//2 for x, s in zip(output_shape, res_img.GetSize())]
    output_origin = res_img.TransformContinuousIndexToPhysicalPoint(padding)
    
    
    pad = sitk.ResampleImageFilter()
    pad.SetInterpolator(sitk.sitkLinear)
    pad.SetDefaultPixelValue(-1024)
    pad.SetOutputSpacing(res_img.GetSpacing())
    pad.SetOutputDirection(res_img.GetDirection())
    pad.SetSize(output_shape)
    pad.SetOutputOrigin(output_origin)
    #pad.SetTransform(transform)
    
    return pad.Execute(res_img), res_img, ratio, padding

In [12]:
def get_orig_pix(name):
    pix_info = f'/home/donal/CT_volumes/resampled_mip/pixels/{name}_mip_WL.txt'
    with open(pix_info, 'r') as f:
        text = f.readlines()
        x, y, z = (float(x) for x in text[1].split(' '))
    return (x, y, z)

def resample_data(name, data):
    # Resample Image
    min_pix = 4*0.3125
    pad_img, res_img, scaling, padding = resample(data, (min_pix, min_pix, min_pix))
    return pad_img, res_img, scaling, padding

In [13]:
def get_sagittal_midline(name, plot=False):
    # ---- READ VOLUME -- 
    reader = sitk.ImageFileReader()
    reader.SetImageIO("NiftiImageIO")
    reader.SetFileName(f'../ct_volumes/{name}.nii')
    data = reader.Execute()
    
    # ---- Norm. coordinates --- 
    orig_spacing = data.GetSpacing()
    orig_size =  data.GetSize()
    orig_direction = data.GetDirection() 
    print(orig_spacing, orig_size, orig_direction)
    data.SetDirection(base_direction)
    # --- REFORMAT DATA ---
    pad_img, res_img, scale, padding = resample_data(name, data)
    pad_img = sitk.GetArrayFromImage(pad_img)
          
    # ----Coronal MIP ---
    #clipped = np.clip(sitk.GetArrayFromImage(res_img), a_min=-100, a_max=900)
    clipped = np.clip(pad_img, a_min=-100, a_max=900)
    mip = np.max(clipped, axis=-1)
    
    # Get Fit from points
    fit = plot_midline(name, mip, scale, padding,                                                                                 , plot=True)
    
    if plot:
#         fig, ax = plt.subplots(1, 1, figsize=(10, 10))
#         ax.imshow(res_img[:, 150, :])
        pass

In [None]:
base_direction = (0.0, 0.0, -1.0, 1.0, 0.0, 0.0, 0.0, -1.0, 0.0)
base_origin = (0.0, 0.0, 0.0)
for name in df.index:
    if name  == '03_06_2014_389_Sag_kj':
        print(name)
        get_sagittal_midline(name.strip('_kj'), plot=True)
        
    elif name == 'fr_555_LS_Sag_kj':
        get_sagittal_midline(name.strip('_kj'), plot=True)
        break

## Prep. annotations