Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions torch2trt/converters/BatchNorm3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,9 @@ def convert_BatchNorm3d(ctx):
layer = ctx.network.add_scale(input_trt, trt.ScaleMode.CHANNEL, bias, scale, power)

output._trt = layer.get_output(0)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 16, 16, 16)])
@add_module_test(torch.float32, torch.device('cuda'), [(2, 3, 16, 16, 16)], max_batch_size=2)
def test_BatchNorm3d_basic():
return torch.nn.BatchNorm3d(3)
1 change: 1 addition & 0 deletions torch2trt/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .AdaptiveAvgPool2d import *
from .BatchNorm1d import *
from .BatchNorm2d import *
from .BatchNorm3d import *
from .clone import *
from .conv_functional import *
from .Conv import *
Expand Down