From 50ddaf7fc91ab2014bd66280070a614466e01deb Mon Sep 17 00:00:00 2001 From: Dmitry Nikolaev <139769634+dnikolaev-amd@users.noreply.github.com> Date: Mon, 25 Aug 2025 18:49:26 +0200 Subject: [PATCH] [rocm7.1_internal_testing] fix large tensor sort on ROCm (#2543) Currently std::min -> ::min did not work as expected on ROCm when input values >= 2147483648 Replace std::min to ternary statement Also std::min can be replaced by explicit typing std::min fixes on ROCm: test_sort_and_select.py::TestSortAndSelectCUDA::test_sort_large_cuda_float16 error: RuntimeError: Cannot sort dimension of length 8192 Combines upstream PRs: - https://github.com/pytorch/pytorch/pull/161054 to fix std::min on ROCm - https://github.com/pytorch/pytorch/pull/155546 fix python test - https://github.com/pytorch/pytorch/pull/159939 change test dtype from int8 to float16 Fixes: SWDEV-526432 --- aten/src/ATen/native/cuda/SortStable.cu | 8 +++++--- test/test_sort_and_select.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/cuda/SortStable.cu b/aten/src/ATen/native/cuda/SortStable.cu index 546989d09839a..4d956616371de 100644 --- a/aten/src/ATen/native/cuda/SortStable.cu +++ b/aten/src/ATen/native/cuda/SortStable.cu @@ -225,8 +225,9 @@ void launch_stable_sort_kernel( return; } - int64_t numel_or_intmax = - std::min(numel, static_cast(std::numeric_limits::max())); + const int64_t intmax = static_cast(std::numeric_limits::max()); + // On ROCm, std::min -> ::min did not work as expected on when input values >= 2147483648 + int64_t numel_or_intmax = numel < intmax ? numel : intmax; int64_t nsort = self.size(dim); int64_t nbatch = (numel_or_intmax / nsort) * nsort; TORCH_CHECK(nbatch > 0, "Cannot sort dimension of length ", nsort); @@ -238,7 +239,8 @@ void launch_stable_sort_kernel( scalar_t* values_ptr = values.mutable_data_ptr(); int64_t remaining = numel; while (remaining > 0) { - int64_t n = std::min(remaining, nbatch); + // On ROCm, std::min -> ::min did not work as expected on when input values >= 2147483648 + int64_t n = remaining < nbatch ? remaining : nbatch; int64_t nsegments = n / nsort; if (nsegments == 1 || diff --git a/test/test_sort_and_select.py b/test/test_sort_and_select.py index daa3996437498..e0c5e3cd7174a 100644 --- a/test/test_sort_and_select.py +++ b/test/test_sort_and_select.py @@ -215,7 +215,7 @@ def test_stable_sort(self, device, dtype): ) @onlyCUDA - @dtypes(torch.uint8) + @dtypes(torch.float16) @largeTensorTest("200GB") # Unfortunately 80GB A100 is not large enough def test_sort_large(self, device, dtype): t0 = torch.randperm(8192, device=device).to(dtype)