-
Notifications
You must be signed in to change notification settings - Fork 665
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #332 from NVIDIA-AI-IOT/SrivastavaKshitij-new_trt_ops
Srivastava kshitij new trt ops
- Loading branch information
Showing
31 changed files
with
1,259 additions
and
677 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,127 +1,47 @@ | ||
import os | ||
import glob | ||
import shutil | ||
import sys | ||
import torch | ||
from setuptools import setup, find_packages | ||
from setuptools.command.install import install | ||
from setuptools.command.develop import develop | ||
from distutils.cmd import Command | ||
from build import build | ||
|
||
package_data = {} | ||
|
||
plugins_user_options = [ | ||
('plugins', None, 'Build plugins'), | ||
('cuda-dir=', None, 'Location of CUDA (if not default location)'), | ||
('torch-dir=', None, 'Location of PyTorch (if not default location)'), | ||
('trt-inc-dir=', None, 'Location of TensorRT include files (if not default location)'), | ||
('trt-lib-dir=', None, 'Location of TensorRT libraries (if not default location)'), | ||
] | ||
|
||
|
||
def initialize_plugins_options(cmd_obj): | ||
cmd_obj.plugins = False | ||
cmd_obj.cuda_dir = None | ||
cmd_obj.torch_dir = None | ||
cmd_obj.trt_inc_dir = None | ||
cmd_obj.trt_lib_dir = None | ||
|
||
|
||
def run_plugins_compilation(cmd_obj): | ||
if cmd_obj.plugins: | ||
build_args = {} | ||
if cmd_obj.cuda_dir: | ||
build_args['cuda_dir'] = cmd_obj.cuda_dir | ||
if cmd_obj.torch_dir: | ||
build_args['torch_dir'] = cmd_obj.torch_dir | ||
if cmd_obj.trt_inc_dir: | ||
build_args['trt_inc_dir'] = cmd_obj.trt_inc_dir | ||
if cmd_obj.trt_lib_dir: | ||
build_args['trt_lib_dir'] = cmd_obj.trt_lib_dir | ||
|
||
print('Building in plugin support') | ||
build(**build_args) | ||
package_data['torch2trt'] = ['libtorch2trt.so'] | ||
|
||
|
||
class DevelopCommand(develop): | ||
description = "Builds the package and symlinks it into the PYTHONPATH" | ||
user_options = develop.user_options + plugins_user_options | ||
|
||
def initialize_options(self): | ||
develop.initialize_options(self) | ||
initialize_plugins_options(self) | ||
|
||
def finalize_options(self): | ||
develop.finalize_options(self) | ||
|
||
def run(self): | ||
run_plugins_compilation(self) | ||
develop.run(self) | ||
|
||
|
||
class InstallCommand(install): | ||
description = "Builds the package" | ||
user_options = install.user_options + plugins_user_options | ||
|
||
def initialize_options(self): | ||
install.initialize_options(self) | ||
initialize_plugins_options(self) | ||
|
||
def finalize_options(self): | ||
install.finalize_options(self) | ||
|
||
def run(self): | ||
run_plugins_compilation(self) | ||
install.run(self) | ||
|
||
|
||
class CleanCommand(Command): | ||
"""Custom clean command to tidy up the project root.""" | ||
PY_CLEAN_FILES = ['./build', './dist', './__pycache__', './*.pyc', './*.tgz', './*.egg-info'] | ||
description = "Command to tidy up the project root" | ||
user_options = [] | ||
|
||
def initialize_options(self): | ||
pass | ||
|
||
def finalize_options(self): | ||
pass | ||
|
||
def run(self): | ||
root_dir = os.path.dirname(os.path.realpath(__file__)) | ||
for path_spec in self.PY_CLEAN_FILES: | ||
# Make paths absolute and relative to this path | ||
abs_paths = glob.glob(os.path.normpath(os.path.join(root_dir, path_spec))) | ||
for path in [str(p) for p in abs_paths]: | ||
if not path.startswith(root_dir): | ||
# Die if path in CLEAN_FILES is absolute + outside this directory | ||
raise ValueError("%s is not a path inside %s" % (path, root_dir)) | ||
print('Removing %s' % os.path.relpath(path)) | ||
shutil.rmtree(path) | ||
|
||
cmd_list = { | ||
"Removing generated protobuf cc files": "find . -name '*.pb.cc' -print0 | xargs -0 rm -f;", | ||
"Removing generated protobuf h files": "find . -name '*.pb.h' -print0 | xargs -0 rm -f;", | ||
"Removing generated protobuf py files": "find . -name '*_pb2.py' -print0 | xargs -0 rm -f;", | ||
"Removing generated ninja files": "find . -name '*.ninja*' -print0 | xargs -0 rm -f;", | ||
"Removing generated o files": "find . -name '*.o' -print0 | xargs -0 rm -f;", | ||
"Removing generated so files": "find . -name '*.so' -print0 | xargs -0 rm -f;", | ||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension | ||
|
||
def trt_inc_dir(): | ||
return "/usr/include/aarch64-linux-gnu" | ||
|
||
def trt_lib_dir(): | ||
return "/usr/lib/aarch64-linux-gnu" | ||
|
||
ext_modules = [] | ||
|
||
plugins_ext_module = CUDAExtension( | ||
name='plugins', | ||
sources=[ | ||
'torch2trt/plugins/interpolate.cpp' | ||
], | ||
include_dirs=[ | ||
trt_inc_dir() | ||
], | ||
library_dirs=[ | ||
trt_lib_dir() | ||
], | ||
libraries=[ | ||
'nvinfer' | ||
], | ||
extra_compile_args={ | ||
'cxx': ['-DUSE_DEPRECATED_INTLIST'] if torch.__version__ < "1.5" else [], | ||
'nvcc': [] | ||
} | ||
) | ||
|
||
for cmd, script in cmd_list.items(): | ||
print("{}".format(cmd)) | ||
os.system(script) | ||
|
||
if '--plugins' in sys.argv: | ||
ext_modules.append(plugins_ext_module) | ||
sys.argv.remove('--plugins') | ||
|
||
setup( | ||
name='torch2trt', | ||
version='0.0.3', | ||
version='0.1.0', | ||
description='An easy to use PyTorch to TensorRT converter', | ||
cmdclass={ | ||
'install': InstallCommand, | ||
'clean': CleanCommand, | ||
'develop': DevelopCommand, | ||
}, | ||
packages=find_packages(), | ||
package_data=package_data | ||
ext_package='torch2trt', | ||
ext_modules=ext_modules, | ||
cmdclass={'build_ext': BuildExtension} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,23 @@ | ||
from torch2trt.torch2trt import * | ||
from torch2trt.module_test import add_module_test | ||
|
||
|
||
@tensorrt_converter('torch.nn.BatchNorm2d.forward') | ||
@tensorrt_converter("torch.nn.BatchNorm2d.forward", enabled=trt_version() < '7.0') | ||
def convert_BatchNorm2d(ctx): | ||
module = ctx.method_args[0] | ||
input = ctx.method_args[1] | ||
input_trt = trt_(ctx.network, input) | ||
output = ctx.method_return | ||
|
||
scale = module.weight.detach().cpu().numpy() / np.sqrt(module.running_var.detach().cpu().numpy() + module.eps) | ||
bias = module.bias.detach().cpu().numpy() - module.running_mean.detach().cpu().numpy() * scale | ||
|
||
scale = module.weight.detach().cpu().numpy() / np.sqrt( | ||
module.running_var.detach().cpu().numpy() + module.eps | ||
) | ||
bias = ( | ||
module.bias.detach().cpu().numpy() | ||
- module.running_mean.detach().cpu().numpy() * scale | ||
) | ||
power = np.ones_like(scale) | ||
|
||
layer = ctx.network.add_scale(input_trt, trt.ScaleMode.CHANNEL, bias, scale, power) | ||
|
||
output._trt = layer.get_output(0) | ||
output._trt = layer.get_output(0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
from torch2trt.torch2trt import * | ||
from torch2trt.module_test import add_module_test | ||
|
||
|
||
@tensorrt_converter('torch.nn.Conv2d.forward', enabled=trt_version() >= '7.0') | ||
@tensorrt_converter('torch.nn.Conv3d.forward', enabled=trt_version() >= '7.0') | ||
def convert_Conv_trt7(ctx): | ||
module = ctx.method_args[0] | ||
input = ctx.method_args[1] | ||
input_trt = trt_(ctx.network, input) | ||
output = ctx.method_return | ||
|
||
input_dim = input.dim() - 2 | ||
|
||
kernel_size = module.kernel_size | ||
if not isinstance(kernel_size, tuple): | ||
kernel_size = (kernel_size, ) * input_dim | ||
|
||
stride = module.stride | ||
if not isinstance(stride, tuple): | ||
stride = (stride, ) * input_dim | ||
|
||
padding = module.padding | ||
if not isinstance(padding, tuple): | ||
padding = (padding, ) * input_dim | ||
|
||
dilation = module.dilation | ||
if not isinstance(dilation, tuple): | ||
dilation = (dilation, ) * input_dim | ||
|
||
kernel = module.weight.detach().cpu().numpy() | ||
|
||
bias = None #trt.Weights(torch_dtype_to_trt(module.weight.dtype)) | ||
if module.bias is not None: | ||
bias = module.bias.detach().cpu().numpy() | ||
|
||
layer = ctx.network.add_convolution_nd( | ||
input=input_trt, | ||
num_output_maps=module.out_channels, | ||
kernel_shape=kernel_size, | ||
kernel=kernel, | ||
bias=bias) | ||
layer.stride_nd = stride | ||
layer.padding_nd = padding | ||
layer.dilation_nd = dilation | ||
|
||
if module.groups is not None: | ||
layer.num_groups = module.groups | ||
|
||
output._trt = layer.get_output(0) | ||
|
||
|
||
|
||
@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224, 224)], enabled=trt_version() >= '7.0') | ||
def test_Conv2d_basic_trt7(): | ||
return torch.nn.Conv2d(10, 5, kernel_size=1, stride=1, padding=0) | ||
|
||
|
||
@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224, 224)], enabled=trt_version() >= '7.0') | ||
def test_Conv2d_stride2_trt7(): | ||
return torch.nn.Conv2d(10, 5, kernel_size=1, stride=2, padding=0) | ||
|
||
|
||
@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224, 224)], enabled=trt_version() >= '7.0') | ||
def test_Conv2d_kernel3_trt7(): | ||
return torch.nn.Conv2d(10, 5, kernel_size=3, stride=2, padding=1) | ||
|
||
|
||
@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224, 224)], enabled=trt_version() >= '7.0') | ||
def test_Conv2d_dilation2_trt7(): | ||
return torch.nn.Conv2d(10, 5, kernel_size=3, stride=1, padding=1, dilation=2) | ||
|
||
|
||
@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 64, 64, 64)], enabled=trt_version() >= '7.0') | ||
def test_Conv3d_basic_trt7(): | ||
return torch.nn.Conv3d(10, 5, kernel_size=1, stride=1, padding=0) | ||
|
||
|
||
@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 64, 64, 64)], enabled=trt_version() >= '7.0') | ||
def test_Conv3d_stride2_trt7(): | ||
return torch.nn.Conv3d(10, 5, kernel_size=1, stride=2, padding=0) | ||
|
||
|
||
@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 64, 64, 64)], enabled=trt_version() >= '7.0') | ||
def test_Conv3d_kernel3_trt7(): | ||
return torch.nn.Conv3d(10, 5, kernel_size=3, stride=2, padding=1) | ||
|
||
|
||
@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 64, 64, 64)], enabled=trt_version() >= '7.0') | ||
def test_Conv3d_dilation2_trt7(): | ||
return torch.nn.Conv3d(10, 5, kernel_size=3, stride=1, padding=1, dilation=2) |
Oops, something went wrong.