-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ONNX networks can't use INT8 calibration and batching #289
Comments
Hi @gcp, Sorry for the delay, I'm on holiday and was hoping to do this in my free time but it's still been a busy holiday 😅 I made a little sample workflow to demonstrate how I believe this works. 1. Export trained model to 2 ONNX models (one fixed batch, one dynamic batch)I tweaked the Alexnet demo from here: https://pytorch.org/docs/stable/onnx.html import torch
import torchvision
dummy_input = torch.randn(10, 3, 224, 224, device='cuda')
model = torchvision.models.alexnet(pretrained=True).cuda()
input_names = [ "actual_input_1" ] #+ [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]
# Fixed Shape
torch.onnx.export(model, dummy_input, "alexnet_fixed.onnx", verbose=True, opset_version=11,
input_names=input_names, output_names=output_names)
# Dynamic Shape
dynamic_axes = dict(zip(input_names, [{0:'batch_size'} for i in range(len(input_names))]))
print(dynamic_axes)
torch.onnx.export(model, dummy_input, "alexnet_dynamic.onnx", verbose=True, opset_version=11,
input_names=input_names, output_names=output_names,
dynamic_axes=dynamic_axes) 2. Do INT8 calibration on fixed shape model and save calibration cacheThis is based on code from here on 20.01 branch: https://github.com/rmccorm4/tensorrt-utils/blob/20.01/classification/imagenet/onnx_to_tensorrt.py # Fixed batch model
$ python onnx_to_tensorrt.py --fp16 --int8 \
--calibration-cache=alexnet.cache \
--calibration-data=/imagenet/val \
--preprocess_func=preprocess_imagenet \
--explicit-batch \
--onnx=../../../alexnet_fixed.onnx
2019-12-30 00:20:15 - __main__ - INFO - TRT_LOGGER Verbosity: Severity.ERROR
2019-12-30 00:20:15 - __main__ - INFO - Using FP16 build flag
2019-12-30 00:20:15 - __main__ - INFO - Using INT8 build flag
2019-12-30 00:20:15 - utils - INFO - Collecting calibration files from: /imagenet/val
2019-12-30 00:20:21 - utils - INFO - Number of Calibration Files found: 50000
2019-12-30 00:20:21 - utils - WARNING - Capping number of calibration images to max_calibration_size: 512
2019-12-30 00:20:22 - __main__ - DEBUG - network.get_input(0).shape = (10, 3, 224, 224)
2019-12-30 00:20:22 - __main__ - DEBUG - network.get_input(0).name = actual_input_1
2019-12-30 00:20:22 - __main__ - INFO - Explicit batch size is fixed (10), creating one optimization profile...
2019-12-30 00:20:22 - __main__ - INFO - Optimization profile: Min(10, 3, 224, 224), Opt(10, 3, 224, 224), Max(10, 3, 224, 224)
2019-12-30 00:20:22 - __main__ - INFO - Building Engine...
2019-12-30 00:20:25 - ImagenetCalibrator - INFO - Calibration images pre-processed: 32/512
2019-12-30 00:20:26 - processing - DEBUG - Received grayscale image. Reshaped to (3, 224, 224)
2019-12-30 00:20:27 - ImagenetCalibrator - INFO - Calibration images pre-processed: 64/512
2019-12-30 00:20:28 - ImagenetCalibrator - INFO - Calibration images pre-processed: 96/512
2019-12-30 00:20:30 - ImagenetCalibrator - INFO - Calibration images pre-processed: 128/512
2019-12-30 00:20:30 - processing - DEBUG - Received grayscale image. Reshaped to (3, 224, 224)
2019-12-30 00:20:31 - ImagenetCalibrator - INFO - Calibration images pre-processed: 160/512
2019-12-30 00:20:32 - ImagenetCalibrator - INFO - Calibration images pre-processed: 192/512
2019-12-30 00:20:32 - processing - DEBUG - Received grayscale image. Reshaped to (3, 224, 224)
2019-12-30 00:20:32 - processing - DEBUG - Received grayscale image. Reshaped to (3, 224, 224)
2019-12-30 00:20:32 - processing - DEBUG - Received grayscale image. Reshaped to (3, 224, 224)
2019-12-30 00:20:33 - ImagenetCalibrator - INFO - Calibration images pre-processed: 224/512
2019-12-30 00:20:33 - processing - DEBUG - Received grayscale image. Reshaped to (3, 224, 224)
2019-12-30 00:20:34 - ImagenetCalibrator - INFO - Calibration images pre-processed: 256/512
2019-12-30 00:20:36 - ImagenetCalibrator - INFO - Calibration images pre-processed: 288/512
2019-12-30 00:20:37 - ImagenetCalibrator - INFO - Calibration images pre-processed: 320/512
2019-12-30 00:20:38 - processing - DEBUG - Received grayscale image. Reshaped to (3, 224, 224)
2019-12-30 00:20:38 - ImagenetCalibrator - INFO - Calibration images pre-processed: 352/512
2019-12-30 00:20:40 - ImagenetCalibrator - INFO - Calibration images pre-processed: 384/512
2019-12-30 00:20:41 - ImagenetCalibrator - INFO - Calibration images pre-processed: 416/512
2019-12-30 00:20:41 - processing - DEBUG - Received grayscale image. Reshaped to (3, 224, 224)
2019-12-30 00:20:42 - processing - DEBUG - Received grayscale image. Reshaped to (3, 224, 224)
2019-12-30 00:20:42 - ImagenetCalibrator - INFO - Calibration images pre-processed: 448/512
2019-12-30 00:20:44 - ImagenetCalibrator - INFO - Calibration images pre-processed: 480/512
2019-12-30 00:20:45 - ImagenetCalibrator - INFO - Calibration images pre-processed: 512/512
2019-12-30 00:20:45 - ImagenetCalibrator - INFO - Caching calibration data for future use: alexnet.cache <-------- # Calibration cache saved here
2019-12-30 00:20:55 - __main__ - INFO - Writing engine to model.engine 3. Use saved calibration cache on dynamic shape model to create int8 dynamic engine# Dynamic batch model
$ python onnx_to_tensorrt.py --fp16 --int8 \
--calibration-cache=alexnet.cache \
--calibration-data=/imagenet/val \
--preprocess_func=preprocess_imagenet \
--explicit-batch \
--onnx=../../../alexnet_dynamic.onnx
2019-12-30 00:27:57 - __main__ - INFO - TRT_LOGGER Verbosity: Severity.ERROR
2019-12-30 00:27:58 - __main__ - INFO - Using FP16 build flag
2019-12-30 00:27:58 - __main__ - INFO - Using INT8 build flag
2019-12-30 00:27:58 - __main__ - INFO - Skipping calibration files, using calibration cache: alexnet.cache
2019-12-30 00:27:58 - __main__ - DEBUG - network.get_input(0).shape = (-1, 3, 224, 224)
2019-12-30 00:27:58 - __main__ - DEBUG - network.get_input(0).name = actual_input_1
2019-12-30 00:27:58 - __main__ - INFO - Explicit batch size is dynamic (-1), creating several optimization profiles...
2019-12-30 00:27:58 - __main__ - INFO - Optimization profile: Min(1, 3, 224, 224), Opt(1, 3, 224, 224), Max(1, 3, 224, 224)
2019-12-30 00:27:58 - __main__ - INFO - Optimization profile: Min(2, 3, 224, 224), Opt(2, 3, 224, 224), Max(2, 3, 224, 224)
2019-12-30 00:27:58 - __main__ - INFO - Optimization profile: Min(4, 3, 224, 224), Opt(4, 3, 224, 224), Max(4, 3, 224, 224)
2019-12-30 00:27:58 - __main__ - INFO - Optimization profile: Min(8, 3, 224, 224), Opt(8, 3, 224, 224), Max(8, 3, 224, 224)
2019-12-30 00:27:58 - __main__ - INFO - Optimization profile: Min(16, 3, 224, 224), Opt(16, 3, 224, 224), Max(16, 3, 224, 224)
2019-12-30 00:27:58 - __main__ - INFO - Optimization profile: Min(32, 3, 224, 224), Opt(32, 3, 224, 224), Max(32, 3, 224, 224)
2019-12-30 00:27:58 - __main__ - INFO - Building Engine...
2019-12-30 00:27:59 - ImagenetCalibrator - INFO - Using calibration cache to save time: alexnet.cache
2019-12-30 00:27:59 - ImagenetCalibrator - INFO - Using calibration cache to save time: alexnet.cache
2019-12-30 00:28:41 - __main__ - INFO - Writing engine to model.engine 4. Smoke test inference on TRT engineI'm hitting an OOM error on this part and don't have time to debug right now, but hopefully this helps. |
This solution does not work with the C++ implementation. Is there a fix upcoming? |
It does work, the prototype for the calibrator stays the same, i.e.
but, and this is IMHO the real issue: NVIDIA should really have fixed the INT8 calibrators to either to use explicit batch sizes, or (vastly preferable for having backwards compatibility which was now totally broken), never have removed implicit batch size support from the ONNX parser. |
These are not documented for the NVIDIA-provided calibrators. For example, there is no documentation for what |
Writing out the cache is what your code has to do. You get passed a byte buffer once calibration is done, and the contents of it are not your problem nor business, you just have to get it on persistent storage (i.e. disk). You don't have to fill in Then, upon a next startup, TensorRT calls readCalibrationCache and you have to read in the file you previously wrote, returning a pointer to it (the |
That is not correct. The default method of constructing a network ( The NVIDIA provided calibrator ( |
notice how this derives from https://docs.nvidia.com/deeplearning/sdk/tensorrt-api/c_api/classnvinfer1_1_1_i_int8_calibrator.html which defines exactly that. The derived interface also provides the methods of the base class, that's how inheritance works. So it is definitely documented. As for your "the calibration buffers are not passed back to the user", this is simply wrong. Or you're hitting a bug. But I assure you it works like I described and your custom class, derived from |
It doesn't seem like you read my comment:
This is the recommended method of creating an engine with a custom calibration as described in the examples. You can in fact see from the documentation that |
I assure you I did read your comment. I'm not clear what part of my explanation is not getting through. The buffers are passed back through the |
Sure, if you have the cache you can write it out. The provided calibrators do not have a method of obtaining the cache in the first place. My comment above should have been "there is not provided method of obtaining the calibration cache." |
Hi @gcp, There was a bug with the calibration code regarding implicit/explicit batch handling. It was fixed for the next release. |
Hi @rmccorm4 , thanks a lot for your answer. But I met some problem when I am following your instructions above. root@26d9b8a5f596:/mnt/current/classification/imagenet# python onnx_to_tensorrt.py --fp16 --int8 --calibration-cache=c
aches/alexnet.cache --calibration-data=/dataset/coco/val2017 --preprocess_func=preprocess_imagenet --explicit-batch --onnx=models/alexnet_fixed.onnx
2020-01-22 02:44:09 - __main__ - INFO - TRT_LOGGER Verbosity: Severity.ERROR
2020-01-22 02:44:10 - __main__ - INFO - Using FP16 build flag
2020-01-22 02:44:10 - __main__ - INFO - Using INT8 build flag
2020-01-22 02:44:10 - ImagenetCalibrator - INFO - Collecting calibration files from: /dataset/coco/val2017
2020-01-22 02:44:11 - ImagenetCalibrator - INFO - Number of Calibration Files found: 5000
2020-01-22 02:44:11 - ImagenetCalibrator - WARNING - Capping number of calibration images to max_calibration_size: 512
2020-01-22 02:44:12 - __main__ - ERROR - Failed to parse model. I export pytorch model to tensorrt with pytorch version of 1.3.1. Could you please help me out. |
This fix does not work in for the C++ TensorRT regardless. |
Hi @ShawnNew, I'm assuming you're running the code on the 19.10 branch. I'm not at a computer right now so I'll fix this later, but in case you're in a different time zone: Can you change the parsing syntax to do the error check in a for loop like here and share the outputs for why parsing failed: https://github.com/rmccorm4/tensorrt-utils/blob/10238af41f7b9de57b7a649a630f3d07f3f0cec5/classification/imagenet/onnx_to_tensorrt.py#L142 |
Hi @rmccorm4 , Yes I am using 19.10 branch. 2020-01-22 07:47:10 - __main__ - INFO - TRT_LOGGER Verbosity: Severity.ERROR
2020-01-22 07:47:11 - __main__ - INFO - Using FP16 build flag
2020-01-22 07:47:11 - __main__ - INFO - Using INT8 build flag
2020-01-22 07:47:11 - ImagenetCalibrator - INFO - Collecting calibration files from: /dataset/coco/val2017
2020-01-22 07:47:14 - ImagenetCalibrator - INFO - Number of Calibration Files found: 5000
2020-01-22 07:47:14 - ImagenetCalibrator - WARNING - Capping number of calibration images to max_calibration_size: 512
ERROR: Failed to parse the ONNX file: models/alexnet_fixed.onnx
In node 0 (importModel): INVALID_GRAPH: Assertion failed: tensors.count(input_name) |
If you're using the same alexnet code as above, try PyTorch 1.3.0, I think that's what I used at the time. I also think I was using onnx==1.6.0 if that makes a difference. Also I've noticed some pytorch models have been producing weird onnx graphs with PyTorch 1.3.*, and I think some of those issues might have been fixed in PyTorch 1.4 - though I don't know if that will work for TensorRT 6 or not, do let me know. I believe your error is basically the same as this failing: import onnx
model = onnx.load("model.onnx")
onnx.checker.check(model) |
Hey @ShawnNew, 1. TensorRT >= 6.0 + PyTorch >= 1.3
I believe this error is a known issue with TRT 6 + PyTorch >= 1.3 and was fixed in the upstream OSS ONNX parser. If you build the OSS ONNX parser, you should be able parse it: # Start TensorRT 6 Container
nvidia-docker run -it -v ${PWD}:/mnt --workdir=/mnt nvcr.io/nvidia/tensorrt:19.12-py3
# Install dependencies
pip install torch==1.4 torchvision==0.5
# Download test model code
wget https://gist.githubusercontent.com/rmccorm4/b72abac18aed6be4c1725db18eba4930/raw/fbdb009152ef54abe7d8b23a9fd57a0f250e03e4/alexnet_onnx.py
# Create test model
python alexnet_onnx.py --opset=11
# Parse test model
trtexec --onnx=alexnet_fixed.onnx --explicitBatch
...
In node 0 (importModel): INVALID_GRAPH: Assertion failed: tensors.count(input_name)
&&&& FAILED TensorRT.trtexec # trtexec --onnx=alexnet_fixed.onnx --explicitBatch If in 19.12 container run bash /opt/tensorrt/install_opensource.sh Then try parsing the model again: trtexec --onnx=alexnet_fixed.onnx --explicitBatch
...
&&&& PASSED TensorRT.trtexec # trtexec --onnx=alexnet_fixed.onnx --explicitBatch 2. TensorRT 6.0 + PyTorch == 1.2Alternatively, in some situations (like this one) you can just use pytorch==1.2 instead and it should work: # Start TensorRT 6 Container
nvidia-docker run -it -v ${PWD}:/mnt --workdir=/mnt nvcr.io/nvidia/tensorrt:19.12-py3
# Install dependencies
pip install torch==1.4 torchvision==0.5
# Download test model code
wget https://gist.githubusercontent.com/rmccorm4/b72abac18aed6be4c1725db18eba4930/raw/fbdb009152ef54abe7d8b23a9fd57a0f250e03e4/alexnet_onnx.py
# Note --opset=10
python alexnet_onnx.py --opset=10
# Note no --explicitBatch flag
trtexec --onnx=alexnet_fixed.onnx
...
&&&& PASSED TensorRT.trtexec # trtexec --onnx=alexnet_fixed.onnx |
Hi @rmccorm4 [TensorRT] ERROR: engine.cpp (529) - Cuda Error in commonEmitTensor: 1 (invalid argument)
[TensorRT] ERROR: FAILED_ALLOCATION: std::exception
[TensorRT] ERROR: ../rtSafe/cuda/caskConvolutionRunner.cpp (334) - Cuda Error in execute: 1 (invalid argument)
[TensorRT] ERROR: FAILED_EXECUTION: std::exception But solved this when I change I think this explicit batch bug should be fixed because it can be confusing. |
Hi @gcp , I met a problem and found it may be related to your comment.
I am just wondering:
|
I have no idea about 1), I didn't write onnx2trt. As for (2), I'm not sure how to explain it better than I already did. If you enable INT8 mode, TensorRT will call the Next launch, TensorRT will call (Technically I think I'm just repeating what I wrote here: #289 (comment) The point I was trying to get across (and clearly failed to do) is that you don't make a call to get the calibration cache from TensorRT. TensorRT will call your code when it has finished calibrating. |
Did you actually check, if the resulting calibration is usable? Because I had to force batch size to 1 myself, to avoid critical errors but didn't end up getting useful results (horrible accuracy decrease). The difference in my case is that I am actually working with a fixed-batch model all the way (using FOLLOW UP: Interestingly, near-zero values in the output tensor get well preserved, while higher values are consistently too low with the INT8 model. Almost like relative values are well preserved but absolute values aren't. Will try with different calibrators than the |
Yes, I'm using this in production. I'm struggling with a long-standing bug in TensorRT that using many examples for calibration decreases the accuracy (my best guess is that NVIDIA's code is overflowing), but with a limited set of say 500 images, my results are good. |
In my particular case, |
Well, my issue is back. Due to decent speed-up, I've started doing input data normalization on GPU as well. This results in the first few operations in my ONNX network being
The |
Me too. This error can be solved by using "y = y.flatten(1)",but it still fails to calibrate INT8 precision with the error "Tensorrt Assertion failed: q > p ../builder/cudnn Calibrator.cpp:584". |
The error "Tensorrt Assertion failed: q > p ../builder/cudnn Calibrator.cpp:584" is solved by checking calibration dataset. So it works well. |
@gcp Hello, I'm doing int8 calibration on an onnx model with C++. getBatchSize(), readCalibrationCache() and writeCalibrationCache() was called, but getBatch() was not. I tried .batch and .ppm file for calibration but all failed. Do you ever came up with similar problems? And can we directly use .jpg .png file for calibration or should we implement it ourselves? Thank you. |
@gcp Hello, I'm doing int8 calibration on an onnx model. getBatchSize(), readCalibrationCache() and writeCalibrationCache() was called, but getBatch() was not. I tried .batch and .ppm file for calibration but all failed. Do you ever came up with similar problems? And can we directly use .jpg .png file for calibration or should we implement it ourselves? Thank you. |
@simejanko @Peppa-cs Hi, I met the same error and i am confused with what is the cause behind this. How did you solve the I was trying to do int8 calibration from onnx model file. I use trt version: 7.0 command:
error info:
Appreciate it a lot. |
Maybe you can try to replace 'y = y.view(size(0),-1)' to 'y = y.flatten(1)'. It works for me. |
Notice your output shape is '(-1, -1, -1, -1)'. It seems that trt doesn't support dynamic shape. |
Hi @gcp, With the release of TensorRT 7.1, hopefully this issue will be resolved now. Feel free to update this issue if you're still having issues with TRT 7.1 + ONNX + INT8 Calibration. NOTE: My |
Hi @rmccorm4, Is calibration-batch-size the same as network.get_input(0)[0]? Models download from ONNX Model Zoo have the fixed batch size, so profiles for different batches won't be generated in create_optimization_profiles(). But it still calibrates the models with the default batch size when I don't set calibration-batch-size. In the situation, only the first data is calculated? Histograms of activations are generated wrong? Thanks |
Hi @Y-U-KAY , I'm not 100% sure on this, but I believe it's something like this: For an EXPLICIT_BATCH ONNX model, it will use the batch size of the model. Meaning if it's a fixed batch size, it will use that batch size . If it's a dynamic batch size (-1), then it will use the batch size specified in the kOPT shape of the calibration profile. If no calibration profile is specified, I believe it will default to the first optimization profile. |
Hi @rmccorm4, Thanks for your reply. It's a fixed batch size, and the network desciption as shown below.
I got these messages.
According to what you said, even if 32 images are loaded onto the GPU, only 1 image can be processed because of the fixed batch size of the calibration profile? Thanks. |
Hi @Y-U-KAY , Assuming that's my INT8 script you're running (based on the log outputs), I haven't updated it much since before CalibrationProfiles were implemented. I'm not sure what the expected behavior is, but the 32 there is probably hard coded in my script as a default batch_size value or something. You can try experimenting by removing that and seeing what happens by default. |
My issue was resolved by upgrading to version 7.1.3
This issue was also resolved for me by upgrading to version 7.1.3. I'm using an explicit batch size model & can now do >1 batch sizes for INT8 calibration. Am I risking something by having batch size for INT8 calibration, that's different than the batch size in my explicit batch model? There were no warnings when I ran it. |
Seems there is no remaining issue in this thread, and start from 7.2 we have I will close this, please reopen if we still have issue, thanks all! |
Description
This is due to mutually incompatible changes in the TRT7 release:
https://docs.nvidia.com/deeplearning/sdk/tensorrt-release-notes/tensorrt-7.html
versus
This means the ONNX network must be exported at a fixed batch size in order to get INT8 calibration working, but now it's no longer possible to specify the batch size. I also verified that manually fixing up the inputs with setDimensions(...-1...) does not work, you will hit an assertion
mg.nodes[mg.regionIndices[outputRegion]].size ==mg.nodes[mg.regionIndices[inputRegion]].size
while building.One would think there might be sort of a workaround by exporting two different networks, one with a fixed batch size and a second one with a dynamic_axis, and then using the calibration from one for the other.
However, even here there are severe pitfalls: a calibration cache that is generated for, say, batch_size=1 won't necessarily work for larger batch sizes, presumably because they will generate a different convolution strategy that causes different accuracy issues.Edit: This might've been another issue.Lastly, the calibrator itself appears to be using implicit batch sizes, and breaks on batch size > 1 as follows:
TRT: Starting Calibration with batch size 16.
Calibrated 16 images.
TRT: Explicit batch network detected and batch size specified, use execute without batch size instead.
TRT: C:\source\builder\cudnnCalibrator.cpp (707) - Cuda Error in nvinfer1::builder::Histogram::add: 700 (an illegal memory access was encountered)
TRT: FAILED_ALLOCATION: Unknown exception
TRT: C:\source\builder\cudnnCalibrator.cpp (703) - Cuda Error in nvinfer1::builder::Histogram::add: 700 (an illegal memory access was encountered)
TRT: FAILED_ALLOCATION: Unknown exception
TRT: C:\source\rtSafe\cuda\caskConvolutionRunner.cpp (233) - Cuda Error in nvinfer1::rt::task::CaskConvolutionRunner::allocateContextResources: 700 (an illegal memory access was encountered)
TRT: FAILED_EXECUTION: Unknown exception
TRT: Calibrated batch 0 in 2.62865 seconds.
Cuda failure: 700
with batch_size == 1, it's also hitting assertions:
TRT: Explicit batch network detected and batch size specified, use execute without batch size instead.
TRT: Assertion failed: d.nbDims >= 1
C:\source\rtSafe\safeHelpers.cpp:419
Aborting...
The combination of all these failures means that you can't really use ONNX networks in INT8 mode, at least the "Using a fixed shape input to build the engine in the first pass" recommendation hits all kinds of internal assertions as you can see above.
Environment
TensorRT Version: 7.0.0.11
GPU Type: RTX 2080
Nvidia Driver Version: 441.22
CUDA Version: 10.2
CUDNN Version: 7.6.0.5
Operating System + Version: Windows 10
Python Version (if applicable): 3.6
TensorFlow Version (if applicable):
PyTorch Version (if applicable): 1.3 stable
Baremetal or Container (if container which image + tag): bare
Relevant Files
Steps To Reproduce
The text was updated successfully, but these errors were encountered: