Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mismatched type error when generating an engine for a quantized stereo-depth model #2131

Closed
deephog opened this issue Jul 7, 2022 · 15 comments
Assignees
Labels
Quantization: QAT Quanitization-aware Training triaged Issue has been triaged by maintainers

Comments

@deephog
Copy link

deephog commented Jul 7, 2022

onnx model : link

used trtexec to generate the engine:

trtexec --onnx=test_quant_sim.onnx --saveEngine=test_quant_sim.engine --workspace=4096 --int8 --fp16 --noTF32 --verbose --noDataTransfers --separateProfileRun --dumpProfile --useCudaGraph > test_quant_sim.log

When generating the engine for this onnx model, this following error happened:

[07/07/2022-11:15:23] [W] [TRT] Skipping tactic 0 due to Myelin error: tensor "inputs_63'" has an unitialized dimension.
[07/07/2022-11:15:23] [W] [TRT] Skipping tactic 0 due to Myelin error: tensor "inputs_63'" has an unitialized dimension.
[07/07/2022-11:15:26] [W] [TRT] Skipping tactic 0 due to Myelin error: Mismatched type for tensor onnx__Where_3326', f16 vs. expected type:b.
[07/07/2022-11:15:26] [W] [TRT] Skipping tactic 0 due to Myelin error: Mismatched type for tensor onnx__Where_3326', f16 vs. expected type:b.
[07/07/2022-11:18:10] [W] [TRT] Skipping tactic 0 due to Myelin error: Could not infer output types for operation: 7: gather: onnx__Gather_3802'-(f32[1,2,3,11,13][]) | onnx__Gather_3801'_postCast-(f32[1,2,13,13][]), onnx__Gather_3772' FP32Input-(f32[3,11][]), stream = 0 // Gather_2972 | axis: 2 batch_dims: 0
[07/07/2022-11:18:10] [W] [TRT] Skipping tactic 0 due to Myelin error: Could not infer output types for operation: 7: gather: onnx__Gather_3802'-(f32[1,2,3,11,13][]) | onnx__Gather_3801'_postCast-(f32[1,2,13,13][]), onnx__Gather_3772' FP32Input-(f32[3,11][]), stream = 0 // Gather_2972 | axis: 2 batch_dims: 0
[07/07/2022-11:18:19] [W] [TRT] Skipping tactic 0 due to Myelin error: Mismatched type for tensor onnx__Where_5194', f16 vs. expected type:b.
[07/07/2022-11:18:19] [W] [TRT] Skipping tactic 0 due to Myelin error: Mismatched type for tensor onnx__Where_5194', f16 vs. expected type:b.
[07/07/2022-11:19:13] [W] [TRT] Skipping tactic 0 due to Myelin error: Could not infer output types for operation: 7: gather: onnx__Gather_5674'-(f32[1,2,3,22,24][]) | onnx__Gather_5673'_postCast-(f32[1,2,24,24][]), onnx__Gather_5644' FP32Input-(f32[3,22][]), stream = 0 // Gather_4237 | axis: 2 batch_dims: 0
[07/07/2022-11:19:13] [W] [TRT] Skipping tactic 0 due to Myelin error: Could not infer output types for operation: 7: gather: onnx__Gather_5674'-(f32[1,2,3,22,24][]) | onnx__Gather_5673'_postCast-(f32[1,2,24,24][]), onnx__Gather_5644' FP32Input-(f32[3,22][]), stream = 0 // Gather_4237 | axis: 2 batch_dims: 0
[07/07/2022-11:19:18] [W] [TRT] Skipping tactic 0 due to Myelin error: Mismatched type for tensor onnx__Where_5884', f16 vs. expected type:b.
[07/07/2022-11:19:18] [W] [TRT] Skipping tactic 0 due to Myelin error: Mismatched type for tensor onnx__Where_5884', f16 vs. expected type:b.
[07/07/2022-11:19:35] [W] [TRT] Skipping tactic 0 due to Myelin error: Mismatched type for tensor onnx__Where_6886', f16 vs. expected type:b.
[07/07/2022-11:19:35] [W] [TRT] Skipping tactic 0 due to Myelin error: Mismatched type for tensor onnx__Where_6886', f16 vs. expected type:b.
[07/07/2022-11:19:45] [W] [TRT] Skipping tactic 0 due to Myelin error: tensor "flow_up'" has an unitialized dimension.
[07/07/2022-11:19:45] [W] [TRT] Skipping tactic 0 due to Myelin error: tensor "flow_up'" has an unitialized dimension.
[07/07/2022-11:19:58] [W] [TRT] Skipping tactic 0 due to Myelin error: Mismatched type for tensor onnx__Where_8449', f16 vs. expected type:b.
[07/07/2022-11:19:58] [W] [TRT] Skipping tactic 0 due to Myelin error: Mismatched type for tensor onnx__Where_8449', f16 vs. expected type:b.
[07/07/2022-11:20:20] [W] [TRT] Skipping tactic 0 due to Myelin error: Mismatched type for tensor onnx__Where_9451', f16 vs. expected type:b.
[07/07/2022-11:20:20] [W] [TRT] Skipping tactic 0 due to Myelin error: Mismatched type for tensor onnx__Where_9451', f16 vs. expected type:b.
[07/07/2022-11:20:31] [W] [TRT] Skipping tactic 0 due to Myelin error: Could not infer output types for operation: 7: gather: onnx__Gather_10419'-(f32[1,2,3,90,92][]) | onnx__Gather_10418'_postCast-(f32[1,2,92,92][]), onnx__Gather_10389' FP32Input-(f32[3,90][]), stream = 0 // Gather_7892 | axis: 2 batch_dims: 0
[07/07/2022-11:20:31] [W] [TRT] Skipping tactic 0 due to Myelin error: Could not infer output types for operation: 7: gather: onnx__Gather_10419'-(f32[1,2,3,90,92][]) | onnx__Gather_10418'_postCast-(f32[1,2,92,92][]), onnx__Gather_10389' FP32Input-(f32[3,90][]), stream = 0 // Gather_7892 | axis: 2 batch_dims: 0
[07/07/2022-11:20:34] [E] Error[2]: [reformat.cpp::cpuReformat::372] Error Code 2: Internal Error (Assertion outScaleBuf == nullptr ? rFormatOut.dataType != DataType::kINT8 || hasZeroVolume(extent) : rFormatOut.dataType == DataType::kINT8 failed. )
[07/07/2022-11:20:34] [E] Error[2]: [builder.cpp::buildSerializedNetwork::609] Error Code 2: Internal Error (Assertion enginePtr != nullptr failed. )
[07/07/2022-11:20:34] [E] Engine could not be created from network
[07/07/2022-11:20:34] [E] Building engine failed
[07/07/2022-11:20:34] [E] Failed to create engine from model.
[07/07/2022-11:20:34] [E] Engine set up failed

It started as just warning, but eventually failed, I guess because it exhausted all the possible solutions.

The model was simply quantized by tutorial settings, automatic replacement of all the compatible layers and calibrated.

The error message looks like the backbone or cost-volume of the model is quantized to int8, but the Gather and Where operations are expecting TF32, which caused the mismatch. Isn't there Dequant layers that cast the result back to TF32 that prevent this kind of mismatch between quantizable and un-quantizable layers?

@deephog deephog closed this as completed Jul 7, 2022
@deephog deephog reopened this Jul 7, 2022
@deephog
Copy link
Author

deephog commented Jul 7, 2022

I previously closed the issue because I thought I figured it out by removing --noTF32 flag. It did successfully compile, but the entire engine fell back to TF32, so basically quantization was not achieved.

This is done with TensorRT 8.2.5.1 shipped with the newest docker image. I also tried 8.4.1.5, compilation failed with following message:

[07/07/2022-14:24:53] [E] Error[2]: [qdqGraphOptimizer.cpp::reportWeightlessTwoInputConvolutionAsError::230] Error Code 2: Internal Error (Conv_1444: Could not fuse 2nd input (kernel weights) of CONVOLUTION)

This issue also happened with my other quantized model, both simplified by onnx-simplifier. I saw other people reporting this exact error code before, but no solution yet.

@zerollzeng
Copy link
Collaborator

try to reproduce with TRT 8.4.

[07/08/2022-06:16:07] [V] [TRT] Swap the layer type of QuantizeLinear_7751_clone_1_clone_2 from QUANTIZE to kQDQ
[07/08/2022-06:16:07] [V] [TRT] Running: QDQToCopy on QuantizeLinear_7751_clone_0
[07/08/2022-06:16:07] [V] [TRT] Swap the layer type of QuantizeLinear_7751_clone_0 from QUANTIZE to kQDQ
[07/08/2022-06:16:07] [V] [TRT] Running: QDQToCopy on QuantizeLinear_7780_clone_0
[07/08/2022-06:16:07] [V] [TRT] Swap the layer type of QuantizeLinear_7780_clone_0 from QUANTIZE to kQDQ
[07/08/2022-06:16:07] [E] Error[2]: [qdqGraphOptimizer.cpp::reportWeightlessTwoInputConvolutionAsError::230] Error Code 2: Internal Error (Conv_791: Could not fuse 2nd input (kernel weights) of CONVOLUTION)
[07/08/2022-06:16:07] [E] Error[2]: [builder.cpp::buildSerializedNetwork::636] Error Code 2: Internal Error (Assertion engine != nullptr failed. )
[07/08/2022-06:16:07] [E] Engine could not be created from network
[07/08/2022-06:16:07] [E] Building engine failed
[07/08/2022-06:16:07] [E] Failed to create engine from model or file.
[07/08/2022-06:16:07] [E] Engine set up failed
&&&& FAILED TensorRT.trtexec [TensorRT v8401] # ./trtexec --onnx=/zeroz_sw/temp/test_quant_sim.onnx --int8 --fp16 --verbose

image

@zerollzeng
Copy link
Collaborator

zerollzeng commented Jul 8, 2022

@ttyio any idea?

@zerollzeng zerollzeng self-assigned this Jul 8, 2022
@zerollzeng zerollzeng added the triaged Issue has been triaged by maintainers label Jul 8, 2022
@deephog
Copy link
Author

deephog commented Jul 8, 2022

try to reproduce with TRT 8.4.

[07/08/2022-06:16:07] [V] [TRT] Swap the layer type of QuantizeLinear_7751_clone_1_clone_2 from QUANTIZE to kQDQ
[07/08/2022-06:16:07] [V] [TRT] Running: QDQToCopy on QuantizeLinear_7751_clone_0
[07/08/2022-06:16:07] [V] [TRT] Swap the layer type of QuantizeLinear_7751_clone_0 from QUANTIZE to kQDQ
[07/08/2022-06:16:07] [V] [TRT] Running: QDQToCopy on QuantizeLinear_7780_clone_0
[07/08/2022-06:16:07] [V] [TRT] Swap the layer type of QuantizeLinear_7780_clone_0 from QUANTIZE to kQDQ
[07/08/2022-06:16:07] [E] Error[2]: [qdqGraphOptimizer.cpp::reportWeightlessTwoInputConvolutionAsError::230] Error Code 2: Internal Error (Conv_791: Could not fuse 2nd input (kernel weights) of CONVOLUTION)
[07/08/2022-06:16:07] [E] Error[2]: [builder.cpp::buildSerializedNetwork::636] Error Code 2: Internal Error (Assertion engine != nullptr failed. )
[07/08/2022-06:16:07] [E] Engine could not be created from network
[07/08/2022-06:16:07] [E] Building engine failed
[07/08/2022-06:16:07] [E] Failed to create engine from model or file.
[07/08/2022-06:16:07] [E] Engine set up failed
&&&& FAILED TensorRT.trtexec [TensorRT v8401] # ./trtexec --onnx=/zeroz_sw/temp/test_quant_sim.onnx --int8 --fp16 --verbose

image

Just trying to understand the error message. According to the Documentation, TensorRT will try to swap the position of the Q layers backward and DQ layers forward to span as much as possible and merge them when they collide with the other set of Q/DQ layers, in order to maximize the number of operations being quantized.

This error message, sounds like during the swapping process, somehow one of the "weight quantizer" being misinterpreted as a "input quantizer"? So a "weightless conv layer with two inputs" happened?

@ttyio
Copy link
Collaborator

ttyio commented Jul 11, 2022

The fnet.conv1.weight is shared by multiple conv, and currently TRT cannot constant fold the weights that shared. Dynamic weights input support could fix this, this is already in TRT plan and actively under development.
Before that, @deephog could you WAR this by disable the weights sharing? maybe use graphsurgeon to make a copy of fnet.conv1.weight? Thanks!

@ttyio ttyio added the Quantization: QAT Quanitization-aware Training label Jul 11, 2022
@deephog
Copy link
Author

deephog commented Jul 11, 2022

The fnet.conv1.weight is shared by multiple conv, and currently TRT cannot constant fold the weights that shared. Dynamic weights input support could fix this, this is already in TRT plan and actively under development. Before that, @deephog could you WAR this by disable the weights sharing? maybe use graphsurgeon to make a copy of fnet.conv1.weight? Thanks!

Thank you for your reply! Actually I don't know why weight sharing is happening and I did not intentionally do that at all. And I tried to search the node you mentioned about in the graph, I didn't see why you think the weight is being shared.
image

And by the way, like I mentioned in the post, the model can be successfully compiled with 8.2.5, although with the type-mismatch problem that caused the quantization being completely abandoned. And if I compile it with FP16 without the quantization, the whole model can be successfully compiled both under 8.2.5 or 8.4.1.

Please share more of your insights with these info considered. thanks!

@deephog
Copy link
Author

deephog commented Jul 11, 2022

The fnet.conv1.weight is shared by multiple conv, and currently TRT cannot constant fold the weights that shared. Dynamic weights input support could fix this, this is already in TRT plan and actively under development. Before that, @deephog could you WAR this by disable the weights sharing? maybe use graphsurgeon to make a copy of fnet.conv1.weight? Thanks!

I see why you think it is weight sharing. The backbone (fnet) is indeed used twice. Actually to prevent this kind of issue, I tried to declare the module twice with different name, then loaded the same weight. In the original onnx file, there are two modules with different names. Somehow the ONNX-simplifier find this out and eliminated the cloned backbone. I need to find out how to prevent ONNX-simplifier from doing that. BTW, the model will never work with TRT if not simplified.

Besides this, my other questions still hold, if this is the weight sharing issue, why other version or precision can be successfully compiled?

@deephog
Copy link
Author

deephog commented Jul 12, 2022

The fnet.conv1.weight is shared by multiple conv, and currently TRT cannot constant fold the weights that shared. Dynamic weights input support could fix this, this is already in TRT plan and actively under development. Before that, @deephog could you WAR this by disable the weights sharing? maybe use graphsurgeon to make a copy of fnet.conv1.weight? Thanks!

I disabled the "duplicate weight initializer elimination" function in the ONNX-Simplifier so that layers with the same weights won't trigger the weight sharing automatically. This is the new simplified ONNX. Seems like the previous node is ok now, but another node is reporting the same error. However, this time it is quite clear it is a dual-weight non-input node.

[qdqGraphOptimizer.cpp::reportWeightlessTwoInputConvolutionAsError::230] Error Code 2: Internal Error (Conv_3050 + Relu_3051: Could not fuse 2nd input (kernel weights) of CONVOLUTION)

image

I don't know how this node is generated like this. I tried to skip all optimizers in ONNX-Simplifier, and I changed opset version from 13-16, this node still exists as it is. Again, the model is perfectly fine when not quantized and exported as FP16, and there is no such error when compiled with TRT 8.2.5, only mismatched type warning.

And if it is an issue about weight sharing not supported in INT8 mode, does that mean any kind of recurrent usage of layers is not supported?

@deephog
Copy link
Author

deephog commented Jul 13, 2022

Here is an updated version of ONNX file

The previous "dual weights no input" errors are cleaned or evaded. However, a new error occurred which is difficult to trace.

Error[2]: [engineTacticSupplyHelpers.cpp::makeEngineTensor::55] Error Code 2: Internal Error (Assertion isMultiple(tensor.start[vectorDim], spv) failed. )

And could you also please advice me which path should I go further? Because all these errors happen only in TRT 8.4. In TRT 8.2.5, the engine can be compiled, just with those mismatched type problem, so the explicit quantization was never achieved. If you can tell me that TRT 8.4 won't have the mismatched type issue, I will keep digging and try to figure out the issues happen in TRT8.4, otherwise I will fallback to 8.2.5 and try to figure out the other set of issues. @ttyio @zerollzeng

Thanks!

@ttyio
Copy link
Collaborator

ttyio commented Aug 2, 2022

I created internal task 3741010 to track this issue.

@deephog , could you wait for next major release? Thanks!

@deephog
Copy link
Author

deephog commented Aug 2, 2022

I created internal task 3741010 to track this issue.

@deephog , could you wait for next major release? Thanks!

Thank you for the effort! Yes I will wait for the next major release

@ttyio
Copy link
Collaborator

ttyio commented Aug 18, 2022

The mode (https://drive.google.com/file/d/1xJyU7CnVqzc8tBU0_ruewTD1fgxrz1EA/view?usp=sharing) will be fixed in 8.5EA, closing and thanks!

@ttyio ttyio closed this as completed Aug 18, 2022
@deephog
Copy link
Author

deephog commented Oct 10, 2022

The mode (https://drive.google.com/file/d/1xJyU7CnVqzc8tBU0_ruewTD1fgxrz1EA/view?usp=sharing) will be fixed in 8.5EA, closing and thanks!

Today I tried the newest docker image from nvidia which is 22.09, surprisingly it contains 8.5EA (though it is not announced anywhere). But still, the mismatched type persisted. Maybe it is still not the version that fixes this issue, just FYI.

@ttyio
Copy link
Collaborator

ttyio commented Oct 14, 2022

The mode (https://drive.google.com/file/d/1xJyU7CnVqzc8tBU0_ruewTD1fgxrz1EA/view?usp=sharing) will be fixed in 8.5EA, closing and thanks!

Today I tried the newest docker image from nvidia which is 22.09, surprisingly it contains 8.5EA (though it is not announced anywhere). But still, the mismatched type persisted. Maybe it is still not the version that fixes this issue, just FYI.

Sorry, did you try the https://drive.google.com/file/d/1xJyU7CnVqzc8tBU0_ruewTD1fgxrz1EA/view?usp=sharing in 22.09?

@deephog
Copy link
Author

deephog commented Oct 17, 2022

The mode (https://drive.google.com/file/d/1xJyU7CnVqzc8tBU0_ruewTD1fgxrz1EA/view?usp=sharing) will be fixed in 8.5EA, closing and thanks!

Today I tried the newest docker image from nvidia which is 22.09, surprisingly it contains 8.5EA (though it is not announced anywhere). But still, the mismatched type persisted. Maybe it is still not the version that fixes this issue, just FYI.

Sorry, did you try the https://drive.google.com/file/d/1xJyU7CnVqzc8tBU0_ruewTD1fgxrz1EA/view?usp=sharing in 22.09?

Sorry, I experimented a few times then I realized I was wrong. The only "mismatch" now I'm seeing is from "where" ops, which I think it is trying to cast Boolean to something else then failed (which should never be succeeded). The results showed that the latency is indeed shrunk significantly, so at least it passed. I will do the quantization with higher amount of data to see how accurate it is.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Quantization: QAT Quanitization-aware Training triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

3 participants