-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TensorRT 6.0 ONNX_Parser doesn't support the ONNX model exported by PyTorch 1.3.1 #376
Comments
Hi @RizhaoCai , PyTorch 1.3 + TensorRT 6 is a known incompatability. Please use PyTorch <= 1.2 for TensorRT 6, or upgrade to TensorRT 7 (when available for Jetson since you mentioned TX2). |
I am trying to convert onnx to tensorrt on jetson nano with pytorch 1.1 and tensorRT 6. |
Hi @Akshaysharma29 , The ONNX parser failed to parse your model, you need to print out the errors to see why. Something like this: import sys
import tensorrt as trt
if __name__ == "__main__":
# Building engine
with trt.Builder(TRT_LOGGER) as builder, \
builder.create_network() as network, \
trt.OnnxParser(network, TRT_LOGGER) as parser:
# Fill network atrributes with information by parsing model
with open("model.onnx", "rb") as f:
if not parser.parse(f.read()):
print('ERROR: Failed to parse the ONNX file: {}'.format(args.onnx))
for error in range(parser.num_errors):
print(parser.get_error(error))
sys.exit(1) |
Hello, @Akshaysharma29 Sometimes, the error can be due to that the forward function of your model has dynamic operations, like tensor.view(-1). |
I have referred your link @RizhaoCai @rmccorm4 |
For that error, I would try building this repo's 19.12 (TRT 6) branch from source, and then try parsing the model again. I don't really have experience building it on Jetson, but it seems to be documented in the README now for Jetpack users. |
Finally bug is solved Thanks for the help @RizhaoCai and @rmccorm4
|
Description
TensorRT 6.0 ONNX Parser doesn't support the ONNX model exported by PyTorch 1.3.1.
The TensorRT ONNX Parser seems not well compatible with the new PyTorch version. 1.3 or 1.4.
I have a Jetson TX2 (Jetpack 4.3, TensorRT6) to deploy my model.
This is my workflow:
However, it will tell when building the engine that
[TensorRT] ERROR: Network must have at least one output
Although there are some issues related to this error, e.g. #319, #286, but they did not take the PyTorch version into account. So here I point it out. This issue is open for reminding those who have the same problem but don't know how to solve.
At the very beginning, I thought to change to opset 7 may help as it is mentioned at the TensorRT doc
Then, I tried
torch.onnx.export(model, dummy_input, onnx_model_path, verbose=True, input_names=input_names, output_names=output_names, opset_version=7)
When I used PyTorch 1.3.1, the problem was still there, and the size of the exported model is 13,599 KB.
Interestingly, when I used PyTorch 1.2.0, the size was 13,986 KB, meaning that different versions PyTorch with the same versions opset doesn't guarantee you can export the same ONNX model.
Therefore. using PyTorch 1.2 may help you solve the problem (1.1 also)
Environment
TensorRT Version: 6.0.1
GPU Type: TX2
CUDA Version: 10.0
CUDNN Version: 7.6
Operating System + Version: ubuntu 18.04
Python Version (if applicable): 3.6
PyTorch Version (if applicable): 1.3.1
The text was updated successfully, but these errors were encountered: