In [1]:
model_file_name = r"model_2019-05-31_21-22-03.h5"
models_folder_name = r"saved_models"
sequence_browser_name = r"SagittalScan"
output_browser_name = r"BoneSequenceBrowser"
input_image_name = "Image_Image"
output_image_name = "Segmented_Image"



In [2]:
import numpy as np
import os
import scipy.ndimage

from local_vars import root_folder

from keras.models import load_model

models_path = os.path.join(root_folder, models_folder_name)
model_fullpath = os.path.join(models_path, model_file_name)

if not os.path.exists(model_fullpath):
    print("Could not find model: " + model_fullpath)
    raise



In [3]:
print("Loading model from: " + model_fullpath)

model = load_model(model_fullpath)

# model.summary()

Loading model from: J:\Data\saved_models\model_2019-05-31_21-22-03.h5


In [4]:
sequence_browser_node = slicer.util.getFirstNodeByName(sequence_browser_name, className='vtkMRMLSequenceBrowserNode')
print("Sequence browser node ID: " + str(sequence_browser_node.GetID()))
print("Sequence browser node name: " + str(sequence_browser_node.GetName()))

Sequence browser node ID: vtkMRMLSequenceBrowserNode2
Sequence browser node name: SagittalScan


In [5]:
input_image_node = slicer.util.getFirstNodeByName(input_image_name, className="vtkMRMLScalarVolumeNode")
if input_image_node is None:
    logging.error("Could not find input image node: {}".format(input_image_name))
    raise



In [8]:
output_browser_node = slicer.util.getFirstNodeByName(output_browser_name, className='vtkMRMLSequenceBrowserNode')
if output_browser_node is None:
    output_browser_node = slicer.mrmlScene.AddNewNodeByClass('vtkMRMLSequenceBrowserNode', output_browser_name)

output_sequence_name = "SegmentationSequence"

output_image_sequence_node = slicer.util.getFirstNodeByName(output_sequence_name, className="vtkMRMLSequenceNode")
if output_image_sequence_node is None:
    output_image_sequence_node = slicer.mrmlScene.AddNewNodeByClass('vtkMRMLSequenceNode', output_sequence_name)

output_browser_node.AddSynchronizedSequenceNode(output_image_sequence_node)
output_image_node = slicer.util.getFirstNodeByName(output_image_name, className="vtkMRMLScalarVolumeNode")
if output_image_node is None:
    volumes_logic = slicer.modules.volumes.logic()
    print("Cloning output image from: {}".format(input_image_node.GetName()))
    output_image_node = volumes_logic.CloneVolume(slicer.mrmlScene, input_image_node, output_image_name)
    print("Created new image named:   {}".format(output_image_node.GetName()))
    output_image_id = output_image_node.GetID()

print("Adding new proxy node: {}".format(output_image_node.GetName()))
output_browser_node.AddProxyNode(output_image_node, output_image_sequence_node)
print("New proxy node:        {}".format(output_browser_node.GetProxyNode(output_image_sequence_node).GetName()))

Adding new proxy node: Segmented_Image
New proxy node:        Sequence_Segmented_Image


In [None]:
num_items = sequence_browser_node.GetNumberOfItems()
n = num_items
sequence_browser_node.SelectFirstItem()

input_array = slicer.util.array(input_image_node.GetID())
slicer_to_model_scaling = model.layers[0].input_shape[1] / input_array.shape[1]
model_to_slicer_scaling = input_array.shape[1] / model.layers[0].input_shape[1]

n = min(n, 20)
for i in range(n):
    input_array = slicer.util.array(input_image_node.GetID())
    
    resized_input_array = scipy.ndimage.zoom(input_array[0,:,:], slicer_to_model_scaling)
    resized_input_array = np.flip(resized_input_array, axis=0)
    resized_input_array = resized_input_array * 1/255.0
    resized_input_array = np.expand_dims(resized_input_array, axis=0)
    resized_input_array = np.expand_dims(resized_input_array, axis=3)
    y = model.predict(resized_input_array)
    y[0,:,:,:] = np.flip(y[0,:,:,:], axis=0)
    upscaled_output_array = scipy.ndimage.zoom(y[0,:,:,1], model_to_slicer_scaling)
    upscaled_output_array = upscaled_output_array * 255.0
    
    output_array = slicer.util.array(output_image_node.GetID())
    output_array[0, :, :] = upscaled_output_array[:, :]
    output_browser_node.SaveProxyNodesState()
    sequence_browser_node.SelectNextItem()