From 520170a552b315d82c9ee43d5dee0809a50e6132 Mon Sep 17 00:00:00 2001 From: Jagadish Krishnamoorthy Date: Tue, 14 Jan 2025 09:44:53 -0800 Subject: [PATCH] [release/2.5] [reland][AMD] Turn on TF32 for aten::mm Ported patch from https://patch-diff.githubusercontent.com/raw/pytorch/pytorch/pull/143549.patch Resolved conflict in aten/src/ATen/Context.cpp file test/dynamo/test_graph_region_tracker.py skipped since not presnt in rel 2.5 branch Signed-off-by: Jagadish Krishnamoorthy --- aten/src/ATen/Context.cpp | 18 +++++++ aten/src/ATen/cuda/CUDABlas.cpp | 4 -- test/dynamo/test_misc.py | 55 +++++++++++++++------- test/test_cuda.py | 33 +++++++++++++ torch/utils/hipify/cuda_to_hip_mappings.py | 4 ++ 5 files changed, 92 insertions(+), 22 deletions(-) diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index ff8ceed7e8de8..492e58cdd6991 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -173,6 +174,9 @@ bool Context::userEnabledOverrideableSDP() const { static const char cublas_config_var_name[] = "CUBLAS_WORKSPACE_CONFIG"; // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) static const char* const cublas_deterministic_configs[] = { ":4096:8", ":16:8" }; +#ifdef USE_ROCM +static constexpr const auto hipblaslt_allow_tf32 = "HIPBLASLT_ALLOW_TF32"; +#endif bool Context::checkCuBLASConfigDeterministic() { bool cublas_config_deterministic = true; @@ -228,10 +232,24 @@ void Context::setBenchmarkLimitCuDNN(int b) { } bool Context::allowTF32CuBLAS() const { +#ifdef USE_ROCM + const static auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32); + if (allow_tf32 != true) { + return false; + } +#endif return float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST; } void Context::setAllowTF32CuBLAS(bool b) { +#ifdef USE_ROCM + const static auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32); + if (allow_tf32 != true) { + LOG(INFO) << "torch.backends.cuda.matmul.allow_tf32 is not supported on ROCm by default. " + << "Please set environment variable HIPBLASLT_ALLOW_TF32=1 to enable it."; + return; + } +#endif float32_matmul_precision = b ? at::Float32MatmulPrecision::HIGH : at::Float32MatmulPrecision::HIGHEST; } diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index 9b3fd5dc6e4dd..ad8d6d0809050 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -329,11 +329,9 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { computeType = CUBLAS_COMPUTE_64F; scaleType = CUDA_R_64F; } else if constexpr (std::is_same_v) { -#ifndef USE_ROCM if (at::globalContext().allowTF32CuBLAS()) { computeType = CUBLAS_COMPUTE_32F_FAST_TF32; } -#endif } else if constexpr (std::is_same_v>) { abcType = CUDA_C_64F; computeType = CUBLAS_COMPUTE_64F; @@ -1205,11 +1203,9 @@ void gemm_and_bias( computeType = CUBLAS_COMPUTE_64F; scaleType = CUDA_R_64F; } else if constexpr (std::is_same_v) { -#ifndef USE_ROCM if (at::globalContext().allowTF32CuBLAS()) { computeType = CUBLAS_COMPUTE_32F_FAST_TF32; } -#endif abcType = CUDA_R_32F; } else if constexpr (std::is_same_v) { abcType = CUDA_R_16F; diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index ab1b343044a26..6e5c11f1196e0 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -7717,24 +7717,43 @@ def write_state(state): def fn(x): return x + 1 - initial_state = read_state() - y = torch.randn(10) - try: - for round in range(3): - for i in range(len(initial_state)): - new_state = [False] * len(initial_state) - new_state[i] = True - write_state(new_state) - assert read_state() == new_state - last_state.clear() - fn(y) - assert last_state == new_state - if round == 0: - assert cnt == i + 1 - else: - assert cnt == len(initial_state) - finally: - write_state(initial_state) + import contextlib + + @contextlib.contextmanager + def _hip_allow_tf32(): + # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new + # and only for MI300+ + hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None) + os.environ["HIPBLASLT_ALLOW_TF32"] = "1" + + try: + yield + finally: + if hip_allow_tf32 is not None: + os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32 + else: + del os.environ["HIPBLASLT_ALLOW_TF32"] + + tf32_ctx = _hip_allow_tf32 if torch.version.hip else contextlib.nullcontext + with tf32_ctx(): + initial_state = read_state() + y = torch.randn(10) + try: + for round in range(3): + for i in range(len(initial_state)): + new_state = [False] * len(initial_state) + new_state[i] = True + write_state(new_state) + assert read_state() == new_state + last_state.clear() + fn(y) + assert last_state == new_state + if round == 0: + assert cnt == i + 1 + else: + assert cnt == len(initial_state) + finally: + write_state(initial_state) def test_grad_state_mutated(self): prior = torch.is_grad_enabled() diff --git a/test/test_cuda.py b/test/test_cuda.py index 381b8705b7dc0..7480d4ef27797 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -433,7 +433,33 @@ def check_workspace_size(inp): torch._C._cuda_clearCublasWorkspaces() + @contextlib.contextmanager + def _hip_allow_tf32(self): + # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new + # and only for MI300+ + hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None) + os.environ["HIPBLASLT_ALLOW_TF32"] = "1" + + try: + yield + finally: + if hip_allow_tf32 is not None: + os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32 + else: + del os.environ["HIPBLASLT_ALLOW_TF32"] + def test_cublas_allow_tf32_get_set(self): + """ + We only turn on TF32 for MI300 with a special env var. This is because TF32 + is only available in MI300+ and is in experimental mode (hipblaslt support + is current WIP) + """ + tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext + + with tf32_ctx(): + self._test_cublas_allow_tf32_get_set_inner() + + def _test_cublas_allow_tf32_get_set_inner(self): skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int( os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"] ) @@ -448,6 +474,12 @@ def test_cublas_allow_tf32_get_set(self): torch.backends.cuda.matmul.allow_tf32 = orig def test_float32_matmul_precision_get_set(self): + tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext + + with tf32_ctx(): + self._test_float32_matmul_precision_get_set_inner() + + def _test_float32_matmul_precision_get_set_inner(self): orig = torch.get_float32_matmul_precision() skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int( os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"] @@ -459,6 +491,7 @@ def test_float32_matmul_precision_get_set(self): self.assertEqual(torch.get_float32_matmul_precision(), "highest") else: self.assertTrue(torch.backends.cuda.matmul.allow_tf32) + for p in ("medium", "high"): torch.set_float32_matmul_precision(p) self.assertEqual(torch.get_float32_matmul_precision(), p) diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py index a002a0d491f78..f8827e0681e9d 100644 --- a/torch/utils/hipify/cuda_to_hip_mappings.py +++ b/torch/utils/hipify/cuda_to_hip_mappings.py @@ -7284,6 +7284,10 @@ "CUBLAS_COMPUTE_32F", ("HIPBLAS_COMPUTE_32F", CONV_MATH_FUNC, API_BLAS) ), + ( + "CUBLAS_COMPUTE_32F_FAST_TF32", + ("HIPBLAS_COMPUTE_32F_FAST_TF32", CONV_MATH_FUNC, API_BLAS) + ), ( "CUBLAS_COMPUTE_64F", ("HIPBLAS_COMPUTE_64F", CONV_MATH_FUNC, API_BLAS)