From 4f0f1a30835bdf14ce695a72f857eb5d846b7b84 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Wed, 23 Oct 2024 09:30:17 -0500 Subject: [PATCH 1/4] Correctly handle inputs which are all Python scalars (weak types) If dtypes list is empty, populate it with first Python scalar from weak_dtypes list. --- dpctl/tensor/_type_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/dpctl/tensor/_type_utils.py b/dpctl/tensor/_type_utils.py index 5defd154df..f279052f94 100644 --- a/dpctl/tensor/_type_utils.py +++ b/dpctl/tensor/_type_utils.py @@ -767,6 +767,9 @@ def result_type(*arrays_and_dtypes): target_dev = d inspected = True + if not dtypes and weak_dtypes: + dtypes.append(weak_dtypes[0].get()) + if not (has_fp16 and has_fp64): for dt in dtypes: if not _dtype_supported_by_device_impl(dt, has_fp16, has_fp64): From 0a68c343e0120eee1984f038bd1b888f8dfa2920 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Wed, 23 Oct 2024 09:44:05 -0500 Subject: [PATCH 2/4] Add test that result_types(dtypes) works the same for Python/NumPy scalars --- dpctl/tests/test_usm_ndarray_manipulation.py | 21 ++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/dpctl/tests/test_usm_ndarray_manipulation.py b/dpctl/tests/test_usm_ndarray_manipulation.py index 882a001827..f0565c8009 100644 --- a/dpctl/tests/test_usm_ndarray_manipulation.py +++ b/dpctl/tests/test_usm_ndarray_manipulation.py @@ -15,6 +15,8 @@ # limitations under the License. +import itertools + import numpy as np import pytest from numpy.testing import assert_, assert_array_equal, assert_raises_regex @@ -1555,3 +1557,22 @@ def test_repeat_0_size(): res = dpt.repeat(x, repetitions, axis=1) axis_sz = 2 * x.shape[1] assert res.shape == (0, axis_sz, 0) + + +def test_result_type_bug_1874(): + dts_bool = [True, np.bool_(True)] + dts_ints = [int(1), np.int64(1)] + dts_floats = [float(1), np.float64(1)] + dts_complexes = [complex(1), np.complex128(1)] + + # iterate over two categories + for dts1, dts2 in itertools.product( + [dts_bool, dts_ints, dts_floats, dts_complexes], repeat=2 + ): + res_dts = [] + # iterate over Python scalar/NumPy scalar choices within categories + for dt1, dt2 in itertools.product(dts1, dts2): + res_dt = dpt.result_type(dt1, dt2) + res_dts.append(res_dt) + # check that all results are the same + assert res_dts and all(res_dts[0] == el for el in res_dts[1:]) From 6c4211c94e69808300dee92adf8281664d891cf9 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Wed, 23 Oct 2024 10:54:27 -0500 Subject: [PATCH 3/4] Record chagne to result_type in CHANGELOG --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 97c06affac..586b65652c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Improved performance of `tensor.sort` and `tensor.argsort` for short arrays in the range [16, 64] elements [gh-1866](https://github.com/IntelPython/dpctl/pull/1866) ### Fixed +* Fix for `tensor.result_type` when all inputs are Python built-in scalars [gh-1877](https://github.com/IntelPython/dpctl/pull/1877) ### Maintenance From dc1887e7f65bd336239975c3727d27a171495ff5 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Wed, 23 Oct 2024 13:32:10 -0500 Subject: [PATCH 4/4] Changed test_result_type_bug_1874 to work with NumPy 2 and older on Win --- dpctl/tests/test_usm_ndarray_manipulation.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/dpctl/tests/test_usm_ndarray_manipulation.py b/dpctl/tests/test_usm_ndarray_manipulation.py index f0565c8009..4bfd6dab9f 100644 --- a/dpctl/tests/test_usm_ndarray_manipulation.py +++ b/dpctl/tests/test_usm_ndarray_manipulation.py @@ -1560,8 +1560,12 @@ def test_repeat_0_size(): def test_result_type_bug_1874(): - dts_bool = [True, np.bool_(True)] - dts_ints = [int(1), np.int64(1)] + py_sc = True + np_sc = np.asarray([py_sc])[0] + dts_bool = [py_sc, np_sc] + py_sc = int(1) + np_sc = np.asarray([py_sc])[0] + dts_ints = [py_sc, np_sc] dts_floats = [float(1), np.float64(1)] dts_complexes = [complex(1), np.complex128(1)]