diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 7d0f4c445f38c..a211f8f7a7894 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -450,9 +450,49 @@ class TORCH_API Context { c10::utils::check_env("TORCH_LINALG_PREFER_HIPSOLVER") == true) // alias ? at::LinalgBackend::Cusolver : at::LinalgBackend::Default; +<<<<<<< HEAD at::BlasBackend blas_preferred_backend = (c10::utils::check_env("TORCH_BLAS_PREFER_CUBLASLT") == true || c10::utils::check_env("TORCH_BLAS_PREFER_HIPBLASLT") == true) // alias +======= +#ifdef USE_ROCM + // AMD Instinct targets prefer hipblaslt + const bool _hipblaslt_preferred_default = []() { + const std::vector archs = { + "gfx90a", "gfx942", +#if ROCM_VERSION >= 60300 + "gfx1200", "gfx1201", +#endif +#if ROCM_VERSION >= 60500 + "gfx950" +#endif + }; + for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) { + if (!detail::getCUDAHooks().isGPUArch(index, archs)) { + return false; + } + } + return true; + }(); +#else + const bool _hipblaslt_preferred_default = false; +#endif + const bool _blaslt_preferred = [&]() { + auto env = c10::utils::check_env("TORCH_BLAS_PREFER_CUBLASLT"); + if (env.has_value()) { + return env.value(); + } + env = c10::utils::check_env("TORCH_BLAS_PREFER_HIPBLASLT"); + if (env.has_value()) { + return env.value(); + } +#ifdef USE_ROCM + return _hipblaslt_preferred_default; +#endif + return false; + }(); + at::BlasBackend blas_preferred_backend = _blaslt_preferred +>>>>>>> 1ded221de6 ([release/2.6] Change gfx110x BLAS preferred backend (#2053)) ? at::BlasBackend::Cublaslt : at::BlasBackend::Default; at::ROCmFABackend rocm_fa_preferred_backend =