In [1]:
# Trained tensorflow model

model_file_name = r"model_2019-05-31_21-22-03.h5"
models_folder_name = r"saved_models"

# Input ultrasound sequence names

input_browser_name = r"SagittalScan"
input_image_name = r"Image_Image"

# Output will be saved using these names

output_browser_name = r"BoneSequenceBrowser"
output_sequence_name = r"SegmentationSequence"
output_image_name = r"Segmented_Image"



In [2]:
import numpy as np
import os
import scipy.ndimage

from keras.models import load_model

from local_vars import root_folder



In [3]:
# Check if keras model file exists. Abort if not found. Load model otherwise.

models_path = os.path.join(root_folder, models_folder_name)
model_fullpath = os.path.join(models_path, model_file_name)

if not os.path.exists(model_fullpath):
    print("Could not find model: " + model_fullpath)
    raise

print("Loading model from: " + model_fullpath)

model = load_model(model_fullpath)

# model.summary()

Loading model from: c:\Data\saved_models\model_2019-05-31_21-22-03.h5


In [5]:
# Check input. Abort if browser or image doesn't exist.

input_browser_node = slicer.util.getFirstNodeByName(input_browser_name, className='vtkMRMLSequenceBrowserNode')
input_image_node = slicer.util.getFirstNodeByName(input_image_name, className="vtkMRMLScalarVolumeNode")

if input_browser_node is None:
    logging.error("Could not find input browser node: {}".format(input_browser_node))
    raise

if input_image_node is None:
    logging.error("Could not find input image node: {}".format(input_image_name))
    raise



In [6]:
# Create output image and browser for segmentation output.

output_browser_node = slicer.util.getFirstNodeByName(output_browser_name, className='vtkMRMLSequenceBrowserNode')
if output_browser_node is None:
    output_browser_node = slicer.mrmlScene.AddNewNodeByClass('vtkMRMLSequenceBrowserNode', output_browser_name)

output_sequence_node = slicer.util.getFirstNodeByName(output_sequence_name, className="vtkMRMLSequenceNode")
if output_sequence_node is None:
    output_sequence_node = slicer.mrmlScene.AddNewNodeByClass('vtkMRMLSequenceNode', output_sequence_name)

output_browser_node.AddSynchronizedSequenceNode(output_sequence_node)
output_image_node = slicer.util.getFirstNodeByName(output_image_name, className="vtkMRMLScalarVolumeNode")
if output_image_node is None:
    volumes_logic = slicer.modules.volumes.logic()
    output_image_node = volumes_logic.CloneVolume(slicer.mrmlScene, input_image_node, output_image_name)

browser_logic = slicer.modules.sequencebrowser.logic()
browser_logic.AddSynchronizedNode(output_sequence_node, output_image_node, output_browser_node)
output_browser_node.SetRecording(output_sequence_node, True)



In [7]:
# Add all input sequences to the output browser for being able to conveniently replay everything

proxy_collection = vtk.vtkCollection()
input_browser_node.GetAllProxyNodes(proxy_collection)

for i in range(proxy_collection.GetNumberOfItems()):
    proxy_node = proxy_collection.GetItemAsObject(i)
    output_sequence = slicer.mrmlScene.AddNewNodeByClass('vtkMRMLSequenceNode')
    browser_logic.AddSynchronizedNode(output_sequence, proxy_node, output_browser_node)
    output_browser_node.SetRecording(output_sequence, True)



In [8]:
# Iterate input sequence, compute segmentation for each frame, record output sequence.

num_items = input_browser_node.GetNumberOfItems()
n = num_items
input_browser_node.SelectFirstItem()

input_array = slicer.util.array(input_image_node.GetID())
slicer_to_model_scaling = model.layers[0].input_shape[1] / input_array.shape[1]
model_to_slicer_scaling = input_array.shape[1] / model.layers[0].input_shape[1]

for i in range(n):
    input_array = slicer.util.array(input_image_node.GetID())
    
    resized_input_array = scipy.ndimage.zoom(input_array[0,:,:], slicer_to_model_scaling)
    resized_input_array = np.flip(resized_input_array, axis=0)
    resized_input_array = resized_input_array * 1/255.0
    resized_input_array = np.expand_dims(resized_input_array, axis=0)
    resized_input_array = np.expand_dims(resized_input_array, axis=3)
    y = model.predict(resized_input_array)
    y[0,:,:,:] = np.flip(y[0,:,:,:], axis=0)
    upscaled_output_array = scipy.ndimage.zoom(y[0,:,:,1], model_to_slicer_scaling)
    upscaled_output_array = upscaled_output_array * 255.0
    
    output_array = slicer.util.array(output_image_node.GetID())
    output_array[0, :, :] = upscaled_output_array[:, :]
    
    output_browser_node.SaveProxyNodesState()
    input_browser_node.SelectNextItem()

