Skip to content

Commit 4adad1b

Browse files
authored
Merge 1add60a into 5b140db
2 parents 5b140db + 1add60a commit 4adad1b

File tree

4 files changed

+141
-139
lines changed

4 files changed

+141
-139
lines changed

.github/workflows/array-api-skips.txt

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,10 @@ array_api_tests/test_linalg.py::test_svd
2626
array_api_tests/test_linalg.py::test_qr
2727
array_api_tests/test_operators_and_elementwise_functions.py::test_clip
2828

29-
# unexpected result is returned
29+
# unexpected result is returned - unmute when dpctl-1986 is resolved
3030
array_api_tests/test_operators_and_elementwise_functions.py::test_asin
3131
array_api_tests/test_operators_and_elementwise_functions.py::test_asinh
3232

3333
# missing 'correction' keyword argument
3434
array_api_tests/test_signatures.py::test_func_signature[std]
3535
array_api_tests/test_signatures.py::test_func_signature[var]
36-
37-
# arrays have different values
38-
array_api_tests/test_linalg.py::test_linalg_tensordot

.github/workflows/conda-package.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ jobs:
218218
id: run_tests_linux
219219
uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
220220
with:
221-
timeout_minutes: 12
221+
timeout_minutes: 15
222222
max_attempts: ${{ env.RUN_TESTS_MAX_ATTEMPTS }}
223223
retry_on: any
224224
command: |

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 70 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -52,18 +52,12 @@
5252

5353
def _compute_res_dtype(*arrays, sycl_queue, dtype=None, casting="no"):
5454
"""
55-
Determines the output array data type and an intermediate data type
56-
used in performing calculations related to a specific math function.
55+
Determines the output array data type.
5756
If dtype is ``None``, the output array data type of the operation is
5857
determined based on the Promotion Type Rule and device capabilities.
5958
Otherwise, `dtype` is used as output array dtype, if input arrays
6059
can cast to it according to the casting rule determined. If casting
6160
cannot be done, a ``TypeError`` is raised.
62-
The intermediate data type is the data type used for performing the math
63-
function calculations. If output array dtype is a floating-point data type,
64-
it is also used for the intermediate data type. If output array dtype is an
65-
integral data type, the default floating point data type of the device where
66-
input arrays are allocated on are used for intermediate data type.
6761
6862
Parameters
6963
----------
@@ -78,17 +72,13 @@ def _compute_res_dtype(*arrays, sycl_queue, dtype=None, casting="no"):
7872
7973
Returns
8074
-------
81-
compute_dtype, res_dtype :
82-
`compute_dtype` is the data type used in performing math function calculations.
83-
The input arrays of the math function are cast to `compute_dtype` and then
84-
the calculations are performed.
75+
res_dtype :
8576
`res_dtype` is the output data type. When the result is obtained, it is cast
8677
to `res_dtype`.
8778
8879
"""
8980

9081
res_dtype = dpnp.result_type(*arrays)
91-
default_dtype = dpnp.default_float_type(sycl_queue=sycl_queue)
9282

9383
if dtype is not None:
9484
if dpnp.can_cast(res_dtype, dtype, casting=casting):
@@ -98,11 +88,7 @@ def _compute_res_dtype(*arrays, sycl_queue, dtype=None, casting="no"):
9888
f"Cannot cast from dtype({res_dtype}) to dtype({dtype}) with casting rule {casting}"
9989
)
10090

101-
compute_dtype = (
102-
res_dtype if dpnp.issubdtype(res_dtype, dpnp.inexact) else default_dtype
103-
)
104-
105-
return compute_dtype, res_dtype
91+
return res_dtype
10692

10793

10894
def _copy_array(x, copy_flag=False, dtype=None, order="C"):
@@ -749,17 +735,17 @@ def dpnp_dot(a, b, /, out=None, *, casting="same_kind", conjugate=False):
749735
_validate_out_array(out, exec_q)
750736

751737
# Determine the appropriate data types
752-
dot_dtype, res_dtype = _compute_res_dtype(a, b, sycl_queue=exec_q)
738+
res_dtype = _compute_res_dtype(a, b, sycl_queue=exec_q)
753739

754740
result = _create_result_array(
755-
a, b, out, (), dot_dtype, res_usm_type, exec_q
741+
a, b, out, (), res_dtype, res_usm_type, exec_q
756742
)
757743

758744
# input arrays should have the proper data type
759745
if dpnp.issubdtype(res_dtype, dpnp.inexact):
760746
# copying is needed if dtypes of input arrays are different
761-
a = _copy_array(a, dtype=dot_dtype)
762-
b = _copy_array(b, dtype=dot_dtype)
747+
a = _copy_array(a, dtype=res_dtype)
748+
b = _copy_array(b, dtype=res_dtype)
763749

764750
_manager = dpu.SequentialOrderManager[exec_q]
765751

@@ -777,14 +763,11 @@ def dpnp_dot(a, b, /, out=None, *, casting="same_kind", conjugate=False):
777763
)
778764
_manager.add_event_pair(ht_ev, dot_ev)
779765
else:
780-
# oneapi::mkl::blas::dot is slow for integer data type,
766+
# oneapi::mkl::blas::dot does not support integer dtypes,
781767
# so using dpctl.tensor.vecdot instead
782-
dpt_a = dpnp.get_usm_ndarray(a)
783-
dpt_b = dpnp.get_usm_ndarray(b)
784-
result = dpnp_array._create_from_usm_ndarray(dpt.vecdot(dpt_a, dpt_b))
785-
786-
if dot_dtype != res_dtype:
787-
result = result.astype(res_dtype, copy=False)
768+
a_usm = dpnp.get_usm_ndarray(a)
769+
b_usm = dpnp.get_usm_ndarray(b)
770+
result = dpnp_array._create_from_usm_ndarray(dpt.vecdot(a_usm, b_usm))
788771

789772
return dpnp.get_result_array(result, out, casting=casting)
790773

@@ -902,7 +885,7 @@ def dpnp_multiplication(
902885
axes_res = normalize_axis_tuple(axes_res, len(result_shape), "axes")
903886

904887
# Determine the appropriate data types
905-
compute_dtype, res_dtype = _compute_res_dtype(
888+
res_dtype = _compute_res_dtype(
906889
x1, x2, dtype=dtype, casting=casting, sycl_queue=exec_q
907890
)
908891

@@ -998,7 +981,7 @@ def dpnp_multiplication(
998981
x2,
999982
out,
1000983
res_shape,
1001-
compute_dtype,
984+
res_dtype,
1002985
res_usm_type,
1003986
exec_q,
1004987
res_order,
@@ -1010,64 +993,72 @@ def dpnp_multiplication(
1010993
elif x1.size == 0 or x2.size == 0:
1011994
result.fill(0)
1012995
else:
1013-
# input arrays should have the proper data type and
1014-
# their base (last 2-dimensions) to be c-contiguous or f-contiguous
1015-
x1 = _copy_array(
1016-
x1,
1017-
copy_flag=not x1_contig_flag,
1018-
dtype=compute_dtype,
1019-
order=res_order,
1020-
)
1021-
x2 = _copy_array(
1022-
x2,
1023-
copy_flag=not x2_contig_flag,
1024-
dtype=compute_dtype,
1025-
order=res_order,
1026-
)
1027-
1028-
if call_flag == "gemv":
1029-
if transpose:
1030-
a_usm = dpnp.get_usm_ndarray(x2)
1031-
x_usm = dpnp.get_usm_ndarray(x1)
1032-
else:
1033-
a_usm = dpnp.get_usm_ndarray(x1)
1034-
x_usm = dpnp.get_usm_ndarray(x2)
1035-
1036-
_manager = dpu.SequentialOrderManager[exec_q]
1037-
1038-
ht_ev, gemv_ev = bi._gemv(
1039-
exec_q,
1040-
a_usm,
1041-
x_usm,
1042-
dpnp.get_usm_ndarray(result),
1043-
transpose,
1044-
depends=_manager.submitted_events,
1045-
)
1046-
_manager.add_event_pair(ht_ev, gemv_ev)
1047-
elif call_flag == "gemm":
1048-
result = _gemm_matmul(
1049-
exec_q,
996+
if dpnp.issubdtype(res_dtype, dpnp.inexact):
997+
# copying is needed if dtypes of input arrays are different or
998+
# their base (last 2-dimensions) is not c-contiguous or f-contiguous
999+
x1 = _copy_array(
10501000
x1,
1051-
x2,
1052-
result,
1001+
copy_flag=not x1_contig_flag,
1002+
dtype=res_dtype,
1003+
order=res_order,
10531004
)
1054-
else: # call_flag == "gemm_batch"
1055-
assert call_flag == "gemm_batch"
1056-
result = _gemm_batch_matmul(
1057-
exec_q,
1058-
x1,
1005+
x2 = _copy_array(
10591006
x2,
1060-
result,
1007+
copy_flag=not x2_contig_flag,
1008+
dtype=res_dtype,
1009+
order=res_order,
10611010
)
10621011

1012+
if call_flag == "gemv":
1013+
if transpose:
1014+
a_usm = dpnp.get_usm_ndarray(x2)
1015+
x_usm = dpnp.get_usm_ndarray(x1)
1016+
else:
1017+
a_usm = dpnp.get_usm_ndarray(x1)
1018+
x_usm = dpnp.get_usm_ndarray(x2)
1019+
1020+
_manager = dpu.SequentialOrderManager[exec_q]
1021+
1022+
ht_ev, gemv_ev = bi._gemv(
1023+
exec_q,
1024+
a_usm,
1025+
x_usm,
1026+
dpnp.get_usm_ndarray(result),
1027+
transpose,
1028+
depends=_manager.submitted_events,
1029+
)
1030+
_manager.add_event_pair(ht_ev, gemv_ev)
1031+
elif call_flag == "gemm":
1032+
result = _gemm_matmul(
1033+
exec_q,
1034+
x1,
1035+
x2,
1036+
result,
1037+
)
1038+
else: # call_flag == "gemm_batch"
1039+
assert call_flag == "gemm_batch"
1040+
result = _gemm_batch_matmul(
1041+
exec_q,
1042+
x1,
1043+
x2,
1044+
result,
1045+
)
1046+
else:
1047+
# oneapi::mkl::blas::gemm/gemv do not support integer dtypes,
1048+
# so using dpctl.tensor.matmul instead
1049+
x1_usm = dpnp.get_usm_ndarray(x1)
1050+
x2_usm = dpnp.get_usm_ndarray(x2)
1051+
out_usm = dpnp.get_usm_ndarray(result)
1052+
res_usm = dpt.matmul(
1053+
x1_usm, x2_usm, out=out_usm, dtype=dtype, order=order
1054+
)
1055+
result = dpnp_array._create_from_usm_ndarray(res_usm)
1056+
10631057
if NumPy_special_case:
10641058
result = dpnp.tile(result, out.shape)
10651059
elif res_shape != result_shape:
10661060
result = dpnp.reshape(result, result_shape)
10671061

1068-
if compute_dtype != res_dtype:
1069-
result = dpnp.astype(result, res_dtype, copy=False)
1070-
10711062
if out is None:
10721063
if axes is not None:
10731064
# Move the data back to the appropriate axes of the result array
@@ -1207,7 +1198,7 @@ def dpnp_vecdot(
12071198
)
12081199

12091200
# Determine the appropriate data types
1210-
_, res_dtype = _compute_res_dtype(
1201+
res_dtype = _compute_res_dtype(
12111202
x1, x2, dtype=dtype, casting=casting, sycl_queue=exec_q
12121203
)
12131204

0 commit comments

Comments
 (0)