Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stand-alone pad operation fails with: Assertion failed: inputs.at(1).is_weights() #439

Closed
copaah opened this issue Mar 17, 2020 · 25 comments
Labels
Export: torch.onnx https://pytorch.org/docs/stable/onnx.html ONNX triaged Issue has been triaged by maintainers

Comments

@copaah
Copy link

copaah commented Mar 17, 2020

Description

My current workflow is pytorch -> onnx -> tensorrt and I encounter an issue with the nn.ConstantPad2D operation that results in the following error:

While parsing node number 23 [Pad -> "24"]:
--- Begin node ---
input: "input"
input: "22"
input: "23"
output: "24"
op_type: "Pad"
attribute {
  name: "mode"
  s: "constant"
  type: STRING
}

--- End node ---
ERROR: /mypath/onnx-tensorrt/builtin_op_importers.cpp:2106 In function importPad:
[8] Assertion failed: inputs.at(1).is_weights()
[03/12/2020-09:06:29] [E] Failed to parse onnx file
[03/12/2020-09:06:29] [E] Parsing model failed
[03/12/2020-09:06:29] [E] Engine creation failed
[03/12/2020-09:06:29] [E] Engine set up failed

Environment

OS: Ubuntu 18.04
torch: 1.4.0
onnx: 1.6.0
tensorrt: 7.0.0
cuda: 10.0
python: 2.7

Steps To Reproduce

# github_repro_example.py
# -----------
import onnx
import argparse
import torch
import torch.nn as nn

class MinimalModel(nn.Module):
    def __init__(self):
        super(MinimalModel, self).__init__()
        self.constant_zero_pad = nn.ConstantPad2d((1, 0, 0, 0), 0)

    def forward(self, input_tensor):
        return self.constant_zero_pad(input_tensor)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='PSMNet')
    parser.add_argument('output_onnx')
    args = parser.parse_args()

    minimal_model = MinimalModel()
    minimal_model = nn.DataParallel(minimal_model)
    minimal_model.cuda()

    # Random deep feature
    input_tensor = torch.rand((1, 32, 128, 128))
    # Check model can do a forward pass
    minimal_model(input_tensor)
    # Export to onnx
    torch.onnx.export(
        minimal_model.module,
        (input_tensor),
        args.output_onnx,
        export_params=True, verbose=True, training=False, opset_version=11
    )

    original_model = onnx.load(args.output_onnx)
    onnx.checker.check_model(original_model)

Run with:

python2 github_repro_example.py ./test.onnx

Run it through tensorrt

trtexec --explicitBatch --onnx ./test.onnx --verbose

Which will result in the above error.

Related issues

onnx/onnx-tensorrt#378
onnx/onnx-tensorrt#411

@rmccorm4
Copy link
Collaborator

Hi @copaah ,

As for this error:

ERROR: /mypath/onnx-tensorrt/builtin_op_importers.cpp:2106 In function importPad:
[8] Assertion failed: inputs.at(1).is_weights()
[03/12/2020-09:06:29] [E] Failed to parse onnx file
[03/12/2020-09:06:29] [E] Parsing model failed
[03/12/2020-09:06:29] [E] Engine creation failed
[03/12/2020-09:06:29] [E] Engine set up failed

PyTorch generated kind of a funky ONNX graph for this simple model. I don't know if this is an issue to be fixed on their part, or on the ONNX parser's part, @kevinch-nv might be able to answer that.


As a workaround to your issue, you can try using onnx-simplifier on the onnx model and then parsing that, which worked for me.

# Start TRT 7 container
nvidia-docker run -it -v $PWD:/mnt nvcr.io/nvidia/tensorrt:20.02-py3

# Install TensorRT OSS components
wget https://raw.githubusercontent.com/rmccorm4/tensorrt-utils/master/OSS/build_OSS.sh
source build_OSS.sh

# Install dependencies
python3 -m pip install torch==1.4 onnx==1.6 onnx-simplifier
# Create repro model
python3 repro.py model.onnx
# Run onnx-simplifier on repro model
python3 -m onnxsim model.onnx model.simple.onnx

# Parse simplified model
trtexec --explicitBatch --onnx=model.simple.onnx
...
&&&& PASSED TensorRT.trtexec # trtexec --explicitBatch --onnx=model.simple.onnx

@rmccorm4
Copy link
Collaborator

Above original and simplified onnx models here: onnx-models.zip

@rmccorm4 rmccorm4 added Export: torch.onnx https://pytorch.org/docs/stable/onnx.html ONNX Release: 7.x labels Mar 26, 2020
@kevinch-nv
Copy link
Collaborator

For padding, the ONNX-TRT parser expects the padded values to be initializers (i.e constants) in the ONNX graph. I checked @rmccorm4's zip package and the nodes that were contributing to the pad dimensions were constant-folded into an initializer by the onnx-simplifier.

@rmccorm4
Copy link
Collaborator

rmccorm4 commented Mar 26, 2020

@kevinch-nv thanks for the insight. I added torch.onnx.export(..., do_constant_folding=True, keep_initializers_as_inputs=True), and I get a slightly different model. but still fails with the same error. Any idea why?

Is there any parameter to torch.onnx.export that will correctly create the padded values as initializers?

torch.onnx.export(..., do_constant_folding=True) model: model_constant_fold.zip

torch.onnx.export(..., do_constant_folding=True, keep_initializers_as_inputs=True) model: model_keep_inits.zip

@kevinch-nv
Copy link
Collaborator

Perhaps the constant-folding functionality of the current torch2onnx export doesn't support this particular structure yet. Pre-opset 11 the pads input of the padding node was an attribute instead of an input, can you try exporting to opset-10 and inspecting the resulting graph?

@rmccorm4
Copy link
Collaborator

rmccorm4 commented Mar 27, 2020

Good call @kevinch-nv !

With opset 10, whether you use default args, or set do_constant_folding=True, we get a simple graph containing only the Pad op which is parseable by TensorRT.

I raised an issue in Pytorch to see what they think: pytorch/pytorch#35516

@kevinch-nv
Copy link
Collaborator

@copaah are you able to import your model with TensorRT now?

@hhhharold
Copy link

same issue, anyone has solved this?

@vilmara
Copy link

vilmara commented Jan 20, 2021

Hi @copaah ,

As for this error:

ERROR: /mypath/onnx-tensorrt/builtin_op_importers.cpp:2106 In function importPad:
[8] Assertion failed: inputs.at(1).is_weights()
[03/12/2020-09:06:29] [E] Failed to parse onnx file
[03/12/2020-09:06:29] [E] Parsing model failed
[03/12/2020-09:06:29] [E] Engine creation failed
[03/12/2020-09:06:29] [E] Engine set up failed

PyTorch generated kind of a funky ONNX graph for this simple model. I don't know if this is an issue to be fixed on their part, or on the ONNX parser's part, @kevinch-nv might be able to answer that.

As a workaround to your issue, you can try using onnx-simplifier on the onnx model and then parsing that, which worked for me.

# Start TRT 7 container
nvidia-docker run -it -v $PWD:/mnt nvcr.io/nvidia/tensorrt:20.02-py3

# Install TensorRT OSS components
wget https://raw.githubusercontent.com/rmccorm4/tensorrt-utils/master/OSS/build_OSS.sh
source build_OSS.sh

# Install dependencies
python3 -m pip install torch==1.4 onnx==1.6 onnx-simplifier
# Create repro model
python3 repro.py model.onnx
# Run onnx-simplifier on repro model
python3 -m onnxsim model.onnx model.simple.onnx

# Parse simplified model
trtexec --explicitBatch --onnx=model.simple.onnx
...
&&&& PASSED TensorRT.trtexec # trtexec --explicitBatch --onnx=model.simple.onnx

Hi @rmccorm4, after applying your instructions mentioned on issues #386 and #439 I got new error, some idea how to fix it?:

[01/20/2021-16:17:48] [W] [TRT] /workspace/TensorRT/parsers/onnx/onnx2trt_utils.cpp:227: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[01/20/2021-16:17:48] [W] [TRT] /workspace/TensorRT/parsers/onnx/onnx2trt_utils.cpp:253: One or more weights outside the range of INT32 was clamped
[01/20/2021-16:17:48] [W] [TRT] /workspace/TensorRT/parsers/onnx/onnx2trt_utils.cpp:253: One or more weights outside the range of INT32 was clamped
[01/20/2021-16:17:48] [W] [TRT] /workspace/TensorRT/parsers/onnx/onnx2trt_utils.cpp:253: One or more weights outside the range of INT32 was clamped
[01/20/2021-16:17:48] [I] [TRT] No importer registered for op: Round. Attempting to import as plugin.
[01/20/2021-16:17:48] [I] [TRT] Searching for plugin: Round, plugin_version: 1, plugin_namespace:
[01/20/2021-16:17:48] [E] [TRT] INVALID_ARGUMENT: getPluginCreator could not find plugin Round version 1
[01/20/2021-16:17:48] [E] [TRT] /workspace/TensorRT/parsers/onnx/ModelImporter.cpp:705: While parsing node number 7 [Loop -> "unused_loop_output___143"]:
[01/20/2021-16:17:48] [E] [TRT] /workspace/TensorRT/parsers/onnx/ModelImporter.cpp:706: --- Begin node ---
[01/20/2021-16:17:48] [E] [TRT] /workspace/TensorRT/parsers/onnx/ModelImporter.cpp:707: input: "trip_count__68"
....
[01/20/2021-16:17:48] [E] [TRT] /workspace/TensorRT/parsers/onnx/ModelImporter.cpp:708: --- End node ---
[01/20/2021-16:17:48] [E] [TRT] /workspace/TensorRT/parsers/onnx/ModelImporter.cpp:711: ERROR: /workspace/TensorRT/parsers/onnx/builtin_op_importers.cpp:3807 In function importFallbackPluginImporter:
[8] Assertion failed: creator && "Plugin not found, are the plugin name, version, and namespace correct?"
[01/20/2021-16:17:48] [E] Failed to parse onnx file
[01/20/2021-16:17:48] [E] Parsing model failed
[01/20/2021-16:17:48] [E] Engine creation failed
[01/20/2021-16:17:48] [E] Engine set up failed
&&&& FAILED TensorRT.trtexec # trtexec --onnx=/faster_rcnn_inceptionv2_coco_updated_model_opset11_onnx1.6_nocostfold.onnx --explicitBatch


@opeide
Copy link

opeide commented Mar 9, 2021

Half a year later, issue still persists.

@dedoogong
Copy link

same here

@ttyio ttyio added the triaged Issue has been triaged by maintainers label Apr 26, 2021
@wkl2013DeepVision
Copy link

Half a year later, issue still persists.

same here, wast so many time !, can any one who solved it?

@ttyio
Copy link
Collaborator

ttyio commented May 27, 2021

@opeide @dedoogong @wkl2013DeepVision

Could you try

      polygraphy surgeon sanitize --fold-constants

We have document here https://github.com/onnx/onnx-tensorrt/blob/master/docs/faq.md#common-assertion-errors thanks

@romain87400
Copy link

@ttyio
I'm try
polygraphy surgeon sanitize model.onnx --fold-constants --output model_folded.onnx
I'm come back on the first error [8] Assertion failed: mode == "constant" && value == 0.
When i use onnxsim to corrected [8] Assertion failed: inputs.at(1).is_weights() i come back on this error [8] Assertion failed: mode == "constant" && value == 0.
Do you have a solution on this problem ?
I work on Nvidia Xavier NX, Jetpack 4.4 with tensorRT 7.1.3.

@ttyio
Copy link
Collaborator

ttyio commented Jun 10, 2021

Sorry @romain87400 , this command can help to fold constants to solve some Assertion failed: inputs.at(1).is_weights() failure, but we still only support constant 0 padding.

@romain87400
Copy link

@ttyio Thanks for your answer.
So, currently, we cannot converted model with constant 0 padding.
But in futur, you will think correct this use case ?
Or maybe, you have an idear to get around this problem ?
I'm readed that we must converted models to onnx with opset11, but when i do that i fall on [8] Assertion failed: inputs.at(1).is_weights()

Sorry I'm a beginner on the subject.
Have a nice day.

@ttyio
Copy link
Collaborator

ttyio commented Jun 10, 2021

Hello @romain87400 , I am not sure if you can adjust your model, finding this error node, change to pad with 0s, and fine-tune your model.

the inputs.at(1).is_weights() assertion failure usually can be fixed by polygraphy surgeon sanitize --fold-constants unless it is really an activation in your model.

@phamdat09
Copy link

@ttyio Any update for fixing it ?? Thanks

@ttyio
Copy link
Collaborator

ttyio commented Aug 20, 2021

@phamdat09 non constant padding will be supported in next release (in around 2 months), thanks

@phamdat09
Copy link

phamdat09 commented Aug 20, 2021

@ttyio Thanks for your info. But in my case, I think it is constant padding.
I got an error like it
[ERROR] [TRT] /home/bigbigboy/Documents/Test/TensortRT_Dynamic/TensorRT/parsers/onnx/ModelImporter.cpp:727: input: "496"
input: "554"
input: "556"
output: "557"
name: "Pad_151"
op_type: "Pad"
attribute {
name: "mode"
s: "constant"
type: STRING
}

[ERROR] [TRT] /home/bigbigboy/Documents/Test/TensortRT_Dynamic/TensorRT/parsers/onnx/ModelImporter.cpp:728: --- End node ---
[ERROR] [TRT] /home/bigbigboy/Documents/Test/TensortRT_Dynamic/TensorRT/parsers/onnx/ModelImporter.cpp:731: ERROR: /home/bigbigboy/Documents/Test/TensortRT_Dynamic/TensorRT/parsers/onnx/builtin_op_importers.cpp:2985 In function importPad:
[8] Assertion failed: inputs.at(1).is_weights() && "The input pads is required to be an initializer."`

@HaohaoNJU
Copy link

Perhaps the constant-folding functionality of the current torch2onnx export doesn't support this particular structure yet. Pre-opset 11 the pads input of the padding node was an attribute instead of an input, can you try exporting to opset-10 and inspecting the resulting graph?

Works for me , THX ~

@nvpohanh
Copy link
Collaborator

@copaah Could you try TRT 8.2/8.4 and see if the issue still exists? If it does, we will debug it. Thanks

@nvpohanh
Copy link
Collaborator

nvpohanh commented Jul 1, 2022

Closing for now due to >14 days without activity. Please feel free to reopen if the issue still exists. Thanks

@nvpohanh nvpohanh closed this as completed Jul 1, 2022
@donrax
Copy link

donrax commented Jul 20, 2022

This is perhaps related to #847

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Export: torch.onnx https://pytorch.org/docs/stable/onnx.html ONNX triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests