Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions torch/utils/cpp_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -2407,11 +2407,18 @@ def _get_cuda_arch_flags(cflags: Optional[list[str]] = None) -> list[str]:

def _get_rocm_arch_flags(cflags: Optional[list[str]] = None) -> list[str]:
# If cflags is given, there may already be user-provided arch flags in it
# (from `extra_compile_args`)
# (from `extra_compile_args`). If user also specified -fgpu-rdc or -fno-gpu-rdc, we
# assume they know what they're doing. Otherwise, we force -fno-gpu-rdc default.
has_gpu_rdc_flag = False
if cflags is not None:
has_custom_flags = False
for flag in cflags:
if 'amdgpu-target' in flag or 'offload-arch' in flag:
return ['-fno-gpu-rdc']
has_custom_flags = True
elif 'gpu-rdc' in flag:
has_gpu_rdc_flag = True
if has_custom_flags:
return [] if has_gpu_rdc_flag else ['-fno-gpu-rdc']
# Use same defaults as used for building PyTorch
# Allow env var to override, just like during initial cmake build.
_archs = os.environ.get('PYTORCH_ROCM_ARCH', None)
Expand All @@ -2424,7 +2431,7 @@ def _get_rocm_arch_flags(cflags: Optional[list[str]] = None) -> list[str]:
else:
archs = _archs.replace(' ', ';').split(';')
flags = [f'--offload-arch={arch}' for arch in archs]
flags += ['-fno-gpu-rdc']
flags += [] if has_gpu_rdc_flag else ['-fno-gpu-rdc']
return flags

def _get_build_directory(name: str, verbose: bool) -> str:
Expand Down Expand Up @@ -2612,8 +2619,8 @@ def _write_ninja_file_to_build_library(path,

if with_cuda and IS_HIP_EXTENSION:
cuda_flags = ['-DWITH_HIP'] + cflags + COMMON_HIP_FLAGS + COMMON_HIPCC_FLAGS
cuda_flags += extra_cuda_cflags
cuda_flags += _get_rocm_arch_flags(cuda_flags)
cuda_flags += extra_cuda_cflags
elif with_cuda:
cuda_flags = common_cflags + COMMON_NVCC_FLAGS + _get_cuda_arch_flags()
if IS_WINDOWS:
Expand Down