Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -3279,6 +3279,8 @@ def test_hip_device_count(self):
{"CUDA_VISIBLE_DEVICES": "0", "HIP_VISIBLE_DEVICES": None},
{"CUDA_VISIBLE_DEVICES": None, "HIP_VISIBLE_DEVICES": "0"},
{"CUDA_VISIBLE_DEVICES": "0,1,2,3", "HIP_VISIBLE_DEVICES": "0"},
{"ROCR_VISIBLE_DEVICES": "1,2,3", "HIP_VISIBLE_DEVICES": "0"},
{"ROCR_VISIBLE_DEVICES": "0", "HIP_VISIBLE_DEVICES": None},
]

for env_config in custom_envs:
Expand Down
20 changes: 19 additions & 1 deletion torch/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,25 @@ def _parse_visible_devices() -> Union[List[int], List[str]]:

if torch.version.hip:
hip_devices = os.getenv("HIP_VISIBLE_DEVICES")
if hip_devices is not None:
rocr_devices = os.getenv("ROCR_VISIBLE_DEVICES")

# You must take care if both HIP and ROCR env vars are set as they have
# different meanings. Both env vars accept either a list of ints or a
# list of UUIDs. The ROCR env var is processed first which then reduces
# the number of GPUs that HIP can select from.
if rocr_devices is not None:
rocr_count = len(rocr_devices.split(","))
if hip_devices is not None:
# sanity check if both env vars are set
if len(hip_devices.split(",")) > rocr_count:
raise RuntimeError(
"HIP_VISIBLE_DEVICES contains more devices than ROCR_VISIBLE_DEVICES"
)
# HIP_VISIBLE_DEVICES is preferred over ROCR_VISIBLE_DEVICES
var = hip_devices
else:
return list(range(rocr_count))
elif hip_devices is not None:
var = hip_devices

if var is None:
Expand Down