Skip to content

Using TensorRT to Accelerate Inference

rainLiuplus edited this page Sep 28, 2018 · 4 revisions

Using TensorRT to Accelerate Inference

The core of TensorRT™ is a C++ library that facilitates high performance inference on NVIDIA graphics processing units (GPUs). TensorRT optimizes the speed of inference of the model on the deployment phase. The version of the TensorRT introduced by this document is 4.0.1.

Briefly Introduction of TensorRT

TensorRT Workflow

  1. Import model. We use TensorRT to parse a trained model. There are two ways to get the network definition. One is to import the model using the TensorRT parser library, which supports serialized models in following formats:

    • Caffe
    • ONNX
    • UFF(used for tensorflow)

    And the other is to define the model directly using the TensorRT API. You also can use the TensorRT API to add custom layers.

  2. Optimize engine. Once get the network definition. TensorRT performs the same set of model optimizations for the target deployment GPU. The output of this step is an optimized inference execution engine which we serialize a file on disk called a plan file.

  3. De-serialize engine. Load and de-serialize a saved plan file to create a Tensor engine object.

  4. Deploy Runtime. Deploy the generated TensorRT runtime engine to the target platform and use it to run inference on new data. TensorRT: Import and optimize trained models to generate inference engines TensorRT: Deploy generated runtime inference engine for inference

TensorRT Supported Layers

Activation Concatenation Constant Convolution Deconvolution
ElementWise Flatten FullyConnected Gather LRN
MatrixMultiply Padding Plugin Pooling Ragged SoftMax
Reduce RNN RNNv2 LSTM Scale
Shuffle SoftMax Squeeze TopK Unary

You can define your own custom layers via Custom Layer API provided by TensorRT.

TensorRT Supported Import methods

C++ API, python API, NvCaffeParser, NvUffParser, NvONNXParser.

TensorRT Optimizations and Optimization Performance Results

The optimizations:

  1. Layer and tensor fusion and elimination of unused layers;
  2. FP16 and INT8 reduced precision calibration;
  3. Target-specific autotuning;
  4. Efficient memory reuse.

ResNet-50 model performs 8x faster at under 7ms latency with the TensorFlow-TensorRT integration using NVIDIA Volta Tensor Cores versus running TensorFlow-only on the same GPU.

img

Installing TensorRT

You can follow the link below to install TensorRT.

TensorRT installation Guide

Note that if you want to use TensorRT Python API you should install PyCUDA firstly.

Some Examples

There are three detailed examples using Python.

  1. This example shows the complete process of using tensorflow on TensorRT: Generating TensorRT Engines from TensorFlow

  2. This example shows the complete process of using Caffe on TensorRT: Using TensorRT to Optimize Caffe Models in Python

  3. This example shows the complete process of using Pytorch on TensorRT, which is unsupported by the UFF converter: Manually Constructing a TensorRT Engine

This guideline demonstrates how to import model from onnx using Python: Import Model by ONNX Parser in Python.

More C++ samples can be found at: https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#c_samples_section

One Example using MMdnn

Here is an example using MMdnn to transfer inception_v3 model from pytorch (inception_v3.pth) to tensorflow (saved_model.pb) and using TensorRT to perform inference on the transfered model whose type is SavedModel. We use the picture stored in "./MMdnn/mmdnn/conversion/examples/data/seagull.jpg" to perform test.

1. Install the stable MMdnn

$ pip install mmdnn

2. Download inception_v3 model and transfer it from pytorch to Tensorflow

You can using the below lines to download the inception_v3 model or you can download it directly from model collection.

$ mmdownload -f pytorch -n inception_v3 -o './'

Next convert the pytorch model to IR structure.

$ mmtoir -f pytorch -d inception_v3 --inputShape 3,299,299 -n imagenet_inception_v3.pth

IR network structure is saved as [inception_v3.json].
IR network structure is saved as [inception_v3.pb].
IR weights are saved as [inception_v3.npy].

Then convert the IR structure to Tensorflow code snippet.

$ mmtocode -f tensorflow --IRModelPath inception_v3.pb --IRWeightPath inception_v3.npy --dstModelPath tf_inception_v3.py

Parse file [inception_v3.pb] with binary format successfully.
Target network code snippet is saved as [tf_inception_v3.py].

Finally convert the Tensorflow code snippet and IR weights to Tensorflow model.

$ mmtomodel -f tensorflow -in tf_inception_v3.py -iw inception_v3.npy -o tf_inception_v3 --dump_tag SERVING

Tensorflow file is saved as [tf_inception_v3/saved_model.pb], generated by [tf_inception_v3.py] and [inception_v3.npy].

For more details about converting models using MMdnn please refer this.

3. Using TensorRT and the transfered model to perform Inference

First, convert the Tensorflow model to UFF.

UffParser support two methods of import. One is to provide at least model stream and the name(s) of the desired output node(s). The other is importing tensorflow frozen protobuf file directly. In this example, we load the SavedModel and create the model stream to convert. We also need to provide the name(s) of output node(s). If you do not know the name, you can use tensorboard or use the below script to print the output node name (In general, the name of the output node is 'dense/BiasAdd').

# Run this script if you do not konw the name of the output node
import tensorflow as tf

export_dir = "./tf_inception_v3"
with tf.Session(graph=tf.Graph()) as sess:
    meta_graph_def = tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], export_dir)
    signature = meta_graph_def.signature_def
    signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY

    x_tensor_name = signature[signature_key].inputs['input'].name

    y_tensor_name = signature[signature_key].outputs['output'].name

    print('x_tensor_name: ', x_tensor_name)
    print('y_tensor_name: ', y_tensor_name)
INFO:tensorflow:Restoring parameters from ./tf_inception_v3/variables/variables
x_tensor_name:  input:0
y_tensor_name:  dense/BiasAdd:0
# Convert tensorflow model to UFF
import tensorflow as tf
import tensorrt as trt
from tensorrt.parsers import uffparser
import pycuda.driver as cuda
import pycuda.autoinit
import numpy as np
from PIL import Image
import time
import os
import uff
OUTPUT_NAMES = ['dense/BiasAdd']	# The name of inception_v3 output node
export_dir = "./tf_inception_v3"	# The directory path of inception_v3 saved model
with tf.Session(graph=tf.Graph()) as sess:
    tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], export_dir)
    graphdef = tf.get_default_graph().as_graph_def()
    frozen_graph = tf.graph_util.convert_variables_to_constants(sess,
    graphdef,
    OUTPUT_NAMES)    
    tf_model = tf.graph_util.remove_training_nodes(frozen_graph)
    uff_model = uff.from_tensorflow(tf_model, OUTPUT_NAMES)
INFO:tensorflow:Restoring parameters from ./tf_inception_v3/variables/variables
INFO:tensorflow:Froze 472 variables.
INFO:tensorflow:Converted 472 variables to const ops.
Using output node dense/BiasAdd
Converting to UFF graph
DEBUG: convert reshape to flatten node
DEBUG: convert reshape to flatten node
No. nodes: 1447

Second, import the UFF model to TensorRT and build an engine. These steps are quite similar to the steps of tf example of TensorRT and you can find more details here.

G_LOGGER = trt.infer.ConsoleLogger(trt.infer.LogSeverity.ERROR)
parser = uffparser.create_uff_parser()
parser.register_input("input", (3,299,299), 0)
parser.register_output(OUTPUT_NAMES[0])
engine = trt.utils.uff_to_trt_engine(G_LOGGER, uff_model, parser, 1, 1 << 20)
parser.destroy()

Third, get the test image. We use the image which is './MMdnn/mmdnn/conversion/example/data/seagull.jpg'.

path = './MMdnn/mmdnn/conversion/examples/data/seagull.jpg'
img = Image.open(path)
img = img.resize((299, 299))
x = np.array(img, dtype=np.float32)
x /= 255.0
x -= 0.5
x *= 2.0
img = x.transpose([2,0,1])  #TensorRT requires CHW format.
img = np.ascontiguousarray(img)

Forth, create the context for the engine. Again, more details please refer this.

runtime = trt.infer.create_infer_runtime(G_LOGGER)
context = engine.create_execution_context()
output = np.empty(1000, dtype = np.float32)

# Alocate device memory
d_input = cuda.mem_alloc(1 * img.nbytes)
d_output = cuda.mem_alloc(1 * output.nbytes)
bindings = [int(d_input), int(d_output)]
stream = cuda.Stream()
# Transfer input data to device
cuda.memcpy_htod_async(d_input, img, stream)
# Execute model
context.enqueue(1, bindings, stream.handle, None)
# Transfer predictions back
cuda.memcpy_dtoh_async(output, d_output, stream)
# Syncronize threads
stream.synchronize()

Last, get the prediction.

imagenet_file_path = './MMdnn/mmdnn/conversion/examples/data/imagenet_1000.txt'
LABELS = open(imagenet_file_path,'r').readlines()
print("Prediction: ", LABELS[np.argmax(output)])
Prediction:  n01608432 kite # The result is right due to imagenet_1000.txt does not include 'seagull'.

Reference

[1] https://devblogs.nvidia.com/tensorrt-3-faster-tensorflow-inference/

[2] https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html

[3] https://devblogs.nvidia.com/tensorrt-4-accelerates-translation-speech-recommender/