Skip to content

Multi-GPU inference throws error when thread uses sessions allocated on different device #25104

Open
@PawelPeczek-Roboflow

Description

@PawelPeczek-Roboflow

Describe the issue

Hello team.

I am running onnxruntime-gpu==1.22.0 with cu12.4 and I believe I see the problem with CUDA context management in the scenario when a single thread uses multiple models allocated on different devices (session is run with iobinding) - illustration below:

>>> run(model_2) # inference in the loop with model on cuda:1
100%|████████████████████████████████████████████████████████████████| 1000/1000 [00:12<00:00, 83.04it/s]
>>> run(model_1). # inference in the loop with model on cuda:0
  0%|                                                                           | 0/1000 [00:00<?, ?it/s]
RuntimeError: Error in execution: /onnxruntime_src/onnxruntime/core/providers/cuda/cuda_call.cc:129 std::conditional_t<THRW, void, onnxruntime::common::Status> onnxruntime::CudaCall(ERRTYPE, const char*, const char*, SUCCTYPE, const char*, const char*, int) [with ERRTYPE = cudaError; bool THRW = true; SUCCTYPE = cudaError; std::conditional_t<THRW, void, common::Status> = void] /onnxruntime_src/onnxruntime/core/providers/cuda/cuda_call.cc:121 std::conditional_t<THRW, void, onnxruntime::common::Status> onnxruntime::CudaCall(ERRTYPE, const char*, const char*, SUCCTYPE, const char*, const char*, int) [with ERRTYPE = cudaError; bool THRW = true; SUCCTYPE = cudaError; std::conditional_t<THRW, void, common::Status> = void] CUDA failure 400: invalid resource handle ; GPU=0 ; hostname=pawel-gpu-dev-4x-l4 ; file=/onnxruntime_src/onnxruntime/core/providers/cuda/cuda_stream_handle.cc ; line=41 ; expr=cudaEventRecord(event_, static_cast<cudaStream_t>(stream_.GetHandle()));


>>> run(model_1) # inference in the loop with model on cuda:0
100%|████████████████████████████████████████████████████████████████| 1000/1000 [00:12<00:00, 82.81it/s]
>>> run(model_2) # inference in the loop with model on cuda:1
  0%|                                                                           | 0/1000 [00:00<?, ?it/s]

RuntimeError: Error in execution: /onnxruntime_src/onnxruntime/core/providers/cuda/cuda_call.cc:129 std::conditional_t<THRW, void, onnxruntime::common::Status> onnxruntime::CudaCall(ERRTYPE, const char*, const char*, SUCCTYPE, const char*, const char*, int) [with ERRTYPE = cudaError; bool THRW = true; SUCCTYPE = cudaError; std::conditional_t<THRW, void, common::Status> = void] /onnxruntime_src/onnxruntime/core/providers/cuda/cuda_call.cc:121 std::conditional_t<THRW, void, onnxruntime::common::Status> onnxruntime::CudaCall(ERRTYPE, const char*, const char*, SUCCTYPE, const char*, const char*, int) [with ERRTYPE = cudaError; bool THRW = true; SUCCTYPE = cudaError; std::conditional_t<THRW, void, common::Status> = void] CUDA failure 400: invalid resource handle ; GPU=1 ; hostname=pawel-gpu-dev-4x-l4 ; file=/onnxruntime_src/onnxruntime/core/providers/cuda/cuda_stream_handle.cc ; line=41 ; expr=cudaEventRecord(event_, static_cast<cudaStream_t>(stream_.GetHandle()));


>>> run(model_2)  # inference in the loop with model on cuda:1
100%|████████████████████████████████████████████████████████████████| 1000/1000 [00:12<00:00, 82.61it/s]

so - when I used model_2 and then switched back to model_1 - the first forward pass fails.

What I did to counter the problem (which seems to be solving the problem) - I do created context manager to rotate CUDA device contexts for me (using primary context such that disposal is not an issue afterwards):

import pycuda.driver as cuda


@contextlib.contextmanager
def use_primary_cuda_context(cuda_device: cuda.Device) -> Generator[cuda.Context, None, None]:
    context = cuda_device.retain_primary_context()
    context.push()
    try:
        yield context
    finally:
        context.pop()

and wrapped the session run with this context manager.

When running onnx models in scale I also noticed that sometimes even with single GPU similar issues are raised (which I am not sure is part of the problem or simple coincidence).

To reproduce

  1. Take some model (like ultralytics YOLO) and export to onnx
  2. create code to run session with iobinding
  3. run the code on multi-GPU server

sorry for not providing e2e code, no time atm - should be possible later - and maybe codeowner look at the code would be enough to tell if there is a bug.

Urgency

No response

Platform

Linux

OS Version

Ubuntu 22.04

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.22.0

ONNX Runtime API

Python

Architecture

X64

Execution Provider

CUDA

Execution Provider Library Version

cu 12.4

Metadata

Metadata

Assignees

No one assigned

    Labels

    ep:CUDAissues related to the CUDA execution provider

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions