In [1]:
import SimpleITK as sitk
import skimage as ski
import numpy as np
import pandas as pd
import os, re, glob
import pickle
# import ReadImages  # Image reader provided by Silas
from IPython.display import HTML
import matplotlib.pyplot as plt
from matplotlib import animation
% matplotlib inline

-----

# Step 1: Crop all images to minimal bounding box

In [2]:
def bbox(img):

    # find planes with values different from 0
    x = np.any(img, axis=(1, 2))
    y = np.any(img, axis=(0, 2))
    z = np.any(img, axis=(0, 1))

    xmin, xmax = np.where(x)[0][[0, -1]]
    ymin, ymax = np.where(y)[0][[0, -1]]
    zmin, zmax = np.where(z)[0][[0, -1]]

    return xmin, xmax, ymin, ymax, zmin, zmax

In [3]:
def crop_imgs(paths):
    x = 0
    y = 0
    z = 0
    
    for path in paths:
        
        # load img
        sitk_image = sitk.ReadImage(path, sitk.sitkFloat32)
        image = sitk.GetArrayFromImage(sitk_image)
        
        # crop to minimal bounding box
        xmin, xmax, ymin, ymax, zmin, zmax = bbox(image)
        cropped = image[xmin:xmax, ymin:ymax, zmin:zmax]
        
        # update global max dims
        if cropped.shape[0] > x:
            x = cropped.shape[0]
        if cropped.shape[1] > y:
            y = cropped.shape[1]
        if cropped.shape[2] > z:
            z = cropped.shape[2]
            
        # export image
        directory = "/Users/anders1991/deepbreath/Data/Scans_cropped/"
        if not os.path.exists(directory):
            os.makedirs(directory)
        name = path.split("/")[-1].split(".nii")[0]
        file = directory + name + ".npy"
        np.save(file, cropped)

    return x, y, z

In [4]:
# get paths
scan_paths = sorted(glob.glob("/Users/anders1991/deepbreath/Data/Scans/vol*.nii"))

# crop images
dims = crop_imgs(scan_paths)

In [5]:
print("Target shape:", dims)

Target shape: (0, 0, 0)


Visualization of cropped image (sanity check):

In [None]:
# # load image to visualize
# paths = sorted(glob.glob("/Users/anders1991/deepbreath/Data/Scans_cropped/*.npy"))
# scan_bbox = np.load(paths[0])

# # create figure
# fig = plt.figure(figsize=(5,5))
# i = np.linspace(1, scan_bbox.shape[0], num=scan_bbox.shape[0], dtype=int)-1
# im = plt.imshow(scan_bbox[0,:,:], animated=True, cmap='jet', vmin=-1000, vmax=0)
# plt.axis('off')

# def animate(i):
#     im.set_array(scan_bbox[i,:,:])
#     return im

# ani = animation.FuncAnimation(fig, animate, frames=scan_bbox.shape[0], 
#                               interval=40, repeat=False)
# plt.close()
# HTML(ani.to_html5_video())

-----

# Step 2: Pad all images to largest minimal bounding box

In [6]:
def pad(length, target):
    gap = target - length

    if (gap > 0) and (gap % 2 == 0):
        pad = (int(gap/2), int(gap/2))
    elif (gap > 2):
        pad = (int(gap/2), int(gap/2 + 1))
    elif (gap == 1):
        pad = (0,1)
    else:
        pad = (0,0)
    
    return pad

In [7]:
def pad_img(path, target_dims):
    x,y,z = target_dims
    img = np.load(path)

    # set pad targets, centering img
    xpad = pad(img.shape[0], x)
    ypad = pad(img.shape[1], y)    
    zpad = pad(img.shape[2], z)
    
    # pad image
    padded = np.pad(img, pad_width=(xpad, ypad, zpad), mode='constant', constant_values=0)

    # save to disk
    directory = "/Users/anders1991/deepbreath/Data/Scans_padded/"
    if not os.path.exists(directory):
        os.makedirs(directory)
    name = path.split("/")[-1]
    file = directory + name
    np.save(file, padded)

In [None]:
# get paths
paths = glob.glob("/Users/anders1991/deepbreath/Data/Scans_cropped/*.npy")
target_dims = dims

# pad images
for path in paths:
    pad_img(path, target_dims)

Visualize padded scan (sanity check):

In [None]:
# # load image to visualize
# paths = sorted(glob.glob("/Users/anders1991/deepbreath/Data/Scans_padded/*.npy"))
# scan_bbox = np.load(paths[0])

# # create figure
# fig = plt.figure(figsize=(5,5))
# i = np.linspace(1, scan_bbox.shape[0], num=scan_bbox.shape[0], dtype=int)-1
# im = plt.imshow(scan_bbox[0,:,:], animated=True, cmap='jet',
#                vmin=-1000, vmax=0)
# plt.axis('off')

# def animate(i):
#     im.set_array(scan_bbox[i,:,:])
#     return im

# ani = animation.FuncAnimation(fig, animate, frames=scan_bbox.shape[0], 
#                               interval=40, repeat=False)
# plt.close()
# HTML(ani.to_html5_video())

# Step 3: Split into train, valid, test sets

Get list of volume numbers

In [None]:
paths = sorted(glob.glob("/Users/anders1991/deepbreath/Data/Scans_padded/*.npy"))

vols = []
for path in paths:
    num = path.split("_")[1].split("vol")[-1]
    vols.append(int(num))
    
vol_id = np.unique(vols)

Split volumes into train/val/test partitions

In [None]:
# set seed!
np.random.seed(2)

# split partitions
ids = np.copy(vol_id)
partition = dict()

# Test set = 20% 
idx_test = np.random.choice(np.arange(ids.size), 
                            size = int(len(ids) * 0.2), 
                            replace = False)
partition["test"] = ids[idx_test]
ids = np.delete(ids, idx_test)

# train/val = 70/30 of rest
idx_val = np.random.choice(np.arange(ids.size),
                          size = int(len(ids) * 0.3),
                          replace=False)
partition["valid"] = ids[idx_val]
ids = np.delete(ids, idx_val)
partition["train"] = ids

# convert key: list(str) -> key: list(int)
partition = {key: list(map(int, val)) for key, val in partition.items()}

In [None]:
# with open("/Users/anders1991/deepbreath/Data/partition.pkl", 'wb') as f:
#     pickle.dump(partition, f, pickle.HIGHEST_PROTOCOL)

Create dict of batch paths by volume

In [None]:
train_vols = dict()
for vol in partition["train"]:
    volpaths = glob.glob("/Users/anders1991/deepbreath/Data/Scans_padded/vol" + str(vol) + "_*.npy")
    train_vols[vol] = volpaths
    
val_vols = dict()
for vol in partition["valid"]:
    volpaths = glob.glob("/Users/anders1991/deepbreath/Data/Scans_padded/vol" + str(vol) + "_*.npy")
    val_vols[vol] = volpaths
    
test_vols = dict()
for vol in partition["test"]:
    volpaths = glob.glob("/Users/anders1991/deepbreath/Data/Scans_padded/vol" + str(vol) + "_*.npy")
    test_vols[vol] = volpaths

# Step 4: Batch by volume

Single data point version

In [8]:
def batch(split, vol_dict):
    for vol, paths in vol_dict.items():
        paths = sorted(paths) # oldest scan first

        # create volume batch
        img0 = np.expand_dims(np.load(paths[-1]), axis=0)
        batch = np.expand_dims(img0, axis=0)
        
        # output to disk
        directory = "/Users/anders1991/deepbreath/Data/" + split
        if not os.path.exists(directory):
            os.makedirs(directory)
        np.save(directory + "/vol_" + str(vol) + ".npy", batch)

In [None]:
# batch("single_test", test_vols)
# batch("single_valid", val_vols)
# batch("single_train", train_vols)

Time distributed version

In [9]:
def timebatch(split, vol_dict):
    for vol, paths in vol_dict.items():
        paths = sorted(paths) # oldest scan first

        # create volume batch
        img0 = np.expand_dims(np.load(paths[0]), axis=0)
        batch = np.expand_dims(img0, axis=0)

        for path in paths[1:]:
            img = np.expand_dims(np.load(path), axis=0)
            img = np.expand_dims(img, axis=0)
            batch = np.concatenate((batch, img), axis=0)
        
        # output to disk
        directory = "/Users/anders1991/deepbreath/Data/" + split
        if not os.path.exists(directory):
            os.makedirs(directory)
        np.save(directory + "/vol_" + str(vol) + ".npy", batch)

In [None]:
# timebatch("time_test", test_vols)
# timebatch("time_valid", val_vols)
# timebatch("time_train", train_vols)

----