diff --git a/torch2trt/converters/BatchNorm3d.py b/torch2trt/converters/BatchNorm3d.py index cb620300..1ce3af63 100644 --- a/torch2trt/converters/BatchNorm3d.py +++ b/torch2trt/converters/BatchNorm3d.py @@ -21,3 +21,9 @@ def convert_BatchNorm3d(ctx): layer = ctx.network.add_scale(input_trt, trt.ScaleMode.CHANNEL, bias, scale, power) output._trt = layer.get_output(0) + + +@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 16, 16, 16)]) +@add_module_test(torch.float32, torch.device('cuda'), [(2, 3, 16, 16, 16)], max_batch_size=2) +def test_BatchNorm3d_basic(): + return torch.nn.BatchNorm3d(3) \ No newline at end of file diff --git a/torch2trt/converters/__init__.py b/torch2trt/converters/__init__.py index 25143f3e..409d0caf 100644 --- a/torch2trt/converters/__init__.py +++ b/torch2trt/converters/__init__.py @@ -7,6 +7,7 @@ from .AdaptiveAvgPool2d import * from .BatchNorm1d import * from .BatchNorm2d import * +from .BatchNorm3d import * from .clone import * from .conv_functional import * from .Conv import *