Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-703] TensorRT runtime integration (#11325)
Browse files Browse the repository at this point in the history
* [MXNET-703] TensorRT runtime integration

Co-authored-by: Clement Fuji-Tsang <caenorst@hotmail.com>
Co-authored-by: Kellen Sunderland <kellen.sunderland@gmail.com>

* correctly assign self._optimized_symbol in executor

* declare GetTrtCompatibleSubsets and ReplaceSubgraph only if MXNET_USE_TENSORRT

* add comments in ReplaceSubgraph

* Addressing Haibin's code review points

* Check that shared_buffer is not empty when USE_TENSORRT is set

* Added check that TensorRT binding is for inference only

* Removed redundant decl.

* WIP Refactored TRT integration and tests

* Add more build guards, remove unused code

* Remove ccache report

* Remove redundant const in declaration

* Clean Cmake TRT files

* Remove TensorRT env var usage

We don't want to use environment variables with TensorRT yet, the
logic being that we want to try and have as much fwd compatiblity as
possible when working on an experimental feature.  Were we to add
env vars they would have to be gaurenteed to work in the future until
a major version change.  Moving the functionality to a contrib call
reduces this risk.

* Use contrib optimize_graph instaed of bind

* Clean up cycle detector

* Convert lenet test to contrib optimize

* Protect interface with trt build flag

* Fix whitespace issues

* Add another build guard to c_api

* Move get_optimized_symbol to contrib area

* Ignore gz files in test folder

* Make trt optimization implicit

* Remove unused declaration

* Replace build guards with runtime errors

* Change default value of TensorRT to off

This is change applies to both TensorRT and non-TensorRT builds.

* Warn user when TRT not active at runtime

* Move TensorRTBind declaration, add descriptive errors

* Test TensorRT graph execution, fix bugs

* Fix lint and whitespace issues

* Fix typo

* Removed default value for set_use_tensorrt

* Improved documentation and fixed spacing issues

* Move static exec funcs to util files

* Update comments to match util style

* Apply const to loop element

* Fix a few namespace issues

* Make static funcs inline to avoid compiler warning

* Remove unused inference code from lenet5_train

* Add explicit trt contrib bind, update tests to use it

* Rename trt bind call

* Remove documentation that is not needed for trt

* Reorder arguments, allow position calling
  • Loading branch information
mkolod authored and marcoabreu committed Aug 10, 2018
1 parent af15853 commit c053262
Show file tree
Hide file tree
Showing 42 changed files with 4,138 additions and 303 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Expand Up @@ -26,3 +26,6 @@
[submodule "3rdparty/tvm"]
path = 3rdparty/tvm
url = https://github.com/dmlc/tvm
[submodule "3rdparty/onnx-tensorrt"]
path = 3rdparty/onnx-tensorrt
url = https://github.com/onnx/onnx-tensorrt.git
1 change: 1 addition & 0 deletions 3rdparty/onnx-tensorrt
Submodule onnx-tensorrt added at e7be19
31 changes: 31 additions & 0 deletions CMakeLists.txt
Expand Up @@ -37,6 +37,7 @@ mxnet_option(ENABLE_CUDA_RTC "Build with CUDA runtime compilation support"
mxnet_option(BUILD_CPP_EXAMPLES "Build cpp examples" ON)
mxnet_option(INSTALL_EXAMPLES "Install the example source files." OFF)
mxnet_option(USE_SIGNAL_HANDLER "Print stack traces on segfaults." OFF)
mxnet_option(USE_TENSORRT "Enable infeference optimization with TensorRT." OFF)

message(STATUS "CMAKE_SYSTEM_NAME ${CMAKE_SYSTEM_NAME}")
if(USE_CUDA AND NOT USE_OLDCMAKECUDA)
Expand Down Expand Up @@ -185,6 +186,36 @@ if(USE_VTUNE)
list(APPEND mxnet_LINKER_LIBS dl)
endif()

if(USE_TENSORRT)
message(STATUS "Using TensorRT")
set(ONNX_PATH 3rdparty/onnx-tensorrt/third_party/onnx/build/)
set(ONNX_TRT_PATH 3rdparty/onnx-tensorrt/build/)

include_directories(${ONNX_PATH})
include_directories(3rdparty/onnx-tensorrt/)
include_directories(3rdparty/)
add_definitions(-DMXNET_USE_TENSORRT=1)
add_definitions(-DONNX_NAMESPACE=onnx)

find_package(Protobuf REQUIRED)

find_library(ONNX_LIBRARY NAMES libonnx.so REQUIRED
PATHS ${ONNX_PATH}
DOC "Path to onnx library.")
find_library(ONNX_PROTO_LIBRARY NAMES libonnx_proto.so REQUIRED
PATHS ${ONNX_PATH}
DOC "Path to onnx_proto library.")
find_library(ONNX_TRT_RUNTIME_LIBRARY NAMES libnvonnxparser_runtime.so REQUIRED
PATHS ${ONNX_TRT_PATH}
DOC "Path to onnx_proto library.")
find_library(ONNX_TRT_PARSER_LIBRARY NAMES libnvonnxparser.so REQUIRED
PATHS ${ONNX_TRT_PATH}
DOC "Path to onnx_proto library.")

list(APPEND mxnet_LINKER_LIBS libnvinfer.so ${ONNX_TRT_PARSER_LIBRARY} ${ONNX_TRT_RUNTIME_LIBRARY}
${ONNX_PROTO_LIBRARY} ${ONNX_LIBRARY} ${PROTOBUF_LIBRARY})
endif()

if(USE_MKLDNN)
include(cmake/MklDnn.cmake)
# CPU architecture (e.g., C5) can't run on another architecture (e.g., g3).
Expand Down
28 changes: 28 additions & 0 deletions Jenkinsfile
Expand Up @@ -30,6 +30,7 @@ mx_cmake_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/li
mx_cmake_lib_debug = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests'
mx_cmake_mkldnn_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, build/3rdparty/mkldnn/src/libmkldnn.so.0'
mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libiomp5.so, lib/libmkldnn.so.0, lib/libmklml_intel.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a'
mx_tensorrt_lib = 'lib/libmxnet.so, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so'
// timeout in minutes
max_time = 120

Expand Down Expand Up @@ -301,6 +302,17 @@ core_logic: {
}
}
},
'TensorRT': {
node(NODE_LINUX_CPU) {
ws('workspace/build-tensorrt') {
timeout(time: max_time, unit: 'MINUTES') {
utils.init_git()
utils.docker_run('ubuntu_gpu_tensorrt', 'build_ubuntu_gpu_tensorrt', false)
utils.pack_lib('tensorrt', mx_tensorrt_lib)
}
}
}
},
'Build CPU windows':{
node(NODE_WINDOWS_CPU) {
timeout(time: max_time, unit: 'MINUTES') {
Expand Down Expand Up @@ -616,6 +628,22 @@ core_logic: {
}
}
},
'Python3: TensorRT GPU': {
node(NODE_LINUX_GPU_P3) {
ws('workspace/build-tensorrt') {
timeout(time: max_time, unit: 'MINUTES') {
try {
utils.init_git()
utils.unpack_lib('tensorrt', mx_tensorrt_lib)
utils.docker_run('ubuntu_gpu_tensorrt', 'unittest_ubuntu_tensorrt_gpu', true)
utils.publish_test_coverage()
} finally {
utils.collect_test_results_unix('nosetests_tensorrt.xml', 'nosetests_python3_tensorrt_gpu.xml')
}
}
}
}
},
'Scala: CPU': {
node(NODE_LINUX_CPU) {
ws('workspace/ut-scala-cpu') {
Expand Down
8 changes: 8 additions & 0 deletions Makefile
Expand Up @@ -91,6 +91,14 @@ else
endif
CFLAGS += -I$(TPARTYDIR)/mshadow/ -I$(TPARTYDIR)/dmlc-core/include -fPIC -I$(NNVM_PATH)/include -I$(DLPACK_PATH)/include -I$(TPARTYDIR)/tvm/include -Iinclude $(MSHADOW_CFLAGS)
LDFLAGS = -pthread $(MSHADOW_LDFLAGS) $(DMLC_LDFLAGS)


ifeq ($(USE_TENSORRT), 1)
CFLAGS += -I$(ROOTDIR) -I$(TPARTYDIR) -DONNX_NAMESPACE=$(ONNX_NAMESPACE) -DMXNET_USE_TENSORRT=1
LDFLAGS += -lprotobuf -pthread -lonnx -lonnx_proto -lnvonnxparser -lnvonnxparser_runtime -lnvinfer -lnvinfer_plugin
endif
# -L/usr/local/lib

ifeq ($(DEBUG), 1)
NVCCFLAGS += -std=c++11 -Xcompiler -D_FORCE_INLINES -g -G -O0 -ccbin $(CXX) $(MSHADOW_NVCCFLAGS)
else
Expand Down
14 changes: 7 additions & 7 deletions amalgamation/amalgamation.py
Expand Up @@ -23,13 +23,12 @@
import platform

blacklist = [
'Windows.h', 'cublas_v2.h', 'cuda/tensor_gpu-inl.cuh',
'cuda_runtime.h', 'cudnn.h', 'cudnn_lrn-inl.h', 'curand.h', 'curand_kernel.h',
'glog/logging.h', 'io/azure_filesys.h', 'io/hdfs_filesys.h', 'io/s3_filesys.h',
'kvstore_dist.h', 'mach/clock.h', 'mach/mach.h',
'malloc.h', 'mkl.h', 'mkl_cblas.h', 'mkl_vsl.h', 'mkl_vsl_functions.h',
'nvml.h', 'opencv2/opencv.hpp', 'sys/stat.h', 'sys/types.h', 'cuda.h', 'cuda_fp16.h',
'omp.h', 'execinfo.h', 'packet/sse-inl.h', 'emmintrin.h', 'thrust/device_vector.h',
'Windows.h', 'cublas_v2.h', 'cuda/tensor_gpu-inl.cuh', 'cuda_runtime.h', 'cudnn.h',
'cudnn_lrn-inl.h', 'curand.h', 'curand_kernel.h', 'glog/logging.h', 'io/azure_filesys.h',
'io/hdfs_filesys.h', 'io/s3_filesys.h', 'kvstore_dist.h', 'mach/clock.h', 'mach/mach.h',
'malloc.h', 'mkl.h', 'mkl_cblas.h', 'mkl_vsl.h', 'mkl_vsl_functions.h', 'NvInfer.h', 'nvml.h',
'opencv2/opencv.hpp', 'sys/stat.h', 'sys/types.h', 'cuda.h', 'cuda_fp16.h', 'omp.h',
'onnx/onnx.pb.h', 'execinfo.h', 'packet/sse-inl.h', 'emmintrin.h', 'thrust/device_vector.h',
'cusolverDn.h', 'internal/concurrentqueue_internal_debug.h', 'relacy/relacy_std.hpp',
'relacy_shims.h', 'ittnotify.h', 'shared_mutex'
]
Expand Down Expand Up @@ -150,6 +149,7 @@ def expand(x, pending, stage):
h not in sysheaders and
'mkl' not in h and
'nnpack' not in h and
'tensorrt' not in h and
not h.endswith('.cuh')): sysheaders.append(h)
else:
expand.treeDepth += 1
Expand Down
41 changes: 41 additions & 0 deletions ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt
@@ -0,0 +1,41 @@
# -*- mode: dockerfile -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
# Dockerfile to run MXNet on Ubuntu 16.04 for CPU

FROM nvidia/cuda:9.0-cudnn7-devel

WORKDIR /work/deps

COPY install/ubuntu_core.sh /work/
RUN /work/ubuntu_core.sh
COPY install/deb_ubuntu_ccache.sh /work/
RUN /work/deb_ubuntu_ccache.sh
COPY install/ubuntu_python.sh /work/
RUN /work/ubuntu_python.sh
COPY install/tensorrt.sh /work
RUN /work/tensorrt.sh

ARG USER_ID=0
COPY install/ubuntu_adduser.sh /work/
RUN /work/ubuntu_adduser.sh

COPY runtime_functions.sh /work/

WORKDIR /work/mxnet
ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib
45 changes: 45 additions & 0 deletions ci/docker/install/tensorrt.sh
@@ -0,0 +1,45 @@
#!/bin/bash

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# Install gluoncv since we're testing Gluon models as well
pip2 install gluoncv==0.2.0
pip3 install gluoncv==0.2.0

# Install Protobuf
# Install protoc 3.5 and build protobuf here (for onnx and onnx-tensorrt)
pushd .
cd ..
apt-get update
apt-get install -y automake libtool
git clone --recursive -b 3.5.1.1 https://github.com/google/protobuf.git
cd protobuf
./autogen.sh
./configure
make -j$(nproc)
make install
ldconfig
popd

# Install TensorRT
echo "TensorRT build enabled. Installing TensorRT."
wget -qO tensorrt.deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64/nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0_1-1_amd64.deb
dpkg -i tensorrt.deb
apt-get update
apt-get install -y --allow-downgrades libnvinfer-dev
rm tensorrt.deb
65 changes: 65 additions & 0 deletions ci/docker/runtime_functions.sh
Expand Up @@ -414,6 +414,60 @@ build_ubuntu_gpu() {
build_ubuntu_gpu_cuda91_cudnn7
}

build_ubuntu_gpu_tensorrt() {

set -ex

build_ccache_wrappers

# Build ONNX
pushd .
echo "Installing ONNX."
cd 3rdparty/onnx-tensorrt/third_party/onnx
rm -rf build
mkdir -p build
cd build
cmake \
-DCMAKE_CXX_FLAGS=-I/usr/include/python${PYVER}\
-DBUILD_SHARED_LIBS=ON ..\
-G Ninja
ninja -v
export LIBRARY_PATH=`pwd`:`pwd`/onnx/:$LIBRARY_PATH
export CPLUS_INCLUDE_PATH=`pwd`:$CPLUS_INCLUDE_PATH
popd

# Build ONNX-TensorRT
pushd .
cd 3rdparty/onnx-tensorrt/
mkdir -p build
cd build
cmake ..
make -j$(nproc)
export LIBRARY_PATH=`pwd`:$LIBRARY_PATH
popd

mkdir -p /work/mxnet/lib/
cp 3rdparty/onnx-tensorrt/third_party/onnx/build/*.so /work/mxnet/lib/
cp -L 3rdparty/onnx-tensorrt/build/libnvonnxparser_runtime.so.0 /work/mxnet/lib/
cp -L 3rdparty/onnx-tensorrt/build/libnvonnxparser.so.0 /work/mxnet/lib/

rm -rf build
make \
DEV=1 \
USE_BLAS=openblas \
USE_CUDA=1 \
USE_CUDA_PATH=/usr/local/cuda \
USE_CUDNN=1 \
USE_OPENCV=0 \
USE_DIST_KVSTORE=0 \
USE_TENSORRT=1 \
USE_JEMALLOC=0 \
USE_GPERFTOOLS=0 \
ONNX_NAMESPACE=onnx \
CUDA_ARCH="-gencode arch=compute_70,code=compute_70"\
-j$(nproc)
}

build_ubuntu_gpu_mkldnn() {
set -ex

Expand Down Expand Up @@ -610,6 +664,15 @@ unittest_ubuntu_python3_gpu_nocudnn() {
nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file nosetests_gpu.xml --verbose tests/python/gpu
}

unittest_ubuntu_tensorrt_gpu() {
set -ex
export PYTHONPATH=./python/
export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
export LD_LIBRARY_PATH=/work/mxnet/lib:$LD_LIBRARY_PATH
python tests/python/tensorrt/lenet5_train.py
nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file nosetests_trt_gpu.xml --verbose tests/python/tensorrt/
}

# quantization gpu currently only runs on P3 instances
# need to separte it from unittest_ubuntu_python2_gpu()
unittest_ubuntu_python2_quantization_gpu() {
Expand Down Expand Up @@ -961,3 +1024,5 @@ EOF
declare -F | cut -d' ' -f3
echo
fi


7 changes: 7 additions & 0 deletions include/mxnet/c_api.h
Expand Up @@ -1761,6 +1761,13 @@ MXNET_DLL int MXExecutorReshape(int partial_shaping,
NDArrayHandle** aux_states,
ExecutorHandle shared_exec,
ExecutorHandle *out);

/*!
* \brief get optimized graph from graph executor
*/
MXNET_DLL int MXExecutorGetOptimizedSymbol(ExecutorHandle handle,
SymbolHandle *out);

/*!
* \brief set a call back to notify the completion of operation
*/
Expand Down
1 change: 1 addition & 0 deletions include/mxnet/executor.h
Expand Up @@ -166,6 +166,7 @@ class Executor {
std::unordered_map<std::string, NDArray>*
shared_data_arrays = nullptr,
Executor* shared_exec = nullptr);

/*!
* \brief the prototype of user-defined monitor callback
*/
Expand Down
16 changes: 16 additions & 0 deletions python/mxnet/base.py
Expand Up @@ -729,3 +729,19 @@ def write_all_str(module_file, module_all_list):
module_op_file.close()
write_all_str(module_internal_file, module_internal_all)
module_internal_file.close()

def cint(init_val=0):
"""create a C int with an optional initial value"""
return C.c_int(init_val)

def int_addr(x):
"""given a c_int, return it's address as an int ptr"""
x_addr = C.addressof(x)
int_p = C.POINTER(C.c_int)
x_int_addr = C.cast(x_addr, int_p)
return x_int_addr

def checked_call(f, *args):
"""call a cuda function and check for success"""
error_t = f(*args)
assert error_t == 0, "Failing cuda call %s returns %s." % (f.__name__, error_t)
1 change: 1 addition & 0 deletions python/mxnet/contrib/__init__.py
Expand Up @@ -32,3 +32,4 @@
from . import io
from . import quantization
from . import quantization as quant
from . import tensorrt

0 comments on commit c053262

Please sign in to comment.