diff --git a/aten/src/ATen/native/cuda/SortStable.cu b/aten/src/ATen/native/cuda/SortStable.cu index 546989d09839a..4d956616371de 100644 --- a/aten/src/ATen/native/cuda/SortStable.cu +++ b/aten/src/ATen/native/cuda/SortStable.cu @@ -225,8 +225,9 @@ void launch_stable_sort_kernel( return; } - int64_t numel_or_intmax = - std::min(numel, static_cast(std::numeric_limits::max())); + const int64_t intmax = static_cast(std::numeric_limits::max()); + // On ROCm, std::min -> ::min did not work as expected on when input values >= 2147483648 + int64_t numel_or_intmax = numel < intmax ? numel : intmax; int64_t nsort = self.size(dim); int64_t nbatch = (numel_or_intmax / nsort) * nsort; TORCH_CHECK(nbatch > 0, "Cannot sort dimension of length ", nsort); @@ -238,7 +239,8 @@ void launch_stable_sort_kernel( scalar_t* values_ptr = values.mutable_data_ptr(); int64_t remaining = numel; while (remaining > 0) { - int64_t n = std::min(remaining, nbatch); + // On ROCm, std::min -> ::min did not work as expected on when input values >= 2147483648 + int64_t n = remaining < nbatch ? remaining : nbatch; int64_t nsegments = n / nsort; if (nsegments == 1 || diff --git a/test/test_sort_and_select.py b/test/test_sort_and_select.py index 360dc058212a0..669f165529e71 100644 --- a/test/test_sort_and_select.py +++ b/test/test_sort_and_select.py @@ -215,7 +215,7 @@ def test_stable_sort(self, device, dtype): ) @onlyCUDA - @dtypes(torch.uint8) + @dtypes(torch.float16) @largeTensorTest("200GB") # Unfortunately 80GB A100 is not large enough def test_sort_large(self, device, dtype): t0 = torch.randperm(8192, device=device).to(dtype)