In [None]:
import SimpleITK as sitk
import skimage as ski
import numpy as np
import pandas as pd
import os, re, glob
import pickle
# import ReadImages  # Image reader provided by Silas
from IPython.display import HTML
import matplotlib.pyplot as plt
from matplotlib import animation
% matplotlib inline

-----

# Step 1: Crop all images to minimal bounding box

In [None]:
def bbox(img):

    # find planes with values different from 0
    x = np.any(img, axis=(1, 2))
    y = np.any(img, axis=(0, 2))
    z = np.any(img, axis=(0, 1))

    xmin, xmax = np.where(x)[0][[0, -1]]
    ymin, ymax = np.where(y)[0][[0, -1]]
    zmin, zmax = np.where(z)[0][[0, -1]]

    return xmin, xmax, ymin, ymax, zmin, zmax

In [None]:
def max_bbox(paths):
    x = 0
    y = 0
    z = 0
    
    for path in paths:
        
        # load img
        sitk_image = sitk.ReadImage(path, sitk.sitkFloat32)
        image = sitk.GetArrayFromImage(sitk_image)
        
        # update global max dims
        x = np.max((x, image.shape[0]))
        y = np.max((y, image.shape[1]))
        z = np.max((z, image.shape[2]))
            
    return x, y, z

In [None]:
def min_bbox(paths):
    x = np.inf
    y = np.inf
    z = np.inf
    
    for path in paths:
        
        # load img
        sitk_image = sitk.ReadImage(path, sitk.sitkFloat32)
        image = sitk.GetArrayFromImage(sitk_image)
        
        # update global max dims
        x = np.min((x, image.shape[0]))
        y = np.min((y, image.shape[1]))
        z = np.min((z, image.shape[2]))
            
    return x, y, z

In [None]:
# get paths
scan_paths = sorted(glob.glob("../data/Data/scans/vol*.nii"))
mask_paths = sorted(glob.glob("../data/Data/Masks/vol*.nii"))

# crop images
dims = min_bbox(scan_paths)

In [None]:
print("Target shape:", dims)

In [None]:
def crop(dims, scan_paths, mask_paths):
    x, y, z = tuple(int(dim//2) for dim in dims)
    
    for img_path, mask_path in zip(scan_paths, mask_paths):
        # load img
        sitk_image = sitk.ReadImage(img_path, sitk.sitkFloat32)
        image = sitk.GetArrayFromImage(sitk_image)
        x0, y0, z0 = tuple(int(dim//2) for dim in image.shape)
        
        # load mask
        sitk_image = sitk.ReadImage(mask_path, sitk.sitkFloat32)
        mask = sitk.GetArrayFromImage(sitk_image) # 0 = background, 1 = foreground
        bground = np.array(mask, dtype=bool)
        
        # shift background by -800 (cf. gerdas paper)
        masked_image = np.ma.array(image, mask=bground)
        masked_image -= 800 # references image
        
        # crop
        cropped = image[x0-x:x0+x, y0-y:y0+y, z0-z:z0+z]
        
        # save
        directory = "../data/scans_cropped_small/"
        if not os.path.exists(directory):
            os.makedirs(directory)
        name = img_path.split("/")[-1].split(".nii")[0]
        file = directory + name + ".npy"
        np.save(file, cropped)

In [None]:
crop(dims, scan_paths, mask_paths)

Visualization of cropped image (sanity check):

In [None]:
# # load image to visualize
# paths = sorted(glob.glob("../data/scans_cropped_small/*.npy"))
# scan_bbox = np.load(paths[5])

# # create figure
# fig = plt.figure(figsize=(5,5))
# i = np.linspace(1, scan_bbox.shape[0], num=scan_bbox.shape[0], dtype=int)-1
# im = plt.imshow(scan_bbox[0,:,:], animated=True, cmap='jet', vmin=-1000, vmax=0)
# plt.axis('off')

# def animate(i):
#     im.set_array(scan_bbox[i,:,:])
#     return im

# ani = animation.FuncAnimation(fig, animate, frames=scan_bbox.shape[0], 
#                               interval=40, repeat=False)
# plt.close()
# HTML(ani.to_html5_video())

# Step 2: Batch by volume

In [None]:
with open("../data/partition.pkl", 'rb') as f:
    partition = pickle.load(f)

In [None]:
train_vols = dict()
for vol in partition["train"]:
    volpaths = glob.glob("../data/scans_cropped_small/vol" + str(vol) + "_*.npy")
    train_vols[vol] = volpaths
    
val_vols = dict()
for vol in partition["valid"]:
    volpaths = glob.glob("../data/scans_cropped_small/vol" + str(vol) + "_*.npy")
    val_vols[vol] = volpaths
    
test_vols = dict()
for vol in partition["test"]:
    volpaths = glob.glob("../data/scans_cropped_small/vol" + str(vol) + "_*.npy")
    test_vols[vol] = volpaths

Single data point version

In [None]:
def batch(split, vol_dict):
    for vol, paths in vol_dict.items():
        paths = sorted(paths) # oldest scan first

        # create volume batch
        img0 = np.expand_dims(np.load(paths[-1]), axis=0)
        batch = np.expand_dims(img0, axis=0)
        
        # output to disk
        directory = "../data/" + split
        if not os.path.exists(directory):
            os.makedirs(directory)
        np.save(directory + "/vol_" + str(vol) + ".npy", batch)

In [None]:
# batch("single_test", test_vols)
# batch("single_valid", val_vols)
# batch("single_train", train_vols)

Time distributed version

In [None]:
def timebatch(split, vol_dict):
    for vol, paths in vol_dict.items():
        paths = sorted(paths) # oldest scan first

        # create volume batch
        img0 = np.expand_dims(np.load(paths[0]), axis=0)
        batch = np.expand_dims(img0, axis=0)

        for path in paths[1:]:
            img = np.expand_dims(np.load(path), axis=0)
            img = np.expand_dims(img, axis=0)
            batch = np.concatenate((batch, img), axis=0)
        
        # output to disk
        directory = "../data/" + split
        if not os.path.exists(directory):
            os.makedirs(directory)
        np.save(directory + "/vol_" + str(vol) + ".npy", batch)

In [None]:
# timebatch("time_test", test_vols)
# timebatch("time_valid", val_vols)
# timebatch("time_train", train_vols)

----