In [1]:
# Trained tensorflow model

model_file_name = r"model_2019-06-16_16-15-15.h5"
models_folder_name = r"saved_models"

# Input ultrasound sequence names

input_browser_name = r"LandmarkingScan_1"
input_image_name = r"Image_Image"

# Output will be saved using these names

output_browser_name = r"BoneSequenceBrowser"
output_sequence_name = r"SegmentationSequence"
output_image_name = r"Segmented_Image"

# Optionally save output to numpy arrays

array_output = False
array_folder_name = r"Temp"
array_segmentation_name = r"segmentation"
array_ultrasound_name = r"ultrasound"



In [2]:
import numpy as np
import os
import scipy.ndimage

from keras.models import load_model

from local_vars import root_folder



In [3]:
# Check if keras model file exists. Abort if not found. Load model otherwise.

models_path = os.path.join(root_folder, models_folder_name)
model_fullpath = os.path.join(models_path, model_file_name)

if not os.path.exists(model_fullpath):
    print("Could not find model: " + model_fullpath)
    raise

print("Loading model from: " + model_fullpath)

if array_output:
    array_output_fullpath = os.path.join(root_folder, array_folder_name)
    array_segmentation_fullname = os.path.join(array_output_fullpath, array_segmentation_name)
    array_ultrasound_fullname = os.path.join(array_output_fullpath, array_ultrasound_name)
    if not os.path.exists(array_output_fullpath):
        os.mkdir(array_output_fullpath)
        print("Folder created: {}".format(array_output_fullpath))
    print("Will save segmentation output to {}".format(array_segmentation_fullname))
    print("Will save ultrasound output to   {}".format(array_ultrasound_fullname))


model = load_model(model_fullpath)

# model.summary()

Loading model from: c:\Data\saved_models\model_2019-06-16_16-15-15.h5


In [4]:
# Check input. Abort if browser or image doesn't exist.

input_browser_node = slicer.util.getFirstNodeByName(input_browser_name, className='vtkMRMLSequenceBrowserNode')
input_image_node = slicer.util.getFirstNodeByName(input_image_name, className="vtkMRMLScalarVolumeNode")

if input_browser_node is None:
    logging.error("Could not find input browser node: {}".format(input_browser_node))
    raise

if input_image_node is None:
    logging.error("Could not find input image node: {}".format(input_image_name))
    raise



In [5]:
# Create output image and browser for segmentation output.

output_browser_node = slicer.util.getFirstNodeByName(output_browser_name, className='vtkMRMLSequenceBrowserNode')
if output_browser_node is None:
    output_browser_node = slicer.mrmlScene.AddNewNodeByClass('vtkMRMLSequenceBrowserNode', output_browser_name)

output_sequence_node = slicer.util.getFirstNodeByName(output_sequence_name, className="vtkMRMLSequenceNode")
if output_sequence_node is None:
    output_sequence_node = slicer.mrmlScene.AddNewNodeByClass('vtkMRMLSequenceNode', output_sequence_name)

output_browser_node.AddSynchronizedSequenceNode(output_sequence_node)
output_image_node = slicer.util.getFirstNodeByName(output_image_name, className="vtkMRMLScalarVolumeNode")
if output_image_node is None:
    volumes_logic = slicer.modules.volumes.logic()
    output_image_node = volumes_logic.CloneVolume(slicer.mrmlScene, input_image_node, output_image_name)

browser_logic = slicer.modules.sequencebrowser.logic()
browser_logic.AddSynchronizedNode(output_sequence_node, output_image_node, output_browser_node)
output_browser_node.SetRecording(output_sequence_node, True)



In [6]:
# Add all input sequences to the output browser for being able to conveniently replay everything

proxy_collection = vtk.vtkCollection()
input_browser_node.GetAllProxyNodes(proxy_collection)

for i in range(proxy_collection.GetNumberOfItems()):
    proxy_node = proxy_collection.GetItemAsObject(i)
    output_sequence = slicer.mrmlScene.AddNewNodeByClass('vtkMRMLSequenceNode')
    browser_logic.AddSynchronizedNode(output_sequence, proxy_node, output_browser_node)
    output_browser_node.SetRecording(output_sequence, True)



In [7]:
# Iterate input sequence, compute segmentation for each frame, record output sequence.

num_items = input_browser_node.GetNumberOfItems()
n = num_items
input_browser_node.SelectFirstItem()

input_array = slicer.util.array(input_image_node.GetID())
slicer_to_model_scaling = model.layers[0].input_shape[1] / input_array.shape[1]
model_to_slicer_scaling = input_array.shape[1] / model.layers[0].input_shape[1]

print("Will segment {} images".format(n))

if array_output:
    array_output_ultrasound = np.zeros((n, input_array.shape[1], input_array.shape[1]))
    array_output_segmentation = np.zeros((n, input_array.shape[1], input_array.shape[1]))

Will segment 1670 images


In [8]:
for i in range(n):
    input_array = slicer.util.array(input_image_node.GetID())
    
    if array_output:
        array_output_ultrasound[i, :, :] = input_array[0, :, :]
    
    resized_input_array = scipy.ndimage.zoom(input_array[0,:,:], slicer_to_model_scaling)
    resized_input_array = np.flip(resized_input_array, axis=0)
    resized_input_array = resized_input_array / resized_input_array.max()  # Scaling intensity to 0-1
    resized_input_array = np.expand_dims(resized_input_array, axis=0)
    resized_input_array = np.expand_dims(resized_input_array, axis=3)
    y = model.predict(resized_input_array)
    y[0,:,:,:] = np.flip(y[0,:,:,:], axis=0)
    
    upscaled_output_array = scipy.ndimage.zoom(y[0,:,:,1], model_to_slicer_scaling)
    upscaled_output_array = upscaled_output_array * 255
    upscaled_output_array = np.clip(upscaled_output_array, 0, 255)
    
    if array_output:
        array_output_segmentation[i, :, :] = upscaled_output_array[:, :]
    
    # output_array = slicer.util.array(output_image_node.GetID())
    # output_array[0, :, :] = upscaled_output_array[:, :].astype(np.uint8)
    
    slicer.util.updateVolumeFromArray(output_image_node, upscaled_output_array.astype(np.uint8)[np.newaxis, ...])
    
    output_browser_node.SaveProxyNodesState()
    input_browser_node.SelectNextItem()
    
    slicer.app.processEvents()

if array_output:
    np.save(array_ultrasound_fullname, array_output_ultrasound)
    np.save(array_segmentation_fullname, array_output_segmentation)
    print("Saved {}".format(array_ultrasound_fullname))
    print("Saved {}".format(array_segmentation_fullname))

