# Working with Frozen TensorFlow model
* This notebook is about:
    * Freezing a TensorFlow graph
    * Loading a frozen model and doing predictions using the same

## Necessary imports

In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import gzip
import os
import multiprocessing
from requests import get
import pickle
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from mnist import mnist

## Load data

In [2]:
obj = mnist()
X_train, y_train, X_test, y_test = obj.load_data()

# split into validation and test set from the test set alone
X_test, X_validation, y_test, y_validation = train_test_split(X_test, y_test, random_state=0)

In [3]:
print('==== Printing shapes of data ===')
print()
print('Train data:      ', X_train.shape, y_train.shape)
print('Test data:       ', X_test.shape, y_test.shape)
print('Validation data: ', X_validation.shape, y_validation.shape)

==== Printing shapes of data ===

Train data:       (60000, 32, 32, 1) (60000,)
Test data:        (7500, 32, 32, 1) (7500,)
Validation data:  (2500, 32, 32, 1) (2500,)


## Class for freezing a model

In [4]:
class freezeGraph(object):
    def __init__(self):
        self.sess = tf.Session()
    
    # THIS IMPLEMENTATION IS ENTIRELY TAKEN FROM THE METAFLOW BLOG
    # MENTIONED IN THE REFERENCE SECTION OF THIS NOTEBOOK
    def freeze_graph(self, model_dir, output_node_names):

        # The original freeze_graph function

        """Extract the sub graph defined by the output nodes and convert 
            all its variables into constant 

        Args:
            model_dir: the root folder containing the checkpoint state file
            output_node_names: a string, containing all the output node's names, 
                                comma separated
        """
        # restore graph meta and model/weights

        if not tf.gfile.Exists(model_dir):
            raise AssertionError(
                "Export directory doesn't exists. Please specify an export "
                "directory: %s" % model_dir)

        if not output_node_names:
            print("You need to supply the name of a node to --output_node_names.")
            return -1

        # We retrieve our checkpoint fullpath
        checkpoint = tf.train.get_checkpoint_state(model_dir)
        input_checkpoint = checkpoint.model_checkpoint_path


        # We precise the file fullname of our freezed graph
        absolute_model_dir = "/".join(input_checkpoint.split('/')[:-1])
        output_graph = absolute_model_dir + "/frozen_model.pb"


        # We clear devices to allow TensorFlow to control on which device it will load operations
        clear_devices = True

        # We start a session using a temporary fresh Graph
#         tf.Session(graph=tf.Graph()) as sess:
        # We import the meta graph in the current default Graph
        saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)

        # We restore the weights
        saver.restore(self.sess, input_checkpoint)
        
        output_graph_def = tf.graph_util.convert_variables_to_constants(
            self.sess, 
            tf.get_default_graph().as_graph_def(),
            output_node_names.split(',')
        )

        # Finally we serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(output_graph, 'wb') as f:
            f.write(output_graph_def.SerializeToString())

        print("%d ops in the final graph." % len(output_graph_def.node))

## Initiaize

In [5]:
f_model = freezeGraph()

## Freeze model

In [6]:
f_model.freeze_graph('./models/mnist/lenet/', 'output')

Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from ./models/mnist/lenet/lenet.ckpt-1
Instructions for updating:
Use tf.compat.v1.graph_util.convert_variables_to_constants
Instructions for updating:
Use tf.compat.v1.graph_util.extract_sub_graph
INFO:tensorflow:Froze 10 variables.
INFO:tensorflow:Converted 10 variables to const ops.
40 ops in the final graph.


## Class to load a frozen model and making predictions

In [7]:
class loadFrozenGraph(object):

    def __init__(self, frozen_graph_path):
        
        # import frozen graph
        self.graph = self.import_graph(frozen_graph_path)
        # MOST IMPORTANT - pass the loaded graph when creating session
        self.sess = tf.Session(graph=self.graph)
        # get input and output tensors
        self.x = self.graph.get_tensor_by_name('prefix/input:0')
        self.y = self.graph.get_tensor_by_name('prefix/output:0')   
        
    def import_graph(self, frozen_graph_path):
            # just a TF way to load a file in desired mode
            # we can also use python file api as well, if loading from local FS
            # for more, checkout the link in the 'Learnings' section
            with tf.gfile.GFile(frozen_graph_path, 'rb') as f:
                # initialize a varible with graphdef which is a 
                # serialized version of the graph
                graph_def = tf.GraphDef()
                # load graphdef from protobuf file
                graph_def.ParseFromString(f.read())
            
            # create an empty graph - bound in a scope here
            # and import the graph def into it
            with tf.Graph().as_default() as graph:
                tf.import_graph_def(graph_def, name='prefix')

            return graph
    
    def get_tensor_names(self):
        # print operations
        for op in self.graph.get_operations():
            print(op.name)
    
    def predict_from_frozen_graph(self, X_test):
        return np.argmax(self.sess.run(self.y, feed_dict={self.x: X_test}), axis=1)
    
    def calculate_accuracy(self, y_pred, y_test):
        correct_instances = np.where(y_test == y_pred)[0].shape[0]
        total_instances = y_test.shape[0]
        accuracy = float(correct_instances)/total_instances
        print('Accuracy: {}'.format(accuracy*100.0))

## Initialize

In [8]:
model = loadFrozenGraph('./models/mnist/lenet/frozen_model.pb')

## Prediction using a frozen model

In [9]:
y_pred = model.predict_from_frozen_graph(X_test)
model.calculate_accuracy(y_pred, y_test)

Accuracy: 95.88


## References 

* https://blog.metaflow.fr/tensorflow-how-to-freeze-a-model-and-serve-it-with-a-python-api-d4f3596b3adc
* https://cv-tricks.com/how-to/freeze-tensorflow-models/

## Learnings

* https://stackoverflow.com/questions/52934795/what-is-difference-frozen-inference-graph-pb-and-saved-model-pb
* https://stackoverflow.com/questions/42256938/what-does-tf-gfile-do-in-tensorflow
* https://stackoverflow.com/questions/47059848/difference-between-tensorflows-graph-and-graphdef
* Things left to explore:
    * optimize_frozen_graph api: https://medium.com/@prasadpal107/saving-freezing-optimizing-for-inference-restoring-of-tensorflow-models-b4146deb21b5 , https://stackoverflow.com/questions/45382917/how-to-optimize-for-inference-a-simple-saved-tensorflow-1-0-1-graph