diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index b3f81f1332863..7772134fd1534 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -1447,7 +1447,7 @@ def to_np(value): try: np_res = np.power(to_np(base), to_np(np_exponent)) expected = ( - torch.from_numpy(np_res).to(dtype=base.dtype) + torch.from_numpy(np_res) if isinstance(np_res, np.ndarray) else torch.tensor(np_res, dtype=base.dtype) ) @@ -1480,8 +1480,8 @@ def to_np(value): self.assertRaisesRegex(RuntimeError, regex, base.pow_, exponent) elif torch.can_cast(torch.result_type(base, exponent), base.dtype): actual2 = actual.pow_(exponent) - self.assertEqual(actual, expected) - self.assertEqual(actual2, expected) + self.assertEqual(actual, expected.to(actual)) + self.assertEqual(actual2, expected.to(actual)) else: self.assertRaisesRegex( RuntimeError,