Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expand_as + avg_pool1d converters were added #699

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
termcolor
4 changes: 2 additions & 2 deletions scripts/build_contrib.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ pushd /tmp/TensorRT
git sparse-checkout set /tools/pytorch-quantization/
git apply --reject --whitespace=fix pytorch_nvidia_quantization.patch
cd tools/pytorch-quantization/
python setup.py install
sudo python3 setup.py install
popd

pushd $parentdir
python3 setup.py install --plugins --contrib
sudo python3 setup.py install --plugins --contrib
popd


5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension
from packaging import version

REQUIREMENTS_PATH = 'requirements.txt'

with open(REQUIREMENTS_PATH, 'r') as file:
required_libraries = file.read().splitlines()

def trt_inc_dir():
return "/usr/include/aarch64-linux-gnu"
Expand Down Expand Up @@ -55,5 +59,6 @@ def trt_lib_dir():
packages=find_packages(exclude=exclude_dir),
ext_package='torch2trt',
ext_modules=ext_modules,
install_requires=required_libraries,
cmdclass={'build_ext': BuildExtension}
)
1 change: 1 addition & 0 deletions torch2trt/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,4 @@
from .transpose import *
from .unary import *
from .view import *
from .zeros import *
65 changes: 60 additions & 5 deletions torch2trt/converters/avg_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,55 @@
from torch2trt.module_test import add_module_test


@tensorrt_converter('torch.nn.functional.avg_pool1d')
def convert_avg_pool1d(ctx):
# At the time of this implementation, TensorRT 8.x does not yet support avg pooling in 1D using `add_pooling_nd(...)`.
# As such, we use a workaround here, by unsqueezing another dimension into the input (thus transforming it from
# (N, C, L) to (N, C, L, 1)) so that we can use 2D max pooling across the last three dimensions.

input = get_arg(ctx, 'input', pos=0, default=None)
input_trt = trt_(ctx.network, input)
output = ctx.method_return

kernel_size = get_arg(ctx, 'kernel_size', pos=1, default=None)
stride = get_arg(ctx, 'stride', pos=2, default=None)
padding = get_arg(ctx, 'padding', pos=3, default=0)
ceil_mode = get_arg(ctx, 'ceil_mode', pos=4, default=False)
count_include_pad = get_arg(ctx, 'count_include_pad', pos=5, default=True)

# Convert inputs to be 2d compatible as inputs will always be 1d.
kernel_size = (kernel_size, 1)
stride = kernel_size if not stride else (stride, 1)
padding = (padding, 0)

# Shuffle layer to unsqueeze another dimension for 2D max pooling.
unsqueeze_layer = ctx.network.add_shuffle(input_trt)
set_layer_precision(ctx, unsqueeze_layer)
unsqueeze_layer.reshape_dims = tuple([*input_trt.shape, 1])
unsqueeze_trt = unsqueeze_layer.get_output(0)

# Use 2D max pooling here to fake 1D max pooling.
layer = ctx.network.add_pooling_nd(
input=unsqueeze_trt,
type=trt.PoolingType.AVERAGE,
window_size=kernel_size,
)
set_layer_precision(ctx, layer)
layer.stride_nd = stride
layer.padding_nd = padding
layer.average_count_excludes_padding = not count_include_pad

if ceil_mode:
layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP

pooling_trt = layer.get_output(0)

# Shuffle layer to squeeze out dimension that was just added for 2D max pooling so return is still in 1D.
squeeze_layer = ctx.network.add_shuffle(pooling_trt)
set_layer_precision(ctx, squeeze_layer)
squeeze_layer.reshape_dims = tuple(pooling_trt.shape[:-1])
output._trt = squeeze_layer.get_output(0)

@tensorrt_converter("torch.nn.functional.avg_pool2d", enabled=trt_version() < '7.0')
def convert_avg_pool2d(ctx):
# parse args
Expand Down Expand Up @@ -83,12 +132,14 @@ def convert_avg_pool_trt7(ctx):
layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP

output._trt = layer.get_output(0)


@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 4, 6)])
@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 5, 7)])
def test_avg_pool2d_without_ceil_mode():
return torch.nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False)
return torch.nn.AvgPool2d(
kernel_size=3, stride=2, padding=1, ceil_mode=False
)


@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 4, 6)])
Expand All @@ -102,10 +153,14 @@ def test_avg_pool2d_with_ceil_mode():
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 4, 6)], enabled=trt_version() >= '7.0')
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 5, 7)], enabled=trt_version() >= '7.0')
def test_avg_pool3d_without_ceil_mode_trt7():
return torch.nn.AvgPool3d(kernel_size=3, stride=2, padding=1, ceil_mode=False)
return torch.nn.AvgPool3d(
kernel_size=3, stride=2, padding=1, ceil_mode=False
)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 4, 6)], enabled=trt_version() >= '7.0')
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 5, 7)], enabled=trt_version() >= '7.0')
def test_avg_pool3d_with_ceil_mode_trt7():
return torch.nn.AvgPool3d(kernel_size=3, stride=2, padding=1, ceil_mode=True, count_include_pad=False) # TRT does not support ceil_mode=True && count_include_pad=True
return torch.nn.AvgPool3d(
kernel_size=3, stride=2, padding=1, ceil_mode=True, count_include_pad=False
) # TRT does not support ceil_mode=True && count_include_pad=True
2 changes: 1 addition & 1 deletion torch2trt/converters/expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from torch2trt.module_test import add_module_test


@tensorrt_converter('torch.Tensor.expand_as')
@tensorrt_converter('torch.Tensor.expand')
def convert_expand(ctx):
input = ctx.method_args[0]
sizes = ctx.method_args[1:]
output = ctx.method_return

inshape = tuple(input.shape)[1:] # exclude batch
Expand Down
66 changes: 66 additions & 0 deletions torch2trt/converters/zeros.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from torch2trt.torch2trt import *
from torch2trt.module_test import add_module_test


def _set_layer_precision(ctx, layer):
# Supported TRT precisions as given by torch2trt_kwargs.
INT8_MODE = "int8_mode"
FP16_MODE = "fp16_mode"

# Check that args exist as expected in torch2trt_kwargs.
trt_kwargs = ctx.torch2trt_kwargs
assert INT8_MODE in trt_kwargs
assert FP16_MODE in trt_kwargs

is_int8 = trt_kwargs.get(INT8_MODE, False)
is_fp16 = trt_kwargs.get(FP16_MODE, False)

if is_int8:
layer.precision = trt.int8
layer.set_output_type(0, trt.int8)
elif is_fp16:
layer.precision = trt.float16
layer.set_output_type(0, trt.float16)


@tensorrt_converter('torch.zeros')
def convert_zeros(ctx):
tensor = ctx.method_return

# Implementation copied from add_trt_constant.
shape = tuple(tensor.shape[1:])
array = tensor[0].detach().cpu().numpy()
layer = ctx.network.add_constant(shape, array)

_set_layer_precision(ctx, layer)

tensor._trt = layer.get_output(0)


class Zeros(torch.nn.Module):
def __init__(self, *size):
super().__init__()
self.size = size

def forward(self, x):
return x + torch.zeros(*self.size, device=torch.device('cuda'))


@add_module_test(torch.float32, torch.device('cuda'), [(1, 2, 3, 4)])
def test_zeros():
return Zeros((1, 2, 3, 4))


@add_module_test(torch.float32, torch.device('cuda'), [(1, 2, 3, 4)])
def test_zeros_var_args():
return Zeros(1, 2, 3, 4)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 2, 3, 4)], fp16_mode=True)
def test_zeros_fp16_mode():
return Zeros(1, 2, 3, 4)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 2, 3, 4)], int8_mode=True)
def test_zeros_int8_mode():
return Zeros(1, 2, 3, 4)
4 changes: 0 additions & 4 deletions torch2trt/module_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
import torch
import torchvision


class ModuleTest(object):
def __init__(self, module_fn, dtype, device, input_shapes, **torch2trt_kwargs):
self.module_fn = module_fn
Expand Down