# Notebook for optimizing brain extraction tensorflow models 

## Dependencies

In [1]:
import os
import numpy as np
from datetime import datetime
import sys

import tensorflow as tf
from tensorflow import data
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.tools import freeze_graph
from tensorflow.python import ops
from tensorflow.tools.graph_transforms import TransformGraph

import tflearn
from tflearn.layers.conv import conv_2d, max_pool_2d, upsample_2d

Instructions for updating:
Colocations handled automatically by placer.


## Define functions that optimized tensorflow SavedModel
Taken from https://medium.com/google-cloud/optimizing-tensorflow-models-for-serving-959080e9ddbf

### Load and convert the SavedModel into a GraphDef

In [2]:
toolkit_dir = '/Users/sebastientourbier/Softwares/mialsuperresolutiontoolkit'
image = f'{toolkit_dir}/data/sub-01/anat/sub-01_run-1_T2w.nii.gz'
manual_mask = f'{toolkit_dir}/data/derivatives/manual_masks/sub-01/anat/sub-01_run-1_desc-brain_mask.nii.gz'
modelCkptLoc = f'{toolkit_dir}/pymialsrtk/data/Network_checkpoints/Network_checkpoints_localization/Unet.ckpt-88000'
modelCkptSeg = f'{toolkit_dir}/pymialsrtk/data/Network_checkpoints/Network_checkpoints_segmentation/Unet.ckpt-20000'

### Run original brain mask extraction interface

In [3]:
from nipype import Node
from pymialsrtk.interfaces.preprocess import BrainExtraction
brainmask = Node(interface=BrainExtraction(), name='brainmask_wf_node', base_dir = '/Users/sebastientourbier/Desktop/mialsrtk')
brainmask.inputs.bids_dir = f'{toolkit_dir}/data'
brainmask.inputs.in_file = image
brainmask.inputs.in_ckpt_loc = modelCkptLoc#+'v2'
brainmask.inputs.threshold_loc = 0.49
brainmask.inputs.in_ckpt_seg = modelCkptSeg#+'v2'
brainmask.inputs.threshold_seg = 0.5
brainmask.inputs.out_postfix = '_brainMask.nii.gz'
brainmask.run() # doctest: +SKIP

201205-08:51:26,737 nipype.workflow INFO:
	 [Node] Setting-up "brainmask_wf_node" in "/Users/sebastientourbier/Desktop/mialsrtk/brainmask_wf_node".
201205-08:51:26,741 nipype.workflow INFO:
	 [Node] Cached "brainmask_wf_node" - collecting precomputed outputs
201205-08:51:26,741 nipype.workflow INFO:
	 [Node] "brainmask_wf_node" found cached.


<nipype.interfaces.base.support.InterfaceResult at 0x7fc5bf665320>

## Resave the graph with checkpoint

### Define function to create the tensorflow graph structure

In [4]:
normalize = "local_max"
width = 128
height = 128
border_x = 15
border_y = 15
n_channels = 1

# Tensorflow graph
def create_graph():
    g = tf.Graph()
    with g.as_default():

        with tf.name_scope('inputs'):
            x = tf.placeholder(tf.float32, [None, width, height, n_channels], name='image')
            print(x)
        conv1 = conv_2d(x, 32, 3, activation='relu', padding='same', regularizer="L2")
        conv1 = conv_2d(conv1, 32, 3, activation='relu', padding='same', regularizer="L2")
        pool1 = max_pool_2d(conv1, 2)

        conv2 = conv_2d(pool1, 64, 3, activation='relu', padding='same', regularizer="L2")
        conv2 = conv_2d(conv2, 64, 3, activation='relu', padding='same', regularizer="L2")
        pool2 = max_pool_2d(conv2, 2)

        conv3 = conv_2d(pool2, 128, 3, activation='relu', padding='same', regularizer="L2")
        conv3 = conv_2d(conv3, 128, 3, activation='relu', padding='same', regularizer="L2")
        pool3 = max_pool_2d(conv3, 2)

        conv4 = conv_2d(pool3, 256, 3, activation='relu', padding='same', regularizer="L2")
        conv4 = conv_2d(conv4, 256, 3, activation='relu', padding='same', regularizer="L2")
        pool4 = max_pool_2d(conv4, 2)

        conv5 = conv_2d(pool4, 512, 3, activation='relu', padding='same', regularizer="L2")
        conv5 = conv_2d(conv5, 512, 3, activation='relu', padding='same', regularizer="L2")

        up6 = upsample_2d(conv5, 2)
        up6 = tflearn.layers.merge_ops.merge([up6, conv4], 'concat', axis=3)
        conv6 = conv_2d(up6, 256, 3, activation='relu', padding='same', regularizer="L2")
        conv6 = conv_2d(conv6, 256, 3, activation='relu', padding='same', regularizer="L2")

        up7 = upsample_2d(conv6, 2)
        up7 = tflearn.layers.merge_ops.merge([up7, conv3], 'concat', axis=3)
        conv7 = conv_2d(up7, 128, 3, activation='relu', padding='same', regularizer="L2")
        conv7 = conv_2d(conv7, 128, 3, activation='relu', padding='same', regularizer="L2")

        up8 = upsample_2d(conv7, 2)
        up8 = tflearn.layers.merge_ops.merge([up8, conv2], 'concat', axis=3)
        conv8 = conv_2d(up8, 64, 3, activation='relu', padding='same', regularizer="L2")
        conv8 = conv_2d(conv8, 64, 3, activation='relu', padding='same', regularizer="L2")

        up9 = upsample_2d(conv8, 2)
        up9 = tflearn.layers.merge_ops.merge([up9, conv1], 'concat', axis=3)
        conv9 = conv_2d(up9, 32, 3, activation='relu', padding='same', regularizer="L2")
        conv9 = conv_2d(conv9, 32, 3, activation='relu', padding='same', regularizer="L2")

        pred = conv_2d(conv9, 2, 1,  activation='linear', padding='valid')
        #tf.identity(pred, name='prediction')
        print(pred)
    return g, x, pred

### Option 1 - `tf.train.Saver.save()`

In [49]:
normalize = "local_max"
width = 128
height = 128
border_x = 15
border_y = 15
n_channels = 1

# Tensorflow graph
g, x, pred = create_graph()

with tf.Session(graph=g) as sess_test_loc:
    # Restore the model
    tf_saver = tf.train.Saver()
    tf_saver.restore(sess_test_loc, modelCkptLoc)
    # save the model
    saved_path = tf_saver.save(sess_test_loc, ''.join([modelCkptLoc,'v2']))
    print('model saved in {}'.format(saved_path))

# Tensorflow graph
g, x, pred = create_graph()

with tf.Session(graph=g) as sess_test_seg:
    # Restore the model
    tf_saver = tf.train.Saver()
    tf_saver.restore(sess_test_seg, modelCkptSeg)
    saved_path = tf_saver.save(sess_test_seg, ''.join([modelCkptSeg,'v2']))
    print('model saved in {}'.format(saved_path))
    

Tensor("inputs/image:0", shape=(?, 128, 128, 1), dtype=float32)
Tensor("Conv2D_18/BiasAdd:0", shape=(?, 128, 128, 2), dtype=float32)
INFO:tensorflow:Restoring parameters from /Users/sebastientourbier/Softwares/mialsuperresolutiontoolkit/pymialsrtk/data/Network_checkpoints/Network_checkpoints_localization/Unet.ckpt-88000
model saved in /Users/sebastientourbier/Softwares/mialsuperresolutiontoolkit/pymialsrtk/data/Network_checkpoints/Network_checkpoints_localization/Unet.ckpt-88000v2
Tensor("inputs/image:0", shape=(?, 128, 128, 1), dtype=float32)
Tensor("Conv2D_18/BiasAdd:0", shape=(?, 128, 128, 2), dtype=float32)
INFO:tensorflow:Restoring parameters from /Users/sebastientourbier/Softwares/mialsuperresolutiontoolkit/pymialsrtk/data/Network_checkpoints/Network_checkpoints_segmentation/Unet.ckpt-20000
model saved in /Users/sebastientourbier/Softwares/mialsuperresolutiontoolkit/pymialsrtk/data/Network_checkpoints/Network_checkpoints_segmentation/Unet.ckpt-20000v2


### Option 2 - `tf.saved_model.simple_save()`

In [6]:
normalize = "local_max"
width = 128
height = 128
border_x = 15
border_y = 15
n_channels = 1

# Tensorflow graph
g, x, pred = create_graph()

im = np.zeros((1, width, height, n_channels))
pred3d = []
with tf.Session(graph=g) as sess_test_loc:
    # Restore the model
    tf_saver = tf.train.Saver()
    tf_saver.restore(sess_test_loc, modelCkptLoc)
    # save the model
    export_dir = f'{toolkit_dir}/pymialsrtk/data/Network_checkpoints/Network_checkpoints_localization2/'
    saved_path = tf.saved_model.simple_save(
        sess_test_loc,
        export_dir,
        inputs={"inputs/image": x},
        outputs={"Conv2D_18/BiasAdd": pred})
    
    print('model saved in {}'.format(saved_path))
    
with tf.Session(graph=g) as sess_test_seg:
    # Restore the model
    tf_saver = tf.train.Saver()
    tf_saver.restore(sess_test_seg, modelCkptSeg)
    # save the model
    export_dir = f'{toolkit_dir}/pymialsrtk/data/Network_checkpoints/Network_checkpoints_segmentation2/'
    saved_path = tf.saved_model.simple_save(sess_test_seg,
                               export_dir,
                               inputs={"inputs/image": x},
                               outputs={"Conv2D_18/BiasAdd": pred})
    print('model saved in {}'.format(saved_path))
    

Tensor("inputs/image:0", shape=(?, 128, 128, 1), dtype=float32)
Tensor("Conv2D_18/BiasAdd:0", shape=(?, 128, 128, 2), dtype=float32)
INFO:tensorflow:Restoring parameters from /Users/sebastientourbier/Softwares/mialsuperresolutiontoolkit/pymialsrtk/data/Network_checkpoints/Network_checkpoints_localization/Unet.ckpt-88000
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.simple_save.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: /Users/sebastientourbier/Softwares/mialsuperresolutiontoolkit/pymialsrtk/data/Network_checkpoints/Network_checkpoints_localization2/saved_model.pb
model saved in None
INFO:tensorflow:Restoring parameters from /Users/sebast

### Option 3 - `tf.saved_model.builder.SavedModelBuilder()`

In [7]:
normalize = "local_max"
width = 128
height = 128
border_x = 15
border_y = 15
n_channels = 1

# Tensorflow graph
g, x, pred = create_graph()

im = np.zeros((1, width, height, n_channels))
pred3d = []
with tf.Session(graph=g) as sess_test_loc:
    # Restore the model
    tf_saver = tf.train.Saver()
    tf_saver.restore(sess_test_loc, modelCkptLoc)
    
    # save the model
    export_dir = f'{toolkit_dir}/pymialsrtk/data/Network_checkpoints/Network_checkpoints_localization3/'
    builder = tf.saved_model.builder.SavedModelBuilder(export_dir)

    tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
    tensor_info_pred = tf.saved_model.utils.build_tensor_info(pred)

    prediction_signature = (
      tf.saved_model.signature_def_utils.build_signature_def(
          inputs={'inputs/image': tensor_info_x},
          outputs={'Conv2D_18/BiasAdd': tensor_info_pred},
          method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))#

    builder.add_meta_graph_and_variables(
        sess_test_loc,
        tags=[tf.saved_model.tag_constants.SERVING],
        signature_def_map={
            tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: prediction_signature },
      )
    saved_path = builder.save()
    print('model saved in {}'.format(saved_path))
    
with tf.Session(graph=g) as sess_test_seg:
    # Restore the model
    tf_saver = tf.train.Saver()
    tf_saver.restore(sess_test_seg, modelCkptSeg)
    
    # save the model
    export_dir = f'{toolkit_dir}/pymialsrtk/data/Network_checkpoints/Network_checkpoints_segmentation3/'
    builder = tf.saved_model.builder.SavedModelBuilder(export_dir)

    tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
    tensor_info_pred = tf.saved_model.utils.build_tensor_info(pred)

    prediction_signature = (
      tf.saved_model.signature_def_utils.build_signature_def(
          inputs={'inputs/image': tensor_info_x},
          outputs={'Conv2D_18/BiasAdd': tensor_info_pred},
          method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))

    builder.add_meta_graph_and_variables(
        sess_test_seg,
        tags=[tf.saved_model.tag_constants.SERVING],
        signature_def_map={
            tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: prediction_signature},
        )
    saved_path = builder.save()
    print('model saved in {}'.format(saved_path))
    

Tensor("inputs/image:0", shape=(?, 128, 128, 1), dtype=float32)
Tensor("Conv2D_18/BiasAdd:0", shape=(?, 128, 128, 2), dtype=float32)
INFO:tensorflow:Restoring parameters from /Users/sebastientourbier/Softwares/mialsuperresolutiontoolkit/pymialsrtk/data/Network_checkpoints/Network_checkpoints_localization/Unet.ckpt-88000
INFO:tensorflow:No assets to save.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: /Users/sebastientourbier/Softwares/mialsuperresolutiontoolkit/pymialsrtk/data/Network_checkpoints/Network_checkpoints_localization3/saved_model.pb
model saved in b'/Users/sebastientourbier/Softwares/mialsuperresolutiontoolkit/pymialsrtk/data/Network_checkpoints/Network_checkpoints_localization3/saved_model.pb'
INFO:tensorflow:Restoring parameters from /Users/sebastientourbier/Softwares/mialsuperresolutiontoolkit/pymialsrtk/data/Network_checkpoints/Network_checkpoints_segmentation/Unet.ckpt-20000
INFO:tensorflow:No assets to save.
INFO:tensorflow:No assets to wri

## Optimize localization graph

### Load and convert the SavedModel into a GraphDef

In [8]:
def get_graph_def_from_saved_model(saved_model_dir): 
    with tf.Session() as session:
        meta_graph_def = tf.saved_model.loader.load(
        session,
        tags=[tag_constants.SERVING],
        export_dir=saved_model_dir
        ) 
    return meta_graph_def.graph_def

graph_def = get_graph_def_from_saved_model('/Users/sebastientourbier/Softwares/mialsuperresolutiontoolkit/pymialsrtk/data/Network_checkpoints/Network_checkpoints_localization3')

Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0.
INFO:tensorflow:Restoring parameters from /Users/sebastientourbier/Softwares/mialsuperresolutiontoolkit/pymialsrtk/data/Network_checkpoints/Network_checkpoints_localization3/variables/variables


### Show graph description

In [9]:
def describe_graph(graph_def, show_nodes=False):
    print('Input Feature Nodes: {}'.format(
        [node.name for node in graph_def.node if node.op=='Placeholder']))
    print('')
    print('Unused Nodes: {}'.format(
        [node.name for node in graph_def.node if 'unused'  in node.name]))
    print('')
    print('Output Nodes: {}'.format( 
        [node.name for node in graph_def.node if (
            'Conv2D_18/BiasAdd' in node.name or 'softmax' in node.name)]))
    print('')
    print('Quantization Nodes: {}'.format(
        [node.name for node in graph_def.node if 'quant' in node.name]))
    print('')
    print('Constant Count: {}'.format(
        len([node for node in graph_def.node if node.op=='Const'])))
    print('')
    print('Variable Count: {}'.format(
      len([node for node in graph_def.node if 'Variable' in node.op])))
    print('')
    print('Identity Count: {}'.format(
      len([node for node in graph_def.node if node.op=='Identity'])))
    print('', 'Total nodes: {}'.format(len(graph_def.node)), '')

    if show_nodes==True:
        for node in graph_def.node:
            print('Op:{} - Name: {}'.format(node.op, node.name))

describe_graph(graph_def, show_nodes=False)

Input Feature Nodes: ['inputs/image']

Unused Nodes: []

Output Nodes: ['Conv2D_18/BiasAdd']

Quantization Nodes: []

Constant Count: 119

Variable Count: 38

Identity Count: 41
 Total nodes: 512 


### Show size

In [10]:
def get_size(model_dir, model_file='saved_model.pb'):
    model_file_path = os.path.join(model_dir, model_file)
    print(model_file_path, '')
    pb_size = os.path.getsize(model_file_path)
    variables_size = 0
    if os.path.exists(os.path.join(model_dir,'variables/variables.data-00000-of-00001')):
        variables_size = os.path.getsize(os.path.join(
            model_dir,'variables/variables.data-00000-of-00001'))
        variables_size += os.path.getsize(os.path.join(
            model_dir,'variables/variables.index'))
    print('Model size: {} KB'.format(round(pb_size/(1024.0),3)))
    print('Variables size: {} KB'.format(round( variables_size/(1024.0),3)))
    print('Total Size: {} KB'.format(round((pb_size + variables_size)/(1024.0),3)))

get_size(f'{toolkit_dir}/pymialsrtk/data/Network_checkpoints/Network_checkpoints_localization3',
         model_file='saved_model.pb')

/Users/sebastientourbier/Softwares/mialsuperresolutiontoolkit/pymialsrtk/data/Network_checkpoints/Network_checkpoints_localization3/saved_model.pb 
Model size: 103.242 KB
Variables size: 30650.242 KB
Total Size: 30753.484 KB


### Freezing the graph

In [15]:
def get_graph_def_from_file(graph_filepath):
    with ops.Graph().as_default():
        with tf.gfile.GFile(graph_filepath, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            return graph_def

def freeze_model(saved_model_dir, output_node_names, output_filename, checkpoints):
    
    output_graph_filename = os.path.join(saved_model_dir, output_filename)
    initializer_nodes = ''
    freeze_graph.freeze_graph(
        input_saved_model_dir=saved_model_dir,
        output_graph=output_graph_filename,
        saved_model_tags = tag_constants.SERVING,
        output_node_names=output_node_names,
        initializer_nodes=initializer_nodes,
        input_graph=None,
        input_saver=False,
        input_binary=False,
        input_checkpoint=checkpoints,
        restore_op_name=None,
        filename_tensor_name=None,
        clear_devices=False,
        input_meta_graph=False,
    )
    print('graph freezed!')
    
def freeze_model2(input_checkpoint,output_graph="frozen_model.pb"):
    from tensorflow.python.framework import graph_util
    
    print("[INFO] input_checkpoint:", input_checkpoint)
    
    # Before exporting our graph, we need to precise what is our output node
    # This is how TF decides what part of the Graph he has to keep and what part it can dump
    output_node_names = "Conv2D_18/BiasAdd" # NOTE: Change here

    # We clear devices to allow TensorFlow to control on which device it will load operations
    clear_devices = True
    
    # We import the meta graph and retrieve a Saver
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)

    # We retrieve the protobuf graph definition
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()

    # We start a session and restore the graph weights
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)

        # We use a built-in TF helper to export variables to constants
        output_graph_def = graph_util.convert_variables_to_constants(
            sess,                        # The session is used to retrieve the weights
            input_graph_def,             # The graph_def is used to retrieve the nodes 
            output_node_names.split(",") # The output node names are used to select the usefull nodes
        ) 

        # Finally we serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))

        print("[INFO] output_graph:",output_graph)
        print("[INFO] all done")

saved_model_dir = f'{toolkit_dir}/pymialsrtk/data/Network_checkpoints/Network_checkpoints_localization3'
frozen_model_dir = f'{toolkit_dir}/pymialsrtk/data/Network_checkpoints/Network_checkpoints_localization3'
saved_model_dir = f'{toolkit_dir}/pymialsrtk/data/Network_checkpoints/Network_checkpoints_localization'
frozen_model_dir = f'{toolkit_dir}/pymialsrtk/data/Network_checkpoints/Network_checkpoints_localization'
frozen_filepath = os.path.join(frozen_model_dir, 'frozen_model.pb')
#freeze_model(saved_model_dir, 'Conv2D_18/BiasAdd', frozen_filepath, modelCkptLoc)
freeze_model2(modelCkptLoc, frozen_filepath)
#frozen_filepath = os.path.join(saved_model_dir,'frozen_model.pb')
#get_size(frozen_filepath)
describe_graph(get_graph_def_from_file(frozen_filepath))

[INFO] input_checkpoint: /Users/sebastientourbier/Softwares/mialsuperresolutiontoolkit/pymialsrtk/data/Network_checkpoints/Network_checkpoints_localization/Unet.ckpt-88000
INFO:tensorflow:Restoring parameters from /Users/sebastientourbier/Softwares/mialsuperresolutiontoolkit/pymialsrtk/data/Network_checkpoints/Network_checkpoints_localization/Unet.ckpt-88000
INFO:tensorflow:Froze 38 variables.
INFO:tensorflow:Converted 38 variables to const ops.
161 ops in the final graph.
[INFO] output_graph: /Users/sebastientourbier/Softwares/mialsuperresolutiontoolkit/pymialsrtk/data/Network_checkpoints/Network_checkpoints_localization/frozen_model.pb
[INFO] all done
Input Feature Nodes: ['inputs/image']

Unused Nodes: []

Output Nodes: ['Conv2D_18/BiasAdd']

Quantization Nodes: []

Constant Count: 50

Variable Count: 0

Identity Count: 38
 Total nodes: 161 


### Optimization : pruning, constant folding and quantization

In [34]:
def get_graph_def_from_file(graph_filepath):
    with ops.Graph().as_default():
        with tf.gfile.GFile(graph_filepath, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            return graph_def

def optimize_graph(model_dir, graph_filename, transforms, output_node):
    input_names = []
    output_names = [output_node]
    if graph_filename is None:
        graph_def = get_graph_def_from_saved_model(model_dir)
    else:
        graph_def = get_graph_def_from_file(os.path.join(model_dir, graph_filename))
    optimized_graph_def = TransformGraph(graph_def,
                                         input_names,
                                         output_names,
                                         transforms)
    tf.train.write_graph(optimized_graph_def,
                         logdir=model_dir,
                         as_text=False,
                         name='optimized_model.pb')
    print('Graph optimized!')

transforms = ['remove_nodes(op=Identity)', 
              'merge_duplicate_nodes',
              'strip_unused_nodes',
              'fold_constants(ignore_errors=true)',
              'fold_batch_norms']#,
              #'quantize_nodes',
              #'quantize_weights']

saved_model_dir = f'{toolkit_dir}/pymialsrtk/data/Network_checkpoints/Network_checkpoints_localization'
optimize_graph(saved_model_dir, 'frozen_model.pb' , transforms, 'Conv2D_18/BiasAdd')
optimized_filepath = os.path.join(saved_model_dir,'optimized_model.pb')
#get_size(optimized_filepath)
describe_graph(get_graph_def_from_file(optimized_filepath))

Graph optimized!
Input Feature Nodes: ['inputs/image']

Unused Nodes: []

Output Nodes: ['Conv2D_18/BiasAdd']

Quantization Nodes: []

Constant Count: 44

Variable Count: 0

Identity Count: 0
 Total nodes: 117 


In [35]:
from tensorflow.python.framework import importer

def convert_graph_def_to_saved_model(export_dir, graph_filepath):
    if tf.gfile.Exists(export_dir):
        tf.gfile.DeleteRecursively(export_dir)
    graph_def = get_graph_def_from_file(graph_filepath)
    
    with tf.Session(graph=tf.Graph()) as session:
        tf.import_graph_def(graph_def, name='')
        # save the model
        builder = tf.saved_model.builder.SavedModelBuilder(export_dir)

        tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
        tensor_info_pred = tf.saved_model.utils.build_tensor_info(pred)

        prediction_signature = (
          tf.saved_model.signature_def_utils.build_signature_def(
              inputs={'inputs/image': tensor_info_x},
              outputs={'Conv2D_18/BiasAdd': tensor_info_pred},
              method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))

        builder.add_meta_graph_and_variables(
          session, [tf.saved_model.tag_constants.SERVING],
          signature_def_map={
              tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                  prediction_signature 
          },
          strip_default_attrs=True,
          )
        saved_path = builder.save()
    print('Optimized graph converted to SavedModel!')

optimized_dir = f'{toolkit_dir}/pymialsrtk/data/Network_checkpoints/Network_checkpoints_localization_opt' 
optimized_filepath = os.path.join(saved_model_dir,'optimized_model.pb')
convert_graph_def_to_saved_model(optimized_dir, optimized_filepath)

INFO:tensorflow:No assets to save.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: /Users/sebastientourbier/Softwares/mialsuperresolutiontoolkit/pymialsrtk/data/Network_checkpoints/Network_checkpoints_localization_opt/saved_model.pb
Optimized graph converted to SavedModel!


In [36]:
g = tf.Graph()
with tf.Session(graph=g) as sess:
    loaded = tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], optimized_dir)
    print(loaded)  # ["serving_default"]


INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored.


IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



## Optimize segmentation graph

### Load and convert the SavedModel into a GraphDef

In [21]:
graph_def = get_graph_def_from_saved_model('/Users/sebastientourbier/Softwares/mialsuperresolutiontoolkit/pymialsrtk/data/Network_checkpoints/Network_checkpoints_segmentation3')

INFO:tensorflow:Restoring parameters from /Users/sebastientourbier/Softwares/mialsuperresolutiontoolkit/pymialsrtk/data/Network_checkpoints/Network_checkpoints_segmentation3/variables/variables


### Show graph description

In [22]:
describe_graph(graph_def, show_nodes=False)

Input Feature Nodes: ['inputs/image']

Unused Nodes: []

Output Nodes: ['Conv2D_18/BiasAdd']

Quantization Nodes: []

Constant Count: 132

Variable Count: 38

Identity Count: 44
 Total nodes: 619 


### Show size

In [23]:
get_size(f'{toolkit_dir}/pymialsrtk/data/Network_checkpoints/Network_checkpoints_segmentation3',
         model_file='saved_model.pb')

/Users/sebastientourbier/Softwares/mialsuperresolutiontoolkit/pymialsrtk/data/Network_checkpoints/Network_checkpoints_segmentation3/saved_model.pb 
Model size: 124.902 KB
Variables size: 30650.242 KB
Total Size: 30775.145 KB


### Freezing the graph

In [26]:
saved_model_dir = f'{toolkit_dir}/pymialsrtk/data/Network_checkpoints/Network_checkpoints_segmentation'
frozen_filepath = os.path.join(saved_model_dir, 'frozen_model.pb')
#freeze_model(saved_model_dir, 'Conv2D_18/BiasAdd', frozen_filepath, checkpoints=modelCkptSeg)
freeze_model2(modelCkptSeg, frozen_filepath)
frozen_filepath = os.path.join(saved_model_dir,'frozen_model.pb')
#get_size(frozen_filepath)
describe_graph(get_graph_def_from_file(frozen_filepath))

[INFO] input_checkpoint: /Users/sebastientourbier/Softwares/mialsuperresolutiontoolkit/pymialsrtk/data/Network_checkpoints/Network_checkpoints_segmentation/Unet.ckpt-20000
INFO:tensorflow:Restoring parameters from /Users/sebastientourbier/Softwares/mialsuperresolutiontoolkit/pymialsrtk/data/Network_checkpoints/Network_checkpoints_segmentation/Unet.ckpt-20000
INFO:tensorflow:Froze 38 variables.
INFO:tensorflow:Converted 38 variables to const ops.
161 ops in the final graph.
[INFO] output_graph: /Users/sebastientourbier/Softwares/mialsuperresolutiontoolkit/pymialsrtk/data/Network_checkpoints/Network_checkpoints_segmentation/frozen_model.pb
[INFO] all done
Input Feature Nodes: ['inputs/image']

Unused Nodes: []

Output Nodes: ['Conv2D_18/BiasAdd']

Quantization Nodes: []

Constant Count: 50

Variable Count: 0

Identity Count: 38
 Total nodes: 161 


### Optimization : pruning, constant folding and quantization

In [37]:
transforms = ['remove_nodes(op=Identity)', 
              'merge_duplicate_nodes',
              'strip_unused_nodes',
              'fold_constants(ignore_errors=true)',
              'fold_batch_norms']

optimize_graph(saved_model_dir, 'frozen_model.pb' , transforms, 'Conv2D_18/BiasAdd')
optimized_filepath = os.path.join(saved_model_dir,'optimized_model.pb')
#get_size(optimized_filepath)
describe_graph(get_graph_def_from_file(optimized_filepath))

Graph optimized!
Input Feature Nodes: ['inputs/image']

Unused Nodes: []

Output Nodes: ['Conv2D_18/BiasAdd']

Quantization Nodes: []

Constant Count: 44

Variable Count: 0

Identity Count: 0
 Total nodes: 117 


In [38]:
optimized_dir = f'{toolkit_dir}/pymialsrtk/data/Network_checkpoints/Network_checkpoints_segmentation_opt' 
optimized_filepath = os.path.join(saved_model_dir,'optimized_model.pb')
convert_graph_def_to_saved_model(optimized_dir, optimized_filepath)

INFO:tensorflow:No assets to save.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: /Users/sebastientourbier/Softwares/mialsuperresolutiontoolkit/pymialsrtk/data/Network_checkpoints/Network_checkpoints_segmentation_opt/saved_model.pb
Optimized graph converted to SavedModel!


## Create interface prototype

In [31]:
from pymialsrtk.interfaces.preprocess import BrainExtraction

In [47]:
class BrainExtraction2(BrainExtraction):
    # Redefine _extract_brain()
    def _extractBrain(self, dataPath, modelCkptLoc, thresholdLoc, modelCkptSeg, thresholdSeg, bidsDir, out_postfix):
        """Generate a brain mask by passing the input image(s) through two networks.

        The first network localizes the brain by a coarse-grained segmentation while the
        second one segments it more precisely. The function saves the output mask in the
        specific module folder created in bidsDir

        Parameters
        ----------
        dataPath <string>
            Input image file (required)

        modelCkptLoc <string>
            Network_checkpoint for localization (required)

        thresholdLoc <Float>
             Threshold determining cutoff probability (default is 0.49)

        modelCkptSeg <string>
            Network_checkpoint for segmentation

        thresholdSeg <Float>
             Threshold determining cutoff probability (default is 0.5)

        bidsDir <string>
            BIDS root directory (required)

        out_postfix <string>
            Suffix of the automatically generated mask (default is '_brainMask.nii.gz')

        """

        ##### Step 1: Brain localization #####
        normalize = "local_max"
        width = 128
        height = 128
        border_x = 15
        border_y = 15
        n_channels = 1

        img_nib = nibabel.load(os.path.join(dataPath))
        image_data = img_nib.get_data()
        images = np.zeros((image_data.shape[2], width, height, n_channels))
        pred3dFinal = np.zeros((image_data.shape[2], image_data.shape[0], image_data.shape[1], n_channels))

        slice_counter = 0
        for ii in range(image_data.shape[2]):
            img_patch = cv2.resize(image_data[:, :, ii], dsize=(width, height), fx=width,
                                   fy=height)

            if normalize:
                if normalize == "local_max":
                    images[slice_counter, :, :, 0] = img_patch / np.max(img_patch)
                elif normalize == "global_max":
                    images[slice_counter, :, :, 0] = img_patch / max_val
                elif normalize == "mean_std":
                    images[slice_counter, :, :, 0] = (img_patch-np.mean(img_patch))/np.std(img_patch)
                else:
                    raise ValueError('Please select a valid normalization')
            else:
                images[slice_counter, :, :, 0] = img_patch

            slice_counter += 1

        # Thresholding parameter to binarize predictions
        percentileLoc = thresholdLoc*100

        im = np.zeros((1, width, height, n_channels))
        pred3d = []
        # Create a clean graph and import the MetaGraphDef nodes.
        g = tf.Graph()
        with tf.Session(graph=g) as sess_test_loc:
            signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
            input_key = 'inputs/image'
            output_key = 'Conv2D_18/BiasAdd'
            # Restore the model
            meta_graph_def = tf.saved_model.loader.load(sess_test_loc,
                                                        [tf.saved_model.tag_constants.SERVING],
                                                        modelCkptLoc)
            
            signature = meta_graph_def.signature_def
            
            x_tensor_name = signature[signature_key].inputs[input_key].name
            pred_tensor_name = signature[signature_key].outputs[output_key].name

            x = sess_test_loc.graph.get_tensor_by_name(x_tensor_name)
            pred = sess_test_loc.graph.get_tensor_by_name(pred_tensor_name)

            for idx in range(images.shape[0]):

                im = np.reshape(images[idx, :, :, :], [1, width, height, n_channels])
                print(im.shape)
                feed_dict = {x: im}
                pred_ = sess_test_loc.run(pred, feed_dict=feed_dict)

                theta = np.percentile(pred_, percentileLoc)
                pred_bin = np.where(pred_ > theta, 1, 0)
                pred3d.append(pred_bin[0, :, :, 0].astype('float64'))

            #####
            pred3d = np.asarray(pred3d)
            heights = []
            widths = []
            coms_x = []
            coms_y = []

            # Apply PPP
            ppp = True
            if ppp:
                pred3d = self._post_processing(pred3d)

            pred3d = [cv2.resize(elem,dsize=(image_data.shape[1], image_data.shape[0]), interpolation=cv2.INTER_NEAREST) for elem in pred3d]
            pred3d = np.asarray(pred3d)
            for i in range(np.asarray(pred3d).shape[0]):
                if np.sum(pred3d[i, :, :]) != 0:
                    pred3d[i, :, :] = self._extractLargestCC(pred3d[i, :, :].astype('uint8'))
                    contours, _ = cv2.findContours(pred3d[i, :, :].astype('uint8'), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
                    area = cv2.minAreaRect(np.squeeze(contours))
                    heights.append(area[1][0])
                    widths.append(area[1][1])
                    bbox = cv2.boxPoints(area).astype('int')
                    coms_x.append(int((np.max(bbox[:, 1])+np.min(bbox[:, 1]))/2))
                    coms_y.append(int((np.max(bbox[:, 0])+np.min(bbox[:, 0]))/2))
            # Saving localization points
            med_x = int(np.median(coms_x))
            med_y = int(np.median(coms_y))
            half_max_x = int(np.max(heights)/2)
            half_max_y = int(np.max(widths)/2)
            x_beg = med_x-half_max_x-border_x
            x_end = med_x+half_max_x+border_x
            y_beg = med_y-half_max_y-border_y
            y_end = med_y+half_max_y+border_y

        ##### Step 2: Brain segmentation #####
        width = 96
        height = 96

        images = np.zeros((image_data.shape[2], width, height, n_channels))

        slice_counter = 0
        for ii in range(image_data.shape[2]):
            img_patch = cv2.resize(image_data[x_beg:x_end, y_beg:y_end, ii], dsize=(width, height))

            if normalize:
                if normalize == "local_max":
                    images[slice_counter, :, :, 0] = img_patch / np.max(img_patch)
                elif normalize == "mean_std":
                    images[slice_counter, :, :, 0] = (img_patch-np.mean(img_patch))/np.std(img_patch)
                else:
                    raise ValueError('Please select a valid normalization')
            else:
                images[slice_counter, :, :, 0] = img_patch

            slice_counter += 1

        g = tf.Graph()
        with tf.Session(graph=g) as sess_test_seg:
            signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
            input_key = 'inputs/image'
            output_key = 'Conv2D_18/BiasAdd'
            # Restore the model
            meta_graph_def = tf.saved_model.loader.load(sess_test_seg,
                                                        [tf.saved_model.tag_constants.SERVING],
                                                        modelCkptSeg)
            
            signature = meta_graph_def.signature_def
            
            x_tensor_name = signature[signature_key].inputs[input_key].name
            pred_tensor_name = signature[signature_key].outputs[output_key].name

            x = sess_test_seg.graph.get_tensor_by_name(x_tensor_name)
            pred = sess_test_seg.graph.get_tensor_by_name(pred_tensor_name)

            for idx in range(images.shape[0]):

                im = np.reshape(images[idx, :, :], [1, width, height, n_channels])
                feed_dict = {x: im}
                pred_ = sess_test_seg.run(pred, feed_dict=feed_dict)
                percentileSeg = thresholdSeg * 100
                theta = np.percentile(pred_, percentileSeg)
                pred_bin = np.where(pred_ > theta, 1, 0)
                # Map predictions to original indices and size
                pred_bin = cv2.resize(pred_bin[0, :, :, 0], dsize=(y_end-y_beg, x_end-x_beg), interpolation=cv2.INTER_NEAREST)
                pred3dFinal[idx, x_beg:x_end, y_beg:y_end,0] = pred_bin.astype('float64')

            pppp = True
            if pppp:
                pred3dFinal = self._post_processing(np.asarray(pred3dFinal))
            pred3d = [cv2.resize(elem, dsize=(image_data.shape[1], image_data.shape[0]), interpolation=cv2.INTER_NEAREST) for elem in pred3dFinal]
            pred3d = np.asarray(pred3d)
            upsampled = np.swapaxes(np.swapaxes(pred3d,1,2),0,2) #if Orient module applied, no need for this line(?)
            up_mask = nibabel.Nifti1Image(upsampled,img_nib.affine)
            # Save output mask

            _, name, ext = split_filename(os.path.abspath(dataPath))
            save_file = os.path.join(os.getcwd(), ''.join((name, out_postfix, ext)))
            nibabel.save(up_mask, save_file)

In [48]:
from nipype import Node
import cv2
import skimage.measure

import scipy.ndimage as snd
from skimage import morphology
from scipy.signal import argrelextrema

import nibabel
import numpy as np

brainmask = Node(interface=BrainExtraction2(),
                 name='brainmask2_wf_node',
                 base_dir = '/Users/sebastientourbier/Desktop/mialsrtk')
brainmask.inputs.bids_dir = f'{toolkit_dir}/data'
brainmask.inputs.in_file = image
brainmask.inputs.in_ckpt_loc = f'{toolkit_dir}/pymialsrtk/data/Network_checkpoints/Network_checkpoints_localization_opt'
brainmask.inputs.threshold_loc = 0.49
brainmask.inputs.in_ckpt_seg = f'{toolkit_dir}/pymialsrtk/data/Network_checkpoints/Network_checkpoints_segmentation_opt'
brainmask.inputs.threshold_seg = 0.5
brainmask.inputs.out_postfix = '_brainMask2.nii.gz'
brainmask.run() # doctest: +SKIP

201205-11:35:35,937 nipype.workflow INFO:
	 [Node] Setting-up "brainmask2_wf_node" in "/Users/sebastientourbier/Desktop/mialsrtk/brainmask2_wf_node".
201205-11:35:35,943 nipype.workflow INFO:
	 [Node] Running "brainmask2_wf_node" ("__main__.BrainExtraction2")



* deprecated from version: 3.0
* Will raise <class 'nibabel.deprecator.ExpiredDeprecationError'> as of version: 5.0


INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored.
(1, 128, 128, 1)
(1, 128, 128, 1)
(1, 128, 128, 1)
(1, 128, 128, 1)
(1, 128, 128, 1)
(1, 128, 128, 1)
(1, 128, 128, 1)
(1, 128, 128, 1)
(1, 128, 128, 1)
(1, 128, 128, 1)
(1, 128, 128, 1)
(1, 128, 128, 1)
(1, 128, 128, 1)
(1, 128, 128, 1)
(1, 128, 128, 1)
(1, 128, 128, 1)
(1, 128, 128, 1)
(1, 128, 128, 1)
(1, 128, 128, 1)
(1, 128, 128, 1)
(1, 128, 128, 1)
(1, 128, 128, 1)
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored.
Failed
Traceback (most recent call last):
  File "/Applications/miniconda3/envs/pymialsrtk-env/lib/python3.6/site-packages/pymialsrtk/interfaces/preprocess.py", line 1542, in _run_interface
    self.inputs.in_ckpt_seg, self.inputs.threshold_seg, self.inputs.bids_dir, se

<nipype.interfaces.base.support.InterfaceResult at 0x7fc583f9cd68>

### Seems challenging to load frozen graph generated from TFLEARN. Should we also describe all inputs to layers? The error above might suggest so.  