From b03beb32595fcc103bbe747a4145e2ae5f465fc3 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Sun, 28 May 2023 20:50:22 -0700 Subject: [PATCH] Fixes incorrect result when multiplying real array and complex scalar - Adds a test for the fix - Resolves #1219 --- dpctl/tensor/_elementwise_common.py | 4 ++-- dpctl/tests/elementwise/test_multiply.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/dpctl/tensor/_elementwise_common.py b/dpctl/tensor/_elementwise_common.py index 6677670e73..a775376d95 100644 --- a/dpctl/tensor/_elementwise_common.py +++ b/dpctl/tensor/_elementwise_common.py @@ -255,7 +255,7 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev): raise ValueError o1_kind_num = _weak_type_num_kind(o1_dtype) o2_kind_num = _strong_dtype_num_kind(o2_dtype) - if o1_kind_num > o2_kind_num: + if o1_kind_num > o2_kind_num or o1_kind_num == 2: if isinstance(o1_dtype, WeakBooleanType): return dpt.bool, o2_dtype if isinstance(o1_dtype, WeakIntegralType): @@ -273,7 +273,7 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev): ): o1_kind_num = _strong_dtype_num_kind(o1_dtype) o2_kind_num = _weak_type_num_kind(o2_dtype) - if o2_kind_num > o1_kind_num: + if o2_kind_num > o1_kind_num or o2_kind_num == 2: if isinstance(o2_dtype, WeakBooleanType): return o1_dtype, dpt.bool if isinstance(o2_dtype, WeakIntegralType): diff --git a/dpctl/tests/elementwise/test_multiply.py b/dpctl/tests/elementwise/test_multiply.py index cd506cd182..1305154021 100644 --- a/dpctl/tests/elementwise/test_multiply.py +++ b/dpctl/tests/elementwise/test_multiply.py @@ -152,3 +152,18 @@ def test_multiply_python_scalar(arr_dt): assert isinstance(R, dpt.usm_ndarray) R = dpt.multiply(sc, X) assert isinstance(R, dpt.usm_ndarray) + + +def test_multiply_python_scalar_gh1219(): + q = get_queue_or_skip() + + X = dpt.ones(4, dtype="f4", sycl_queue=q) + + r = dpt.multiply(X, 2j) + expected = dpt.multiply(X, dpt.asarray(2j, sycl_queue=q)) + assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q) + + # symmetric case + r = dpt.multiply(2j, X) + expected = dpt.multiply(dpt.asarray(2j, sycl_queue=q), X) + assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q)