Skip to content

Inconsistent Results between TensorRT and PyTorch when Converting Torch Models using Conv1d, Conv2d, and Conv3d Operators #3503

@hongliyu0716

Description

@hongliyu0716

Description

I have encountered an issue while converting Torch models to TensorRT, where the output results from TensorRT differ from those obtained in PyTorch. This inconsistency specifically arises when using the Conv1d, Conv2d, and Conv3d operators.

Environment

TensorRT Version: 8.6.1.6

NVIDIA GPU: NVIDIA GeForce RTX 3090

NVIDIA Driver Version: 525.116.03

CUDA Version: 12.0

CUDNN Version: 8.9.0

Operating System: Ubuntu 18.04

Python Version: 3.9.18

PyTorch Version: 2.1.1

torch2trt Version: 0.4.0

Steps To Reproduce

Commands or scripts:

  1. Conv3d
from torch2trt import torch2trt
import torch
from torch.nn import Module

model = torch.nn.Conv3d(2,3,3,2,(1, 2, 3),1,1,True,'circular',).cuda()
input_data = torch.randn([2, 2, 4, 4, 4], dtype=torch.float32).cuda()
model_trt = torch2trt(model, [input_data])
y = model(input_data)
y_trt = model_trt(input_data)

# check the output against PyTorch
print(torch.max(torch.abs(y - y_trt)))

The output is tensor(1.1362, device='cuda:0', grad_fn=<MaxBackward1>)

  1. Conv2d
import torch
from torch.nn import Module
from torch2trt import torch2trt

model = torch.nn.Conv2d(12,6,(3, 2),(1, 2),(1, 1),(1, 1),3,False,'circular',).eval().cuda()
input_data = torch.randn([2, 12, 5, 7], dtype=torch.float32).cuda()
model_trt = torch2trt(model, [input_data])
output = model(input_data)
output_trt = model_trt(input_data)

print(torch.max(torch.abs(output - output_trt)))

The output is tensor(1.1816, device='cuda:0', grad_fn=<MaxBackward1>)
3. Conv1d

import torch
from torch.nn import Module
from torch2trt import torch2trt

model = torch.nn.Conv1d(2,3,3,2,(1,),1,1,True,'reflect',).eval().cuda()
input_data = torch.randn([2, 2, 4], dtype=torch.float32).cuda()
model_trt = torch2trt(model, [input_data])
output = model(input_data)
output_trt = model_trt(input_data)

print(torch.max(torch.abs(output - output_trt)))

The output is tensor(0.8214, device='cuda:0', grad_fn=<MaxBackward1>)

I would greatly appreciate any insights or guidance on resolving this issue.

Metadata

Metadata

Assignees

Labels

triagedIssue has been triaged by maintainers

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions