# 把checkpoint转化成frozen graph

In [1]:
import tensorflow as tf
from tensorflow.python.compiler.tensorrt import trt_convert as trt
import numpy as np
import matplotlib.pyplot as plt
import os
import time
import logging

logging.getLogger("tensorflow").setLevel(logging.ERROR)

import os
os.environ['CUDA_VISIBLE_DEVICES']='0'


config = tf.ConfigProto()
config.gpu_options.allow_growth=True
import sys
sys.path.append('/home/mao/Github/models/research/slim')
import nets.nets_factory
SAVED_MODEL_DIR = "/home/mao/Github/tensorRT/resnet_v1_50_2016_08_28/"
graph = tf.Graph()
with graph.as_default():
    with tf.Session(config=config) as sess:
        tf_input = tf.placeholder(tf.float32, [None, 224, 224, 3], name='input')
        network_fn = nets.nets_factory.get_network_fn('resnet_v1_50', 1000,
                                                      is_training=False)
        tf_net, tf_end_points = network_fn(tf_input)
                
        saver = tf.train.Saver()
        saver.restore(sess, SAVED_MODEL_DIR+"resnet_v1_50.ckpt")
        
        tf_output = tf.identity(tf_net, name='logits')
        tf_output_classes = tf.argmax(tf_output, axis=1, name='classes')        
        #tf_output_classes = tf.reshape(tf_output_classes, (BATCH_SIZE,), name='classes')
        
        # freeze graph
        fp32_frozen_graph = tf.graph_util.convert_variables_to_constants(
            sess,
            sess.graph_def,
            output_node_names=['logits', 'classes']
        )

# 测试frozen graph

In [2]:
def benchmark_frozen_graph(frozen_graph, SAVED_MODEL_DIR=None, BATCH_SIZE=8):
    with tf.Session(graph=tf.Graph(), config=config) as sess:
        next_element = tf.convert_to_tensor(np.random.random((8, 224, 224, 3)))

        output_node = tf.import_graph_def(
            frozen_graph,
            return_elements=['classes'],
            name="")
        
        print('Warming up for 50 batches...')
        for _ in range (50):
            sess.run(['classes:0'], feed_dict={"input:0": sess.run(next_element)})

        num_predict = 0
        start_time = time.time()
        try:
            for i in range(0, 1000):        
                image_data = sess.run(next_element)    
                img = image_data
                output = sess.run(['classes:0'], feed_dict={"input:0": img})
                num_predict += len(output[0])
        except tf.errors.OutOfRangeError as e:
            pass

        print('Inference speed: %.2f samples/s'%(num_predict/(time.time()-start_time)))
        
        #Optionally, save model for serving if an ouput directory argument is presented
        if SAVED_MODEL_DIR:
            print('Saving model to %s'%SAVED_MODEL_DIR)
            tf.saved_model.simple_save(
                session=sess,
                export_dir=SAVED_MODEL_DIR,
                inputs={"input":tf.get_default_graph().get_tensor_by_name("input:0")},
                outputs={"classes":tf.get_default_graph().get_tensor_by_name("classes:0")},
                legacy_init_op=None
             )

# 保存FP32,并测试

In [3]:
FP32_SAVED_MODEL_DIR = "/home/mao/Github/tensorRT/model/Resnet_FP32/"
benchmark_frozen_graph(fp32_frozen_graph, FP32_SAVED_MODEL_DIR)

Warming up for 50 batches...
Inference speed: 245.79 samples/s
Saving model to /home/mao/Github/tensorRT/model/Resnet_FP32/


# 保存tensorRT的FP32,并测试

In [4]:
#Now we create the TFTRT FP32 engine
BATCH_SIZE = 8
converter = trt.TrtGraphConverter(input_graph_def=fp32_frozen_graph,
                                  max_batch_size=BATCH_SIZE,
                                  precision_mode=trt.TrtPrecisionMode.FP32,
                                  nodes_blacklist=['classes', 'logits'])
trt_fp32_graph = converter.convert()

In [5]:
TRT_FP32_SAVED_MODEL_DIR = "/home/mao/Github/tensorRT/model/Resnet_TRT_FP32/"
benchmark_frozen_graph(trt_fp32_graph, TRT_FP32_SAVED_MODEL_DIR)

Warming up for 50 batches...
Inference speed: 328.11 samples/s
Saving model to /home/mao/Github/tensorRT/model/Resnet_TRT_FP32/


# 保存tensorRT的INT8,并测试

In [6]:
num_calibration_batches = 2
BATCH_SIZE = 8
batched_input = np.zeros((BATCH_SIZE * num_calibration_batches, 224, 224, 3), dtype=np.float32)

with tf.Session(graph=tf.Graph(), config=config) as sess:
    # prepare dataset iterator
    next_element = tf.convert_to_tensor(np.random.random((8, 224, 224, 3)))
    for i in range(num_calibration_batches):
        print(batched_input[i*BATCH_SIZE:(i+1)*BATCH_SIZE, :].shape, sess.run(next_element).shape)
        batched_input[i*BATCH_SIZE:(i+1)*BATCH_SIZE, :] = sess.run(next_element)

#batched_input = tf.constant(batched_input)
print('Calibration data shape: ', batched_input.shape)

def calibration_input_fn_gen():
    for i in range(num_calibration_batches):
        yield batched_input[i*BATCH_SIZE:(i+1)*BATCH_SIZE, :]
        
calibration_input_fn = calibration_input_fn_gen()

(8, 224, 224, 3) (8, 224, 224, 3)
(8, 224, 224, 3) (8, 224, 224, 3)
Calibration data shape:  (16, 224, 224, 3)


In [7]:
#Now we create the TFTRT FP16 engine
converter = trt.TrtGraphConverter(input_graph_def=fp32_frozen_graph,
                                  max_batch_size=BATCH_SIZE,
                                  precision_mode=trt.TrtPrecisionMode.INT8,
                                  nodes_blacklist=['classes', 'logits'])
trt_int8_graph = converter.convert()


# Run calibration for num_calibration_batches times.
trt_int8_calibrated_graph = converter.calibrate(
      fetch_names=['classes:0'],
      num_runs=num_calibration_batches,
      feed_dict_fn=lambda: {"input:0": next(calibration_input_fn)})

In [8]:
INT8_SAVED_MODEL_DIR = "/home/mao/Github/tensorRT/model/Resnet_TRT_INT8/"
benchmark_frozen_graph(trt_int8_calibrated_graph, INT8_SAVED_MODEL_DIR)

Warming up for 50 batches...
Inference speed: 503.99 samples/s
Saving model to /home/mao/Github/tensorRT/model/Resnet_TRT_INT8/


# 分别保存frozen_graph的.pb和tensorRT的FP32和INT8的.pb

In [9]:
from tensorflow.python.framework import graph_io

In [10]:
graph_io.write_graph(fp32_frozen_graph, './model', 'fp32_frozen_graph.pb', as_text=False)

'./model/fp32_frozen_graph.pb'

In [11]:
graph_io.write_graph(trt_fp32_graph, './model', 'trt_fp32_graph.pb', as_text=False)

'./model/trt_fp32_graph.pb'

In [12]:
graph_io.write_graph(trt_int8_calibrated_graph, './model', 'trt_int8_calibrated_graph.pb', as_text=False)

'./model/trt_int8_calibrated_graph.pb'