Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WARNING:Your ONNX model has been generated with INT64 weights #2542

Closed
FrancescoSaverioZuppichini opened this issue Dec 12, 2022 · 23 comments
Assignees
Labels
triaged Issue has been triaged by maintainers

Comments

@FrancescoSaverioZuppichini

Description

Hello There!

I hope you are all doing well :)

There are other similar issues but no even one of them has a fix to this problem.

Tensorrt takes a lot of time casting IN64 to IN32 making it impossible to use in real life

my conversion code

import onnxruntime as ort
import torch
from torchvision.models import ConvNeXt_Small_Weights, convnext_small

torch.set_default_tensor_type("torch.FloatTensor")
torch.set_default_tensor_type("torch.cuda.FloatTensor")

model_name = "model.onnx"
# get the model and put it in half precision
model = convnext_small(ConvNeXt_Small_Weights.IMAGENET1K_V1).eval().half().cuda()

with torch.autocast("cuda", dtype=torch.float16):
    x = torch.randn(1, 3, 224, 224, device="cuda")
    # # Export the model
    torch.onnx.export(
        model,  # model being run
        x,  # model input (or a tuple for multiple inputs)
        model_name,  # where to save the model (can be a file or file-like object)
        opset_version=16,
        export_params=True,  # store the trained parameter weights inside the model file
        do_constant_folding=True,  # whether to execute constant folding for optimization
        input_names=["image"],  # the model's input names
        output_names=["output"],  # the model's output names
        dynamic_axes={
            "image": {0: "batch_size"},  # variable length axes
            "output": {0: "batch_size"},
        },
    )

# let's check
print("Checking")
x = torch.randn(1, 3, 224, 224, device="cuda")
ort_session = ort.InferenceSession(model_name, providers=["CUDAExecutionProvider"])
outputs = ort_session.run(None, {"image": x.cpu().numpy()})
print(outputs[0].shape, outputs[0].dtype)

The code takes around 3/4m to convert the weights, then it outputs the following

2022-12-05 13:55:19.402807667 [W:onnxruntime:Default, tensorrt_execution_provider.h:60 log] [2022-12-05 13:55:19 WARNING] TensorRT encountered issues when converting weights between types and that could affect accuracy.
2022-12-05 13:55:19.402832897 [W:onnxruntime:Default, tensorrt_execution_provider.h:60 log] [2022-12-05 13:55:19 WARNING] If this is not the desired behavior, please modify the weights or retrain with regularization to adjust the magnitude of the weights.
2022-12-05 13:55:19.402838637 [W:onnxruntime:Default, tensorrt_execution_provider.h:60 log] [2022-12-05 13:55:19 WARNING] Check verbose logs for the list of affected weights.
2022-12-05 13:55:19.402844777 [W:onnxruntime:Default, tensorrt_execution_provider.h:60 log] [2022-12-05 13:55:19 WARNING] - 8 weights are affected by this issue: Detected subnormal FP16 values.

Environment

I am using the latest nvidia container
TensorRT Version: 8.5.1
ONNX-TensorRT Version / Branch:
GPU Type: RTX 3090
Nvidia Driver Version:
CUDA Version: 11.7
CUDNN Version:
Operating System + Version:
Python Version (if applicable): 3.8
TensorFlow + TF2ONNX Version (if applicable): 1.13
PyTorch Version (if applicable):
Baremetal or Container (if container which image + tag):

Relevant Files

my onnx model ( a convnext)

link to drive

Steps To Reproduce

  1. Download onnx model
  2. Copy this code
x = torch.randn(1, 3, 224, 224, device="cuda")
ort_session = ort.InferenceSession("model.onnx", providers=[
                    (
                        "TensorrtExecutionProvider",
                        {
                                    'device_id': 0,

                            "trt_fp16_enable": True,
                            "trt_max_workspace_size": 2147483648,
                        },
                    ),
                    "CUDAExecutionProvider",
                ],)
outputs = ort_session.run(None, {"image": x.cpu().numpy()})
  1. run it with python
@zerollzeng
Copy link
Collaborator

8 weights are affected by this issue: Detected subnormal FP16 values.

Can you remove the half() in model = convnext_small(ConvNeXt_Small_Weights.IMAGENET1K_V1).eval().half().cuda() ? this warning means you model contain weight that exceed the FP16 range(e.g greater than 65504).

Your ONNX model has been generated with INT64 weights
most of the ONNX models has INT64 weights, and TRT only supports INT32 weights. Usually, this is not a problem because your weights are unlikely larger than max INT32(-2,147,483,648 to +2,147,483,647). so this warning can be ignored.

@zerollzeng zerollzeng self-assigned this Dec 13, 2022
@zerollzeng zerollzeng added the triaged Issue has been triaged by maintainers label Dec 13, 2022
@FrancescoSaverioZuppichini
Copy link
Author

Hi @zerollzeng thanks for the reply. I would like to keep my model in half-precision mode, any advice on how to achieve it while removing the error?

@zerollzeng
Copy link
Collaborator

you can retrain your model with FP16 so all the weights can have FP16 range, or add some normalization method on the subnormal weights.

@FrancescoSaverioZuppichini
Copy link
Author

Retraining is not possible. Is there a way to serve with tensorrt a model with mixed precision? My task is very simple, I have a model in PyTorch, I exported it in mixed precision and I want to run it with tensorrt, nothing really crazy. I am just evaluating the technology.

If I remove the .half will the model be still in mixed precision? Not sure if torch.autocast will make it sure mostly is still in float16.

One thing that I don't understand is where the INT64 weights are coming from, any idea?

Thanks a lot :)

@FrancescoSaverioZuppichini
Copy link
Author

hey @zerollzeng I am trying everything here. See microsoft/onnxconverter-common#251 and microsoft/onnxconverter-common#252 but they are not tensorrt related. Have you ever create a mixed precision model in onnx?

@zerollzeng
Copy link
Collaborator

zerollzeng commented Dec 20, 2022

Yes mixed precision is feasible, just export the onnx to FP32, and when you build the engine, you can enable FP16(you don't need to export FP16 model at first) and specify those subnormal layers to FP32.

@FrancescoSaverioZuppichini
Copy link
Author

@zerollzeng how do I do it? I used the functions from onnx doc but they are not working. I'd like to have an automatic way to find the subnormal layers. If you have experience with it, do you mind sharing a snippet?

@zerollzeng
Copy link
Collaborator

you can refer to https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#layer-level-control
API: https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_layer.html

If you are using trtexec, you can specify per-layer precision with --layerPrecisions and --precisionConstraints. see trtexec -h

@FrancescoSaverioZuppichini
Copy link
Author

@zerollzeng I think I didn't explain myself. I know this is not Tensorrt related, but ONNX related now so off topic.

So, nobody online has ever exported a pytorch model to mixed precision in ONNX and has documented the process. Thus, I was wondering if you know how to do it :)

@zerollzeng
Copy link
Collaborator

We don't support building engines that can follow the ONNX precision now. maybe they are in our roadmap but I'm not sure. @kevinch-nv Do you know about it?

For now I think you can only mark the per-layer precision manually with TRT API, I think it's possible to write some automated code to parse the original ONNX precision and mark them with TensorRT automatically, but I don't have one in my hand :-(.

@FrancescoSaverioZuppichini
Copy link
Author

@zerollzeng thanks for the reply :) Sorry for my noobiness but how does it work internally? Assuming I have an onnx graph with mixed precision nodes (float32 and float16), should that be enough for Tensorrt?

@zerollzeng
Copy link
Collaborator

Yes it's enough, for example you can use https://github.com/NVIDIA/TensorRT/tree/master/tools/onnx-graphsurgeon. you can write a python script to read the onnx and iterate over all operators. when you see it's precision is FP16, set the trt layer to FP16. then you enable both of the FP32 and FP16 precision. you will get a mixed precision trt engine finally.

@FrancescoSaverioZuppichini
Copy link
Author

Thanks @zerollzeng , so the current Tensorrt limitation is that it cannot accept onnx models with both FP16 and FP32 nodes and I have to export the model to FP32 and manually remap all the FP32 to FP16 once I've loaded the model in tensorrt?

In your experience, is it better to should run it everything in FP16 or mixed precision actually have an advantage? Looks like is super hard to achieve in real life

@zerollzeng
Copy link
Collaborator

In your experience, is it better to should run it everything in FP16 or mixed precision actually have an advantage? Looks like is super hard to achieve in real life

I would suggest exporting the onnx to FP32 and run it everything in FP16 with tensorrt first, it will offer you better performance and most likely the accuracy will just degrade a little(maybe become better 🥇 ). if the accuracy is not good(e.g. some layer outputs or weights will exceed FP16 range and you will see a warning in the log), then you can offload those layers back to FP32, it will give you a good balance between performance and accuracy.

You can also try INT8 if you want the best performance. we support QAT and PTQ for INT8 quantization.

@FrancescoSaverioZuppichini
Copy link
Author

but you can't do INT8 on an a GPU or can you?

@zerollzeng
Copy link
Collaborator

zerollzeng commented Jan 9, 2023

GPU with compute capability >= 6.1 support INT8, a quick way to check this is use trtexec: trtexec --onnx=model.onnx --int8 or trtexec --onnx=model.onnx --int8 --fp16(this will enable int8 and fp16 and thus provide the best performance)

@FrancescoSaverioZuppichini
Copy link
Author

damn, really? I didn't know that!

@hannah12356
Copy link

how did you get the onnx file?

@FrancescoSaverioZuppichini
Copy link
Author

hey @hannah12356 it's generated from the code on the first message on this issue :)

@wangjl1993
Copy link

I'm having a similar issue in YoloV5! I convert the yolov5s from '.pt' to '.engine' (FP32), their inference results are inconsistent.

ultralytics/yolov5#11255 (comment)
@zerollzeng @FrancescoSaverioZuppichini

@zerollzeng
Copy link
Collaborator

I'm closing this due to no left action.

@wangjl1993 I just saw someone is helping you in the yolov5 repo. if you find TRT accuracy issues please file a new issue against TensorRT OSS. Thanks!

FYI: we can't said the accuracy is bad until we test on some validation dataset, it might be due to fluctuation in some images.

@FrancescoSaverioZuppichini
Copy link
Author

@zerollzeng why this issue was closed?

@zerollzeng
Copy link
Collaborator

@FrancescoSaverioZuppichini please reopen it if you have any further questions :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

4 participants