Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why TensorRT 8.5 and TensorRT 8.6 do not support dynamic shape input during the calibration? #3933

Open
yjiangling opened this issue Jun 11, 2024 · 11 comments

Comments

@yjiangling
Copy link

When I use polygraphy to conduct INT8 calibration and quantilization, it failed. Why the TensorRT 8.5 and above version do not support dynamic shape input during the calibration? But it works fine in TensorRT 8.4. For TensorRT 8.5, it gives error like:

[V]     Setting TensorRT Optimization Profiles
[V]     Input tensor: xs (dtype=DataType.FLOAT, shape=(1, -1)) | Setting input tensor shapes to: (min=(1, 1120), opt=(1, 160000), max=(1, 480000))
[V]     Input tensor: xlen (dtype=DataType.INT32, shape=(1,)) | Setting input tensor shapes to: (min=(1,), opt=(1,), max=(1,))
[I]     Configuring with profiles: [Profile().add('xs', min=(1, 1120), opt=(1, 160000), max=(1, 480000)).add('xlen', min=(1,), opt=(1,), max=(1,))]
[W] Will use `opt` shapes from profile 0 for calibration. Note that even though `min` != `max` in this profile, calibration will use fixed input shapes. This is not necessarily an issue.
[V] Loaded Module: numpy | Version: 1.21.6 | Path: ['/usr/local/lib/python3.8/dist-packages/numpy']
[W] Will use `opt` shapes from profile 0 for calibration. Note that even though `min` != `max` in this profile, calibration will use fixed input shapes. This is not necessarily an issue.
[I] Building engine with configuration:
    Flags                  | [INT8]
    Engine Capability      | EngineCapability.DEFAULT
    Memory Pools           | [WORKSPACE: 10002.44 MiB]
    Tactic Sources         | [CUBLAS, CUBLAS_LT, CUDNN, EDGE_MASK_CONVOLUTIONS, JIT_CONVOLUTIONS]
    Profiling Verbosity    | ProfilingVerbosity.DETAILED
    Calibrator             | Calibrator(<generator object calib_data at 0x7fd0c3163c80>, cache='trt85_calib.cache', BaseClass=<class 'tensorrt.tensorrt.IInt8MinMaxCalibrator'>)
[V] CPLX_M_rfftrfft__332: broadcasting input1 to make tensors conform, dims(input0)=[2,257,512][NONE] dims(input1)=[1,512,-1][NONE].
[W] Using PreviewFeature::kFASTER_DYNAMIC_SHAPES_0805 can help improve performance and resolve potential functional issues.
[V] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +6, GPU +10, now: CPU 834, GPU 347 (MiB)
[V] [MemUsageChange] Init cuDNN: CPU +1, GPU +10, now: CPU 835, GPU 357 (MiB)
[V] Timing cache disabled. Turning it on will improve builder speed.
[V] [GraphReduction] The approximate region cut reduction algorithm is called.
[V] Total Activation Memory: 11128477184
[V] Detected 2 inputs and 1 output network tensors.
[V] Total Host Persistent Memory: 67616
[V] Total Device Persistent Memory: 0
[V] Total Scratch Memory: 4194304
[V] [MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 2 MiB, GPU 153 MiB
[V] [BlockAssignment] Started assigning block shifts. This will take 1438 steps to complete.
[V] [BlockAssignment] Algorithm ShiftNTopDown took 1654.34ms to assign 183 blocks to 1438 nodes requiring 18010112 bytes.
[V] Total Activation Memory: 18010112
[V] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +8, now: CPU 1006, GPU 519 (MiB)
[V] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +1, GPU +10, now: CPU 1006, GPU 503 (MiB)
[V] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +17, now: CPU 0, GPU 170 (MiB)
[V] Starting Calibration.
**[!] Received an unexpected input from the data loader during calibration. For input: 'xs', expected a shape compatible with: (1, 160000), but received: (1, 41152)**
[V]   Post Processing Calibration data in 1.725e-06 seconds.
[E] 1: Unexpected exception _Map_base::at
[E] 2: [builder.cpp::buildSerializedNetwork::751] Error Code 2: Internal Error (Assertion engine != nullptr failed. )
[!] Invalid Engine. Please ensure the engine was built correctly
Traceback (most recent call last):
  File "quantilize.py", line 192, in <module>
    main()
  File "quantilize.py", line 178, in main
    with G_LOGGER.verbosity(G_LOGGER.VERBOSE), TrtRunner(build_engine) as runner:
  File "/usr/local/lib/python3.8/dist-packages/polygraphy/backend/base/runner.py", line 60, in __enter__
    self.activate()
  File "/usr/local/lib/python3.8/dist-packages/polygraphy/backend/base/runner.py", line 95, in activate
    self.activate_impl()
  File "/usr/local/lib/python3.8/dist-packages/polygraphy/backend/trt/runner.py", line 87, in activate_impl
    engine_or_context, owning = util.invoke_if_callable(self._engine_or_context)
  File "/usr/local/lib/python3.8/dist-packages/polygraphy/util/util.py", line 661, in invoke_if_callable
    ret = func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/polygraphy/backend/base/loader.py", line 42, in __call__
    return self.call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/polygraphy/backend/trt/loader.py", line 526, in call_impl
    return engine_from_bytes(super().call_impl)
  File "<string>", line 3, in func_impl
  File "/usr/local/lib/python3.8/dist-packages/polygraphy/backend/base/loader.py", line 42, in __call__
    return self.call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/polygraphy/backend/trt/loader.py", line 550, in call_impl
    buffer, owns_buffer = util.invoke_if_callable(self._serialized_engine)
  File "/usr/local/lib/python3.8/dist-packages/polygraphy/util/util.py", line 661, in invoke_if_callable
    ret = func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/polygraphy/backend/trt/loader.py", line 484, in call_impl
    G_LOGGER.critical("Invalid Engine. Please ensure the engine was built correctly")
  File "/usr/local/lib/python3.8/dist-packages/polygraphy/logger/logger.py", line 597, in critical
    raise PolygraphyException(message) from None

And for TensorRT 8.6, it goves error like:

[V] Input filename:   ./onnx_model/model.onnx
[V] ONNX IR version:  0.0.7
[V] Opset version:    13
[V] Producer name:    tf2onnx
[V] Producer version: 1.11.1 1915fb
[V] Domain:           
[V] Model version:    0
[V] Doc string:       
[V] ----------------------------------------------------------------
[W] onnx2trt_utils.cpp:374: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[V]     Setting TensorRT Optimization Profiles
[V]     Input tensor: xs (dtype=DataType.FLOAT, shape=(1, -1)) | Setting input tensor shapes to: (min=(1, 1120), opt=(1, 160000), max=(1, 480000))
[V]     Input tensor: xlen (dtype=DataType.INT32, shape=(1,)) | Setting input tensor shapes to: (min=(1,), opt=(1,), max=(1,))
[I]     Configuring with profiles: [Profile().add('xs', min=(1, 1120), opt=(1, 160000), max=(1, 480000)).add('xlen', min=(1,), opt=(1,), max=(1,))]
**[W] TensorRT does not currently support using dynamic shapes during calibration. The `OPT` shapes from the calibration profile will be used for tensors with dynamic shapes. Calibration data is expected to conform to those shapes.** 
[V] Loaded Module: numpy | Version: 1.23.5 | Path: ['/usr/local/lib/python3.10/dist-packages/numpy']
[I] Building engine with configuration:
    Flags                  | [INT8]
    Engine Capability      | EngineCapability.DEFAULT
    Memory Pools           | [WORKSPACE: 10002.44 MiB, TACTIC_DRAM: 10002.44 MiB]
    Tactic Sources         | [CUBLAS, CUBLAS_LT, CUDNN, EDGE_MASK_CONVOLUTIONS, JIT_CONVOLUTIONS]
    Profiling Verbosity    | ProfilingVerbosity.DETAILED
    Preview Features       | [FASTER_DYNAMIC_SHAPES_0805, DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805]
    Calibrator             | Calibrator(<generator object calib_data at 0x7f1286f74ba0>, cache='trt86_calib.cache', BaseClass=<class 'tensorrt.tensorrt.IInt8MinMaxCalibrator'>)
[V] Graph optimization time: 0.061173 seconds.
[V] Timing cache disabled. Turning it on will improve builder speed.
[V] [GraphReduction] The approximate region cut reduction algorithm is called.
[V] Detected 2 inputs and 1 output network tensors.
[V] Total Host Persistent Memory: 580944
[V] Total Device Persistent Memory: 194560
[V] Total Scratch Memory: 2622464
[V] [MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 9 MiB, GPU 151 MiB
[V] [BlockAssignment] Started assigning block shifts. This will take 1386 steps to complete.
[V] [BlockAssignment] Algorithm ShiftNTopDown took 1684.36ms to assign 183 blocks to 1386 nodes requiring 17694208 bytes.
[V] Total Activation Memory: 17694208
[V] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +17, now: CPU 0, GPU 170 (MiB)
[V] Starting Calibration.
**[!] Received an unexpected input from the data loader during calibration. For input: 'xs', expected a shape compatible with: BoundedShape([1, 160000], min=None, max=None), but received: (1, 41152)**
[V]   Post Processing Calibration data in 1.209e-06 seconds.
[E] 1: Unexpected exception _Map_base::at
[!] Invalid Engine. Please ensure the engine was built correctly
Traceback (most recent call last):
  File "/media/tcl1/ASR/users/yujiangling/TensorRT_Engine/quantilize.py", line 192, in <module>
    main()
  File "/media/tcl1/ASR/users/yujiangling/TensorRT_Engine/quantilize.py", line 178, in main
    with G_LOGGER.verbosity(G_LOGGER.VERBOSE), TrtRunner(build_engine) as runner:
  File "/usr/local/lib/python3.10/dist-packages/polygraphy/backend/base/runner.py", line 60, in __enter__
    self.activate()
  File "/usr/local/lib/python3.10/dist-packages/polygraphy/backend/base/runner.py", line 95, in activate
    self.activate_impl()
  File "/usr/local/lib/python3.10/dist-packages/polygraphy/util/util.py", line 694, in wrapped
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/polygraphy/backend/trt/runner.py", line 90, in activate_impl
    engine_or_context, owning = util.invoke_if_callable(self._engine_or_context)
  File "/usr/local/lib/python3.10/dist-packages/polygraphy/util/util.py", line 663, in invoke_if_callable
    ret = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/polygraphy/backend/base/loader.py", line 40, in __call__
    return self.call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/polygraphy/util/util.py", line 694, in wrapped
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/polygraphy/backend/trt/loader.py", line 617, in call_impl
    return engine_from_bytes(super().call_impl, runtime=self._runtime)
  File "<string>", line 3, in engine_from_bytes
  File "/usr/local/lib/python3.10/dist-packages/polygraphy/backend/base/loader.py", line 40, in __call__
    return self.call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/polygraphy/util/util.py", line 694, in wrapped
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/polygraphy/backend/trt/loader.py", line 646, in call_impl
    buffer, owns_buffer = util.invoke_if_callable(self._serialized_engine)
  File "/usr/local/lib/python3.10/dist-packages/polygraphy/util/util.py", line 663, in invoke_if_callable
    ret = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/polygraphy/util/util.py", line 694, in wrapped
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/polygraphy/backend/trt/loader.py", line 550, in call_impl
    G_LOGGER.critical("Invalid Engine. Please ensure the engine was built correctly")
  File "/usr/local/lib/python3.10/dist-packages/polygraphy/logger/logger.py", line 597, in critical
    raise PolygraphyException(message) from None
polygraphy.exception.exception.PolygraphyException: Invalid Engine. Please ensure the engine was built correctly

It seems TensorRT 8.5 and 8.6 will only use the optimal shape to conduct calibration? Why? Is there any solution?

@yjiangling
Copy link
Author

@lix19937 May I have your help? Thank you very much!

@lix19937
Copy link

[!] Received an unexpected input from the data loader during calibration. For input: 'xs', expected a shape compatible with: (1, 160000), but received: (1, 41152)

Usually dynamic shape calib use opt shape, when you use cpp IInt8EntropyCalibrator2 or IInt8MinMaxCalibrator.

BTW, NV think un-opt shape is not allowed.
https://github.com/NVIDIA/TensorRT/blob/release/8.5/tools/Polygraphy/CHANGELOG.md#fixed

@yjiangling
Copy link
Author

yjiangling commented Jun 17, 2024

[!] Received an unexpected input from the data loader during calibration. For input: 'xs', expected a shape compatible with: (1, 160000), but received: (1, 41152)

Usually dynamic shape calib use opt shape, when you use cpp IInt8EntropyCalibrator2 or IInt8MinMaxCalibrator.

BTW, NV think un-opt shape is not allowed. https://github.com/NVIDIA/TensorRT/blob/release/8.5/tools/Polygraphy/CHANGELOG.md#fixed

Thanks a lot for the help. I see, this error is related to the version of polygraphy, instaed of TensorRT. But with the version of 0.33.0 polygraphy, it works fine with dynamic input shape for calibration, and both the version of 0.43.1 and 0.47.1 polygraphy will raise the error, but the CHANGELOG said the bug have been fixed since version 0.44.2 ("Fixed a bug where the calibrator would not accept inputs with a shape other than the OPT shape set in the profile." ), it quite strange...

@lix19937
Copy link

Check the impl of function of get_batch in backend/trt/calibrator.py, for example, as polygraphy v0.33.0, https://github.com/NVIDIA/TensorRT/blob/release/8.0/tools/Polygraphy/polygraphy/backend/trt/calibrator.py#L118 has no check about of input shapes.

but for polygraphy v0.42.1,
https://github.com/NVIDIA/TensorRT/blob/release/8.4/tools/Polygraphy/polygraphy/backend/trt/calibrator.py

                if not util.is_valid_shape_override(buffer.shape, expected_shape):
                    G_LOGGER.critical(
                        err_prefix
                        + f"For input: '{name}', expected a shape compatible with: {expected_shape}, but received: {buffer.shape}"
                    )

If you overwrite Calibrator, there is no limit about the input shapes except the batch size.

@lix19937
Copy link

lix19937 commented Jun 18, 2024

BTW: NV constraint

To run INT8 calibration for a network with dynamic shapes, a calibration optimization profile must be set. Calibration is performed using kOPT values of the profile. Calibration input data size must match this profile.

@yjiangling
Copy link
Author

Check the impl of function of get_batch in backend/trt/calibrator.py, for example, as polygraphy v0.33.0, https://github.com/NVIDIA/TensorRT/blob/release/8.0/tools/Polygraphy/polygraphy/backend/trt/calibrator.py#L118 has no check about of input shapes.

but for polygraphy v0.42.1, https://github.com/NVIDIA/TensorRT/blob/release/8.4/tools/Polygraphy/polygraphy/backend/trt/calibrator.py

                if not util.is_valid_shape_override(buffer.shape, expected_shape):
                    G_LOGGER.critical(
                        err_prefix
                        + f"For input: '{name}', expected a shape compatible with: {expected_shape}, but received: {buffer.shape}"
                    )

If you overwrite Calibrator, there is no limit about the input shapes except the batch size.

Oh,I see,got it,Your kindness is so appreciated. Thank you ever so much!

@yjiangling
Copy link
Author

@lix19937 Sorry bother again. After overwrite Calibrator in TensorRT 8.6, the calibration can be successfully done and generate the calibration cache file, but can not build the quantilized tensorrt engine, give the following error:

[W] onnx2trt_utils.cpp:374: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[I] Configuring with profiles:[
        Profile 0:
            {xs [min=[1, 1120], opt=[1, 160000], max=[1, 480000]],
             xlen [min=[1], opt=[1], max=[1]]}
    ]
[I] Using calibration profile: {xs [min=[1, 1120], opt=[1, 160000], max=[1, 480000]],
     xlen [min=[1], opt=[1], max=[1]]}
[W] TensorRT does not currently support using dynamic shapes during calibration. The `OPT` shapes from the calibration profile will be used for tensors with dynamic shapes. Calibration data is expected to conform to those shapes. 
[I] Building engine with configuration:
    Flags                  | [INT8]
    Engine Capability      | EngineCapability.DEFAULT
    Memory Pools           | [WORKSPACE: 10002.44 MiB, TACTIC_DRAM: 10002.44 MiB]
    Tactic Sources         | [CUBLAS, CUBLAS_LT, CUDNN, EDGE_MASK_CONVOLUTIONS, JIT_CONVOLUTIONS]
    Profiling Verbosity    | ProfilingVerbosity.DETAILED
    Preview Features       | [FASTER_DYNAMIC_SHAPES_0805, DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805]
    Calibrator             | Calibrator(DataLoader(seed=1, iterations=1, int_range=(1, 25), float_range=(-1.0, 1.0), val_range=(0.0, 1.0), data_loader_backend_module='numpy'), cache='trt86_minmax_calib.cache', BaseClass=<class 'tensorrt.tensorrt.IInt8MinMaxCalibrator'>)
[I] Loading calibration cache from trt86_minmax_calib.cache
[W] Missing scale and zero-point for tensor xlen, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor (Unnamed Layer* 1) [Shuffle]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor sub_3:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor (Unnamed Layer* 9) [Shuffle]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor floordiv_1:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor frame/range_1:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor (Unnamed Layer* 59) [Shuffle]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor frame/mul_1:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor frame/Reshape_2:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor (Unnamed Layer* 70) [Constant]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor frame/add_2:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor (Unnamed Layer* 74) [Shuffle]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor add:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor (Unnamed Layer* 77) [Shuffle]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor subsamp/sub:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor (Unnamed Layer* 80) [Shuffle]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor subsamp/floordiv:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor SequenceMask_1/ExpandDims:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor SequenceMask_1/Cast:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor SequenceMask/ExpandDims:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor SequenceMask/Cast:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor (Unnamed Layer* 116) [Constant]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor (Unnamed Layer* 121) [Constant]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor (Unnamed Layer* 144) [Constant]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor (Unnamed Layer* 148) [Constant]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor (Unnamed Layer* 151) [Constant]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor (Unnamed Layer* 164) [Constant]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor SequenceMask/Range:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor SequenceMask/Less:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor encoder/range:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor encoder/ExpandDims:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor SequenceMask_1/Less:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor (Unnamed Layer* 246) [Constant]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor (Unnamed Layer* 274) [Constant]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor encoder/Tile:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor (Unnamed Layer* 283) [Constant]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor (Unnamed Layer* 295) [Constant]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor (Unnamed Layer* 303) [Constant]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor (Unnamed Layer* 337) [Constant]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor (Unnamed Layer* 574) [Constant]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor encoder/enc_layer_0/encoder_layer/mha/sdpa/Softmax:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor encoder/enc_layer_1/encoder_layer/mha/sdpa/Softmax:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor encoder/enc_layer_2/encoder_layer/mha/sdpa/Softmax:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor encoder/enc_layer_3/encoder_layer/mha/sdpa/Softmax:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor encoder/enc_layer_4/encoder_layer/mha/sdpa/Softmax:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor encoder/enc_layer_5/encoder_layer/mha/sdpa/Softmax:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor (Unnamed Layer* 2949) [Softmax]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor (Unnamed Layer* 2950) [Unary]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
formats.cpp:2379: DCHECK(desired_so.size() == t->dim_count()) failed. 
[E] 10: Could not find any implementation for node {ForeignNode[frame/Reshape_3__10...Tensordot/Reshape]}.
[E] 10: [optimizer.cpp::computeCosts::3869] Error Code 10: Internal Error (Could not find any implementation for node {ForeignNode[frame/Reshape_3__10...Tensordot/Reshape]}.)
[!] Invalid Engine. Please ensure the engine was built correctly

Why the model can be convert to fp32 tensorrt engine successfully, but failed for int8 quantilization engine? The int8 inference was not implemented in TensorRT8.6 for Tensordot/Reshape node? But why it works fine in TensorRT8.5 and 8.4?

@lix19937
Copy link

Why the model can be convert to fp32 tensorrt engine successfully, but failed for int8 quantilization engine?

There are many reasons.

I suggest you first use the opt-shape to calib with trt8.6, then generate a plan.

@yjiangling
Copy link
Author

Why the model can be convert to fp32 tensorrt engine successfully, but failed for int8 quantilization engine?

There are many reasons.

I suggest you first use the opt-shape to calib with trt8.6, then generate a plan.

Thanks a lot for the suggestion, but it still give the same error message even use the opt-shape to calib with trt86. It looks like quite complicated...

[V] Graph optimization time: 0.996886 seconds.
[V] Global timing cache in use. Profiling results in this builder pass will be stored.
formats.cpp:2379: DCHECK(desired_so.size() == t->dim_count()) failed. 
[E] 10: Could not find any implementation for node {ForeignNode[frame/Reshape_3:0...Tensordot/Reshape]}.
[E] 10: [optimizer.cpp::computeCosts::3869] Error Code 10: Internal Error (Could not find any implementation for node {ForeignNode[frame/Reshape_3:0...Tensordot/Reshape]}.)
[!] Invalid Engine. Please ensure the engine was built correctly
Traceback (most recent call last):
  File "/media/tcl1/ASR/users/yujiangling/TensorRT_Engine/quantilize.py", line 197, in <module>
    main()
  File "/media/tcl1/ASR/users/yujiangling/TensorRT_Engine/quantilize.py", line 183, in main
    with G_LOGGER.verbosity(G_LOGGER.VERBOSE), TrtRunner(build_engine) as runner:
  File "/usr/local/lib/python3.10/dist-packages/polygraphy/backend/base/runner.py", line 60, in __enter__
    self.activate()
  File "/usr/local/lib/python3.10/dist-packages/polygraphy/backend/base/runner.py", line 95, in activate
    self.activate_impl()
  File "/usr/local/lib/python3.10/dist-packages/polygraphy/util/util.py", line 694, in wrapped
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/polygraphy/backend/trt/runner.py", line 90, in activate_impl
    engine_or_context, owning = util.invoke_if_callable(self._engine_or_context)
  File "/usr/local/lib/python3.10/dist-packages/polygraphy/util/util.py", line 663, in invoke_if_callable
    ret = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/polygraphy/backend/base/loader.py", line 40, in __call__
    return self.call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/polygraphy/util/util.py", line 694, in wrapped
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/polygraphy/backend/trt/loader.py", line 617, in call_impl
    return engine_from_bytes(super().call_impl, runtime=self._runtime)
  File "<string>", line 3, in engine_from_bytes
  File "/usr/local/lib/python3.10/dist-packages/polygraphy/backend/base/loader.py", line 40, in __call__
    return self.call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/polygraphy/util/util.py", line 694, in wrapped
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/polygraphy/backend/trt/loader.py", line 646, in call_impl
    buffer, owns_buffer = util.invoke_if_callable(self._serialized_engine)
  File "/usr/local/lib/python3.10/dist-packages/polygraphy/util/util.py", line 663, in invoke_if_callable
    ret = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/polygraphy/util/util.py", line 694, in wrapped
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/polygraphy/backend/trt/loader.py", line 550, in call_impl
    G_LOGGER.critical("Invalid Engine. Please ensure the engine was built correctly")
  File "/usr/local/lib/python3.10/dist-packages/polygraphy/logger/logger.py", line 597, in critical
    raise PolygraphyException(message) from None
polygraphy.exception.exception.PolygraphyException: Invalid Engine. Please ensure the engine was built correctly

@lix19937
Copy link

Does trtexec --int8 --onnx=${onnx_file} --verbose passed ? You can try it.

@Egorundel
Copy link

Egorundel commented Jul 29, 2024

@lix19937 What is the calibration_data.cache file? And why is it needed at all?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants