Skip to content

Commit

Permalink
Merge pull request #83 from Layer-norm/main
Browse files Browse the repository at this point in the history
fix issue #82
  • Loading branch information
Fannovel16 committed Oct 21, 2023
2 parents ff56485 + bb14f52 commit ed320e6
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/controlnet_aux/dwpose/wholebody.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
from .types import PoseResult, BodyResult, Keypoint

ONNX_PROVIDERS = ["CUDAExecutionProvider", "DirectMLExecutionProvider", "OpenVINOExecutionProvider", "ROCMExecutionProvider"]
SUPPORT_PROVIDERS = []
def check_ort_gpu():
try:
import onnxruntime as ort
for provider in ONNX_PROVIDERS:
if provider in ort.get_available_providers():
SUPPORT_PROVIDERS.append(provider)
return True
return False
except:
Expand All @@ -28,8 +30,9 @@ def __init__(self, onnx_det: str, onnx_pose: str):
if check_ort_gpu():
import onnxruntime as ort
if ort_session_det is None:
ort_session_det = ort.InferenceSession(onnx_det, providers=ort.get_available_providers())
ort_session_pose = ort.InferenceSession(onnx_pose, providers=ort.get_available_providers())
SUPPORT_PROVIDERS.append('CPUExecutionProvider')
ort_session_det = ort.InferenceSession(onnx_det, providers=SUPPORT_PROVIDERS)
ort_session_pose = ort.InferenceSession(onnx_pose, providers=SUPPORT_PROVIDERS)
self.session_det = ort_session_det
self.session_pose = ort_session_pose
return
Expand Down

0 comments on commit ed320e6

Please sign in to comment.