Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Regression observed when using CUDA execution provider #20712

Open
krishung5 opened this issue May 17, 2024 · 15 comments
Open

[Performance] Regression observed when using CUDA execution provider #20712

krishung5 opened this issue May 17, 2024 · 15 comments
Assignees
Labels
ep:CUDA issues related to the CUDA execution provider

Comments

@krishung5
Copy link

krishung5 commented May 17, 2024

Describe the issue

We are seeing a regression when using onnxruntime with the CUDA execution provider starting from version 1.14.1. Before version 1.14.1, there was no regression. We also observe the regression in subsequent versions, including the latest version, 1.17.1.

To reproduce

pip install onnxruntime-gpu for different versions, and run the script below

import onnxruntime as ort
import numpy as np
import time

def run_inference(session, input_data):
    """
    Run a single inference and return the latency.
    """
    start_time = time.time()
    _ = session.run(None, input_data)
    return time.time() - start_time

def generate_dummy_img_input(batch_size=1):
    # use random data
    input_data = np.random.rand(1, 3, 224, 224).astype(np.float32)
    return {"gpu_0/data_0": input_data}

model_path = "/path/to/resnet50-1.2.onnx"
session = ort.InferenceSession(model_path, providers=['CUDAExecutionProvider'])
input_data = generate_dummy_img_input()

latencies = []
for _ in range(100):
    latency = run_inference(session, input_data)
    latencies.append(latency)
# print(f"Latencies (ms): {latencies}")
# Remove the first value for model warm-up, calculate the average of the rest
latencies = latencies[1:]
average_latency = sum(latencies) / len(latencies)
print(f"Average latency (ms): {average_latency * 1000}")

Sharing some numbers:
onnxruntime-gpu == 1.13.1

Average latency (ms): 3.55936060048113
Average latency (ms): 3.6028346630057904
Average latency (ms): 3.526743012245255
Average latency (ms): 3.408121340202563
Average latency (ms): 3.644895071935172

onnxruntime-gpu == 1.14.1

Average latency (ms): 4.1407864503186165
Average latency (ms): 4.213094711303711
Average latency (ms): 3.9710492798776342
Average latency (ms): 3.977019377429076
Average latency (ms): 3.976980845133464

onnxruntime-gpu == 1.17.1

Average latency (ms): 4.22009795603126
Average latency (ms): 4.130989614159169
Average latency (ms): 4.265643129445086
Average latency (ms): 4.415384446731721
Average latency (ms): 4.095142537897283

The latency increase from version 1.13.1 to version 1.14.1 is approximately 14.29%.
The latency increase from version 1.14.1 to version 1.17.1 is approximately 4.18%.

Note that when using the CPU execution provider, there is no regression. It only occurs when using the CUDA execution provider.

For a simpler reproduction, I’m using a resnet50 model. We observe a latency increase of more than 20x with our model:
onnxruntime-gpu == 1.13.1

Average:0.1200742244720459

onnxruntime-gpu == 1.16.3

Average:3.157811665534973

The regression can be observed using the C++ code as well.

Urgency

High

Platform

Linux

OS Version

22.04

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.17.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

CUDA

Execution Provider Library Version

CUDA 12.2 and 11.8

Model File

No response

Is this a quantized model?

No

@github-actions github-actions bot added the ep:CUDA issues related to the CUDA execution provider label May 17, 2024
@gedoensmax
Copy link
Contributor

Which cuDNN version are you using ?

@krishung5
Copy link
Author

@gedoensmax I am using cuDNN 8.7.0.84. Tried to use cuDNN 9 with onnxruntime-gpu 1.17.1 but it's finding cuDNN 8.

2024-05-17 19:57:13.917623151 [E:onnxruntime:Default, provider_bridge_ort.cc:1548 TryGetProviderInfo_CUDA] /onnxruntime_src/onnxruntime/core/session/provider_bridge_ort.cc:1209 onnxruntime::Provider& onnxruntime::ProviderLibrary::Get() [ONNXRuntimeError] : 1 : FAIL : Failed to load library libonnxruntime_providers_cuda.so with error: libcudnn.so.8: cannot open shared object file: No such file or directory

@krishung5
Copy link
Author

Hi team, I was wondering if we have any update on this issue?

@JackWeiw
Copy link

Hi team, I was wondering if we have any update on this issue?

Hello, do you have some idea about the performance degrassion? I have test the performance of onnxruntime 1.17,it's performance is even worse than torch2.0.1

@gedoensmax
Copy link
Contributor

@tianleiwu can you help out with this. My initial guess was that there might be regressions due to cuDNN shipping less kernels. But it looks like cuDNN version was the same across the different versions.

@gedoensmax
Copy link
Contributor

@krishung5 I would recommend trying to use a CUDA graph, that might help reducing the execution time for such small networks.

@JackWeiw
Copy link

@gedoensmax Sir, one thing i am confused is that if i install onnxruntime by pip install onnxruntime-gpu==1.17, would the onnxruntime package be the optimum one (i mean it will match the cuda-11.8 install on my machine and corresponding cublas cudnn librarys). Can you explain that, thanks a lot!

@gedoensmax
Copy link
Contributor

The default 1.17 shipment is with CUDA 11. To install onnxruntime with CUDA 12 there is a separate package. https://onnxruntime.ai/docs/install/#install-onnx-runtime-gpu-cuda-11x

@JackWeiw
Copy link

The default 1.17 shipment is with CUDA 11. To install onnxruntime with CUDA 12 there is a separate package. https://onnxruntime.ai/docs/install/#install-onnx-runtime-gpu-cuda-11x

OK, Thank you very much. Can you please take a look at this issue about dynamic quantize? There are some problem with dynamic quantize vicuna-7b model from fp16 to int8

@krishung5
Copy link
Author

Hi @pranavsharma, just wanted to follow up and see if we have any update on this, thank you!

@tianleiwu
Copy link
Contributor

I reproduced the issue with https://github.com/onnx/models/blob/main/validated/vision/classification/resnet/model/resnet50-v1-12.onnx in A100. The average latency (ms) output:
ORT 1.13.1: 2.98
ORT 1.14.0: 3.20
ORT 1.17.1: 3.20
So there is some regression from 1.13.1 to 1.14.0. I will take a look at the cause.

@tianleiwu
Copy link
Contributor

tianleiwu commented May 30, 2024

The root cause seems to be the change of default value of cudnn_conv_use_max_workspace from 0 to 1
in #13981.

The solution is to set the value to 0 for Resnet:

session = ort.InferenceSession(model_path, providers=[("CUDAExecutionProvider", {"cudnn_conv_use_max_workspace": '0'})])

For debugging, set an environment variable to limit the cudnn workspace in MiB could help:

CUDNN_CONV_WSCAP_DBG=128 python test.py

@gedoensmax, do you know why larger workspace causes performance drop in some convolution network (we've enabled conv algo tuning by default)?

@gedoensmax
Copy link
Contributor

@tianleiwu I just saw cone also tuning is now set to exhaustive search. This should guarantee the best possible perf, but usually using the heuristics is sufficient.
Could you capture and Nsight Systems trace with and without the limited workspace size ? I would like to confirm which kernels are used, it might no longer do a transformation from NCHW to NHWC to leverage tensor cores. It still surprises me why the exhaustive search did not pick that strategy.

@tianleiwu
Copy link
Contributor

@gedoensmax,

The Nsight trace files:
resnet_nsys.zip

@krishung5
Copy link
Author

@gedoensmax I think using cuda graph indeed helps with the performance. I wasn't able to run the model used by RIVA team due to the issue

onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : This session cannot use the graph capture feature as requested by the user as the model has control flow nodes which can't be supported byCUDAExecutionProvider

but with the resnet model, I'm seeing an approximate improvement of 19.18% in average latency.

ONNX 1.18 with CUDA Graph

Latencies (ms):
2.595525799375592
2.116176817152235
2.7692823699026397
2.5585733278833254
2.085702587859799

ONNX 1.18 without CUDA Graph:

Latencies (ms):
3.0858926098756116
2.4176077409224077
2.685696187645498
3.6532445387406782
3.1608499661840574

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:CUDA issues related to the CUDA execution provider
Projects
None yet
Development

No branches or pull requests

4 participants