# Load saved TensorFlow model(s)
* Exploring how to load saved model(s) (saved using different methods) and do predictions with them. 
* Broadly, models were saved using:
    * [tf.train.saver](https://www.tensorflow.org/api_docs/python/tf/train/Saver)
    * [tf.saved_model.builder.SavedModelBuilder](https://www.tensorflow.org/api_docs/python/tf/saved_model/Builder)

__NOTE:__ For more on how to save models refer: [Exploring-TensorFlow-Low-Level-API.ipynb](Exploring-TensorFlow-Low-Level-API.ipynb) notebook!

## Necessary imports

In [1]:
import tensorflow as tf
from mnist import mnist
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
import time
import numpy as np

## Load data

In [2]:
obj = mnist()
X_train, y_train, X_test, y_test = obj.load_data()

# split into validation and test set from the test set alone
X_test, X_validation, y_test, y_validation = train_test_split(X_test, y_test, random_state=0)

## Load model class

In [3]:
class loadModel(object):
    def __init__(self):
        
        config = tf.ConfigProto(device_count={'GPU':1, 'CPU':3})
        config.gpu_options.allow_growth = True
        config.gpu_options.per_process_gpu_memory_fraction = 0.75
        self.sess = tf.Session(config=config)
        
    def load_model_saved_using_tf_train_saver(self, model_path):
        # get latest checkpoint from model path
        checkpoint_path = tf.train.latest_checkpoint(model_path)
        # import graph --> Recreates a Graph saved in a MetaGraphDef proto.
        saver = tf.train.import_meta_graph('{}.meta'.format(checkpoint_path))
        # restores graph and variables in the session 
        saver.restore(self.sess, checkpoint_path)
        # get the default session's graph
        graph = tf.get_default_graph()
        return graph
    
    def load_model_saved_using_tf_saved_model_builder(self, model_path):
        # create a new, empty graph
        graph = tf.Graph
        # load variables, weights, and graph into the default session
        tf.saved_model.loader.load(self.sess, ['serve'], model_path)
        # get the default graph
        graph = tf.get_default_graph()
        return graph
            
    def predict(self, X_test, model_path, function):
        graph = function(model_path)
        x = graph.get_tensor_by_name('input:0')
        y = graph.get_tensor_by_name('output:0')
        return np.argmax(self.sess.run(y, feed_dict={x:X_test}), axis=1)

    def calculate_accuracy(self, y_pred, y_test):
        correct_instances = np.where(y_test == y_pred)[0].shape[0]
        total_instances = y_test.shape[0]
        accuracy = float(correct_instances)/total_instances
        print('Accuracy: {}'.format(accuracy*100.0))

## Initialize

In [4]:
model = loadModel()

## Prediction

### a. Load model saved via tf.train.saver

In [5]:
y_pred = model.predict(X_test=X_test, model_path='./models/mnist/lenet/', 
                             function=model.load_model_saved_using_tf_train_saver)
model.calculate_accuracy(y_pred, y_test)

Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from ./models/mnist/lenet/lenet.ckpt-1
Accuracy: 94.93333333333334


### b. Load model via tf.saved_model.builder.SavedModelBuilder

In [6]:
y_pred = model.predict(X_test=X_test, model_path='./models/mnist/lenet/using_SavedModelBuilder/', 
                             function=model.load_model_saved_using_tf_saved_model_builder)
model.calculate_accuracy(y_pred, y_test)

Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0.
INFO:tensorflow:Restoring parameters from ./models/mnist/lenet/using_SavedModelBuilder/variables/variables
Accuracy: 94.93333333333334


# References

* https://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/
* https://medium.com/@jsflo.dev/saving-and-loading-a-tensorflow-model-using-the-savedmodel-api-17645576527
* https://blog.metaflow.fr/tensorflow-saving-restoring-and-mixing-multiple-models-c4c94d5d7125