<a href="https://colab.research.google.com/github/rakesh4real/role-models/blob/main/optimize-TF-svavedmodels.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

- **Tool:** [TF Graph Transforms python API](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms)
- **Input:** `SavedModel` [format](https://www.tensorflow.org/guide/saved_model) combines a `GraphDef` with checkpoint files that store weights, **all collected in a folder**.

# Steps

1. Freeze the `SavedModel` model by converting to `Graphdef` format
2. Optimize frozen `GraphDef` mode;
3. Unfreeze to `SavedModel` format

In [1]:
from __future__ import print_function

import os
import numpy as np
from datetime import datetime
import sys

import tensorflow as tf
from tensorflow import data
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.tools import freeze_graph
from tensorflow.python import ops
from tensorflow.tools.graph_transforms import TransformGraph

NUM_CLASSES = 10
MODELS_LOCATION = 'models/mnist'
MODEL_NAME = 'keras_classifier'

def load_mnist_keras():
  (train_data, train_labels), (eval_data, eval_labels) = tf.keras.datasets.mnist.load_data()
  return train_data, train_labels, eval_data, eval_labels

def keras_model_fn(params):
    
  inputs = tf.keras.layers.Input(shape=(28, 28), name='input_image')
  input_layer = tf.keras.layers.Reshape(target_shape=(28, 28, 1), name='reshape')(inputs)
  
  # convolutional layers
  conv_inputs = input_layer
  for i in range(params.num_conv_layers):      
    filters = params.init_filters * (2**i)
    conv = tf.keras.layers.Conv2D(kernel_size=3, filters=filters, strides=1, padding='SAME', activation='relu')(conv_inputs)
    max_pool = tf.keras.layers.MaxPool2D(pool_size=2, strides=2, padding='SAME')(conv)
    batch_norm = tf.keras.layers.BatchNormalization()(max_pool)
    conv_inputs = batch_norm

  flatten = tf.keras.layers.Flatten(name='flatten')(conv_inputs)
  
  # fully-connected layers
  dense_inputs = flatten
  for i in range(len(params.hidden_units)):      
    dense = tf.keras.layers.Dense(units=params.hidden_units[i], activation='relu')(dense_inputs)
    dropout = tf.keras.layers.Dropout(params.dropout)(dense)
    dense_inputs = dropout
      
  # softmax classifier
  logits = tf.keras.layers.Dense(units=NUM_CLASSES, name='logits')(dense_inputs)
  softmax = tf.keras.layers.Activation('softmax', name='softmax')(logits)

  # keras model
  model = tf.keras.models.Model(inputs, softmax)
  return model


def create_estimator_keras(params, run_config):
    
  keras_model = keras_model_fn(params)
  print(keras_model.summary())
  
  optimizer = tf.keras.optimizers.Adam(lr=params.learning_rate)
  keras_model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
  mnist_classifier = tf.keras.estimator.model_to_estimator(
      keras_model=keras_model,
      config=run_config
  )
  
  return mnist_classifier

ModuleNotFoundError: ignored

In [None]:
def run_experiment(hparams, train_data, train_labels, run_config, create_estimator_fn=create_estimator):
  train_spec = tf.estimator.TrainSpec(
      input_fn = tf.estimator.inputs.numpy_input_fn(
          x={'input_image': train_data},
          y=train_labels,
          batch_size=hparams.batch_size,
          num_epochs=None,
          shuffle=True),
      max_steps=hparams.max_training_steps
  )
  eval_spec = tf.estimator.EvalSpec(
      input_fn = tf.estimator.inputs.numpy_input_fn(
          x={'input_image': train_data},
          y=train_labels,
          batch_size=hparams.batch_size,
          num_epochs=1,
          shuffle=False),
      steps=None,
      throttle_secs=hparams.eval_throttle_secs
  )

  tf.logging.set_verbosity(tf.logging.INFO)

  time_start = datetime.utcnow()
  print('Experiment started at {}'.format(time_start.strftime('%H:%M:%S')))
  print('.......................................')

  estimator = create_estimator_fn(hparams, run_config)

  tf.estimator.train_and_evaluate(
      estimator=estimator,
      train_spec=train_spec,
      eval_spec=eval_spec
  )

  time_end = datetime.utcnow()
  print('.......................................')
  print('Experiment finished at {}'.format(time_end.strftime('%H:%M:%S')))
  print('')
  time_elapsed = time_end - time_start
  print('Experiment elapsed time: {} seconds'.format(time_elapsed.total_seconds()))

  return estimator


def train_and_export_model(train_data, train_labels):
  model_dir = os.path.join(MODELS_LOCATION, MODEL_NAME)

  hparams  = tf.contrib.training.HParams(
      batch_size=100,
      hidden_units=[1024],
      num_conv_layers=2,
      init_filters=64,
      dropout=0.85,
      max_training_steps=50,
      eval_throttle_secs=10,
      learning_rate=1e-3,
      debug=True
  )

  run_config = tf.estimator.RunConfig(
      tf_random_seed=19830610,
      save_checkpoints_steps=1000,
      keep_checkpoint_max=3,
      model_dir=model_dir
  )

  if tf.gfile.Exists(model_dir):
      print('Removing previous artifacts...')
      tf.gfile.DeleteRecursively(model_dir)

  os.makedirs(model_dir)

  estimator = run_experiment(hparams, train_data, train_labels, run_config, create_estimator_keras)

  def make_serving_input_receiver_fn():
      inputs = {'input_image': tf.placeholder(
          shape=[None,28,28], dtype=tf.float32, name='serving_input_image')}
      return tf.estimator.export.build_raw_serving_input_receiver_fn(inputs)

  export_dir = os.path.join(model_dir, 'export')

  if tf.gfile.Exists(export_dir):
      tf.gfile.DeleteRecursively(export_dir)

  estimator.export_savedmodel(
      export_dir_base=export_dir,
      serving_input_receiver_fn=make_serving_input_receiver_fn()
  )

  return export_dir
