# Exporting Frozen Graphs in Tensorflow 2 
In order to use a trained model as input to MIGraphX, the model must be first be saved in a frozen graph format. This was accomplished in Tensorflow 1 by launching a graph in a tf.Session and then saving the session. However, Tensorflow has decided to deprecate Sessions in favor of functions and SavedModel format.  

After importing the necessary libraries, the next step is to instantiate a model. For simplicity, in this example we will use a resnet50 architecture with pre-trained imagenet weights. These weights may also be trained or fine-tuned before freezing. 

In [1]:
import tensorflow as tf
tf.enable_eager_execution() #May not be required depending on tensorflow version
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
from tensorflow import keras
from tensorflow.keras import layers

MODEL_NAME = "resnet50"
model = tf.keras.applications.ResNet50(weights="imagenet")
model.summary()

Model: "resnet50"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D)       (None, 230, 230, 3)  0           input_1[0][0]                    
__________________________________________________________________________________________________
conv1_conv (Conv2D)             (None, 112, 112, 64) 9472        conv1_pad[0][0]                  
__________________________________________________________________________________________________
conv1_bn (BatchNormalization)   (None, 112, 112, 64) 256         conv1_conv[0][0]                 
___________________________________________________________________________________________

## SavedModel format
The simplest way to save a model is through saved\_model.save()

This will create an equivalent tensorflow program which can later be loaded for fine-tuning or inference, although it is not directly compatible with MIGraphX.

In [2]:
tf.saved_model.save(model, "./Saved_Models/{}".format(MODEL_NAME))

Instructions for updating:
If using Keras pass *_constraint arguments to layers.
INFO:tensorflow:Assets written to: ./Saved_Models/resnet50/assets


## Convert to ConcreteFunction
To begin, we need to get the function equivalent of the model and then concretize the function to avoid retracing.

In [3]:
full_model = tf.function(lambda x: model(x))
full_model = full_model.get_concrete_function(
    x=tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

## Freeze ConcreteFunction and Serialize
Since we are saving the graph for the purpose of inference, all variables can be made constant (i.e. "frozen").

Next, we need to obtain a serialized GraphDef representation of the graph. 


Optionally, the operators can be printed out layer by layer followed by the inputs and outputs.

In [4]:
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()

layers = [op.name for op in frozen_func.graph.get_operations()]
print("-" * 50)
print("Frozen model layers: ")
for layer in layers:
    print(layer)

print("-" * 50)
print("Frozen model inputs: ")
print(frozen_func.inputs)
print("Frozen model outputs: ")
print(frozen_func.outputs)

--------------------------------------------------
Frozen model layers: 
x
resnet50/conv1_pad/Pad/paddings
resnet50/conv1_pad/Pad
resnet50/conv1_conv/Conv2D/ReadVariableOp/resource
resnet50/conv1_conv/Conv2D/ReadVariableOp
resnet50/conv1_conv/Conv2D
resnet50/conv1_conv/BiasAdd/ReadVariableOp/resource
resnet50/conv1_conv/BiasAdd/ReadVariableOp
resnet50/conv1_conv/BiasAdd
resnet50/conv1_bn/ReadVariableOp/resource
resnet50/conv1_bn/ReadVariableOp
resnet50/conv1_bn/ReadVariableOp_1/resource
resnet50/conv1_bn/ReadVariableOp_1
resnet50/conv1_bn/FusedBatchNormV3/ReadVariableOp/resource
resnet50/conv1_bn/FusedBatchNormV3/ReadVariableOp
resnet50/conv1_bn/FusedBatchNormV3/ReadVariableOp_1/resource
resnet50/conv1_bn/FusedBatchNormV3/ReadVariableOp_1
resnet50/conv1_bn/FusedBatchNormV3
resnet50/conv1_relu/Relu
resnet50/pool1_pad/Pad/paddings
resnet50/pool1_pad/Pad
resnet50/pool1_pool/MaxPool
resnet50/conv2_block1_0_conv/Conv2D/ReadVariableOp/resource
resnet50/conv2_block1_0_conv/Conv2D/ReadVariable

## Save Frozen Graph as Protobuf
Finally, we can save to hard drive, and now the frozen graph will be stored as `./frozen_models/<MODEL_NAME>_frozen_graph.pb`

In [5]:
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                  logdir="./frozen_models",
                  name="{}_frozen_graph.pb".format(MODEL_NAME),
                  as_text=False)

'./frozen_models/resnet50_frozen_graph.pb'

Assuming MIGraphX has already been built and installed on your system, the driver can be used to verify that the frozen graph has been correctly exported. 

In [6]:
import subprocess
driver = "/opt/rocm/bin/migraphx-driver"
command = "read"
model_path = "./frozen_models/{}_frozen_graph.pb".format(MODEL_NAME)
process = subprocess.run([driver, command, model_path], 
                         stdout=subprocess.PIPE, 
                         universal_newlines=True)

print(process.stdout)

Reading: ./frozen_models/resnet50_frozen_graph.pb
@0 = @literal{ ... } -> float_type, {1000}, {1}
@1 = @literal{ ... } -> float_type, {2048, 1000}, {1000, 1}
@2 = @literal{1, 2} -> int32_type, {2}, {1}
@3 = @literal{ ... } -> float_type, {2048}, {1}
@4 = @literal{ ... } -> float_type, {2048}, {1}
@5 = @literal{ ... } -> float_type, {2048}, {1}
@6 = @literal{ ... } -> float_type, {2048}, {1}
@7 = @literal{ ... } -> float_type, {2048}, {1}
@8 = @literal{ ... } -> float_type, {1, 1, 512, 2048}, {1048576, 1048576, 2048, 1}
@9 = @literal{ ... } -> float_type, {512}, {1}
@10 = @literal{ ... } -> float_type, {512}, {1}
@11 = @literal{ ... } -> float_type, {512}, {1}
@12 = @literal{ ... } -> float_type, {512}, {1}
@13 = @literal{ ... } -> float_type, {512}, {1}
@14 = @literal{ ... } -> float_type, {3, 3, 512, 512}, {786432, 262144, 512, 1}
@15 = @literal{ ... } -> float_type, {512}, {1}
@16 = @literal{ ... } -> float_type, {512}, {1}
@17 = @literal{ ... } -> float_type, {512}, {1}
@18 = @liter