In [1]:
import tensorflow as tf
import numpy as np

In [2]:
def wrap_frozen_graph(graph_def, inputs, outputs, print_graph=False):
    def _imports_graph_def():
        tf.compat.v1.import_graph_def(graph_def, name="")

    wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
    import_graph = wrapped_import.graph

    print("-" * 50)
    print("Frozen model layers: ")
    layers = [op.name for op in import_graph.get_operations()]
    if print_graph == True:
        for layer in layers:
            print(layer)
    print("-" * 50)

    return wrapped_import.prune(
        tf.nest.map_structure(import_graph.as_graph_element, inputs),
        tf.nest.map_structure(import_graph.as_graph_element, outputs))

In [3]:
with tf.io.gfile.GFile("./frozen_models/frozen_graph.pb", "rb") as f:
        graph_def = tf.compat.v1.GraphDef()
        loaded = graph_def.ParseFromString(f.read())

In [4]:
frozen_func = wrap_frozen_graph(graph_def=graph_def,
                                    inputs=["x:0"],
                                    outputs=["Identity:0"],
                                    print_graph=True)

--------------------------------------------------
Frozen model layers: 
x
model/dense/MatMul/ReadVariableOp/resource
model/dense/MatMul/ReadVariableOp
model/dense/MatMul
model/dense/BiasAdd/ReadVariableOp/resource
model/dense/BiasAdd/ReadVariableOp
model/dense/BiasAdd
Identity
--------------------------------------------------


In [6]:
print("-" * 50)
print("Frozen model inputs: ")
print(frozen_func.inputs)
print("Frozen model outputs: ")
print(frozen_func.outputs)

--------------------------------------------------
Frozen model inputs: 
[<tf.Tensor 'x:0' shape=(None, 4) dtype=float32>]
Frozen model outputs: 
[<tf.Tensor 'Identity:0' shape=(None, 1) dtype=float32>]


In [14]:
test_x = np.array([[1,1,1,1]],np.float32)

In [15]:
test_x.shape

(1, 4)

In [18]:
pred_y = frozen_func(x=tf.constant(test_x))[0]
print(pred_y)

tf.Tensor([[32.246185]], shape=(1, 1), dtype=float32)


In [19]:
true_y = np.dot(test_x,np.array([[3],[14],[6],[10]]))
print(true_y)

[[33.]]
