From baa35b291f15a3002b2859a5964a2bf70169f199 Mon Sep 17 00:00:00 2001 From: Dmitry Nikolaev <139769634+dnikolaev-amd@users.noreply.github.com> Date: Mon, 14 Jul 2025 18:58:24 +0200 Subject: [PATCH] [release/2.6] Fix dtype before comparing torch and numpy tensors (#2340) Cast numpy dtype result to torch dtype result before compare Numpy returns `np.power(float32, int64) => float64` [Promotion rules for Python scalars](https://numpy.org/neps/nep-0050-scalar-promotion.html) Pytorch returns `torch.pow(float32, int64) => float32` Reverts https://github.com/ROCm/pytorch/pull/2287 and fixes tests in a different way Fixes: - SWDEV-538110 - `'dtype' do not match: torch.float32 != torch.float64` > - test_binary_ufuncs.py::TestBinaryUfuncsCUDA::test_cuda_tensor_pow_scalar_tensor_cuda - SWDEV-539171 - `AttributeError: 'float' object has no attribute 'dtype` > - test_binary_ufuncs.py::TestBinaryUfuncsCUDA::test_long_tensor_pow_floats_cuda > - test_binary_ufuncs.py::TestBinaryUfuncsCUDA::test_complex_scalar_pow_tensor_cuda_* > - test_binary_ufuncs.py::TestBinaryUfuncsCUDA::test_float_scalar_pow_float_tensor_cuda_* --- test/test_binary_ufuncs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index b3f81f1332863..7772134fd1534 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -1447,7 +1447,7 @@ def to_np(value): try: np_res = np.power(to_np(base), to_np(np_exponent)) expected = ( - torch.from_numpy(np_res).to(dtype=base.dtype) + torch.from_numpy(np_res) if isinstance(np_res, np.ndarray) else torch.tensor(np_res, dtype=base.dtype) ) @@ -1480,8 +1480,8 @@ def to_np(value): self.assertRaisesRegex(RuntimeError, regex, base.pow_, exponent) elif torch.can_cast(torch.result_type(base, exponent), base.dtype): actual2 = actual.pow_(exponent) - self.assertEqual(actual, expected) - self.assertEqual(actual2, expected) + self.assertEqual(actual, expected.to(actual)) + self.assertEqual(actual2, expected.to(actual)) else: self.assertRaisesRegex( RuntimeError,