5252
5353def _compute_res_dtype (* arrays , sycl_queue , dtype = None , casting = "no" ):
5454 """
55- Determines the output array data type and an intermediate data type
56- used in performing calculations related to a specific math function.
55+ Determines the output array data type.
5756 If dtype is ``None``, the output array data type of the operation is
5857 determined based on the Promotion Type Rule and device capabilities.
5958 Otherwise, `dtype` is used as output array dtype, if input arrays
6059 can cast to it according to the casting rule determined. If casting
6160 cannot be done, a ``TypeError`` is raised.
62- The intermediate data type is the data type used for performing the math
63- function calculations. If output array dtype is a floating-point data type,
64- it is also used for the intermediate data type. If output array dtype is an
65- integral data type, the default floating point data type of the device where
66- input arrays are allocated on are used for intermediate data type.
6761
6862 Parameters
6963 ----------
@@ -78,17 +72,13 @@ def _compute_res_dtype(*arrays, sycl_queue, dtype=None, casting="no"):
7872
7973 Returns
8074 -------
81- compute_dtype, res_dtype :
82- `compute_dtype` is the data type used in performing math function calculations.
83- The input arrays of the math function are cast to `compute_dtype` and then
84- the calculations are performed.
75+ res_dtype :
8576 `res_dtype` is the output data type. When the result is obtained, it is cast
8677 to `res_dtype`.
8778
8879 """
8980
9081 res_dtype = dpnp .result_type (* arrays )
91- default_dtype = dpnp .default_float_type (sycl_queue = sycl_queue )
9282
9383 if dtype is not None :
9484 if dpnp .can_cast (res_dtype , dtype , casting = casting ):
@@ -98,11 +88,7 @@ def _compute_res_dtype(*arrays, sycl_queue, dtype=None, casting="no"):
9888 f"Cannot cast from dtype({ res_dtype } ) to dtype({ dtype } ) with casting rule { casting } "
9989 )
10090
101- compute_dtype = (
102- res_dtype if dpnp .issubdtype (res_dtype , dpnp .inexact ) else default_dtype
103- )
104-
105- return compute_dtype , res_dtype
91+ return res_dtype
10692
10793
10894def _copy_array (x , copy_flag = False , dtype = None , order = "C" ):
@@ -749,17 +735,17 @@ def dpnp_dot(a, b, /, out=None, *, casting="same_kind", conjugate=False):
749735 _validate_out_array (out , exec_q )
750736
751737 # Determine the appropriate data types
752- dot_dtype , res_dtype = _compute_res_dtype (a , b , sycl_queue = exec_q )
738+ res_dtype = _compute_res_dtype (a , b , sycl_queue = exec_q )
753739
754740 result = _create_result_array (
755- a , b , out , (), dot_dtype , res_usm_type , exec_q
741+ a , b , out , (), res_dtype , res_usm_type , exec_q
756742 )
757743
758744 # input arrays should have the proper data type
759745 if dpnp .issubdtype (res_dtype , dpnp .inexact ):
760746 # copying is needed if dtypes of input arrays are different
761- a = _copy_array (a , dtype = dot_dtype )
762- b = _copy_array (b , dtype = dot_dtype )
747+ a = _copy_array (a , dtype = res_dtype )
748+ b = _copy_array (b , dtype = res_dtype )
763749
764750 _manager = dpu .SequentialOrderManager [exec_q ]
765751
@@ -777,14 +763,11 @@ def dpnp_dot(a, b, /, out=None, *, casting="same_kind", conjugate=False):
777763 )
778764 _manager .add_event_pair (ht_ev , dot_ev )
779765 else :
780- # oneapi::mkl::blas::dot is slow for integer data type ,
766+ # oneapi::mkl::blas::dot does not support integer dtypes ,
781767 # so using dpctl.tensor.vecdot instead
782- dpt_a = dpnp .get_usm_ndarray (a )
783- dpt_b = dpnp .get_usm_ndarray (b )
784- result = dpnp_array ._create_from_usm_ndarray (dpt .vecdot (dpt_a , dpt_b ))
785-
786- if dot_dtype != res_dtype :
787- result = result .astype (res_dtype , copy = False )
768+ a_usm = dpnp .get_usm_ndarray (a )
769+ b_usm = dpnp .get_usm_ndarray (b )
770+ result = dpnp_array ._create_from_usm_ndarray (dpt .vecdot (a_usm , b_usm ))
788771
789772 return dpnp .get_result_array (result , out , casting = casting )
790773
@@ -902,7 +885,7 @@ def dpnp_multiplication(
902885 axes_res = normalize_axis_tuple (axes_res , len (result_shape ), "axes" )
903886
904887 # Determine the appropriate data types
905- compute_dtype , res_dtype = _compute_res_dtype (
888+ res_dtype = _compute_res_dtype (
906889 x1 , x2 , dtype = dtype , casting = casting , sycl_queue = exec_q
907890 )
908891
@@ -998,7 +981,7 @@ def dpnp_multiplication(
998981 x2 ,
999982 out ,
1000983 res_shape ,
1001- compute_dtype ,
984+ res_dtype ,
1002985 res_usm_type ,
1003986 exec_q ,
1004987 res_order ,
@@ -1010,64 +993,72 @@ def dpnp_multiplication(
1010993 elif x1 .size == 0 or x2 .size == 0 :
1011994 result .fill (0 )
1012995 else :
1013- # input arrays should have the proper data type and
1014- # their base (last 2-dimensions) to be c-contiguous or f-contiguous
1015- x1 = _copy_array (
1016- x1 ,
1017- copy_flag = not x1_contig_flag ,
1018- dtype = compute_dtype ,
1019- order = res_order ,
1020- )
1021- x2 = _copy_array (
1022- x2 ,
1023- copy_flag = not x2_contig_flag ,
1024- dtype = compute_dtype ,
1025- order = res_order ,
1026- )
1027-
1028- if call_flag == "gemv" :
1029- if transpose :
1030- a_usm = dpnp .get_usm_ndarray (x2 )
1031- x_usm = dpnp .get_usm_ndarray (x1 )
1032- else :
1033- a_usm = dpnp .get_usm_ndarray (x1 )
1034- x_usm = dpnp .get_usm_ndarray (x2 )
1035-
1036- _manager = dpu .SequentialOrderManager [exec_q ]
1037-
1038- ht_ev , gemv_ev = bi ._gemv (
1039- exec_q ,
1040- a_usm ,
1041- x_usm ,
1042- dpnp .get_usm_ndarray (result ),
1043- transpose ,
1044- depends = _manager .submitted_events ,
1045- )
1046- _manager .add_event_pair (ht_ev , gemv_ev )
1047- elif call_flag == "gemm" :
1048- result = _gemm_matmul (
1049- exec_q ,
996+ if dpnp .issubdtype (res_dtype , dpnp .inexact ):
997+ # copying is needed if dtypes of input arrays are different or
998+ # their base (last 2-dimensions) is not c-contiguous or f-contiguous
999+ x1 = _copy_array (
10501000 x1 ,
1051- x2 ,
1052- result ,
1001+ copy_flag = not x1_contig_flag ,
1002+ dtype = res_dtype ,
1003+ order = res_order ,
10531004 )
1054- else : # call_flag == "gemm_batch"
1055- assert call_flag == "gemm_batch"
1056- result = _gemm_batch_matmul (
1057- exec_q ,
1058- x1 ,
1005+ x2 = _copy_array (
10591006 x2 ,
1060- result ,
1007+ copy_flag = not x2_contig_flag ,
1008+ dtype = res_dtype ,
1009+ order = res_order ,
10611010 )
10621011
1012+ if call_flag == "gemv" :
1013+ if transpose :
1014+ a_usm = dpnp .get_usm_ndarray (x2 )
1015+ x_usm = dpnp .get_usm_ndarray (x1 )
1016+ else :
1017+ a_usm = dpnp .get_usm_ndarray (x1 )
1018+ x_usm = dpnp .get_usm_ndarray (x2 )
1019+
1020+ _manager = dpu .SequentialOrderManager [exec_q ]
1021+
1022+ ht_ev , gemv_ev = bi ._gemv (
1023+ exec_q ,
1024+ a_usm ,
1025+ x_usm ,
1026+ dpnp .get_usm_ndarray (result ),
1027+ transpose ,
1028+ depends = _manager .submitted_events ,
1029+ )
1030+ _manager .add_event_pair (ht_ev , gemv_ev )
1031+ elif call_flag == "gemm" :
1032+ result = _gemm_matmul (
1033+ exec_q ,
1034+ x1 ,
1035+ x2 ,
1036+ result ,
1037+ )
1038+ else : # call_flag == "gemm_batch"
1039+ assert call_flag == "gemm_batch"
1040+ result = _gemm_batch_matmul (
1041+ exec_q ,
1042+ x1 ,
1043+ x2 ,
1044+ result ,
1045+ )
1046+ else :
1047+ # oneapi::mkl::blas::gemm/gemv do not support integer dtypes,
1048+ # so using dpctl.tensor.matmul instead
1049+ x1_usm = dpnp .get_usm_ndarray (x1 )
1050+ x2_usm = dpnp .get_usm_ndarray (x2 )
1051+ out_usm = dpnp .get_usm_ndarray (result )
1052+ res_usm = dpt .matmul (
1053+ x1_usm , x2_usm , out = out_usm , dtype = dtype , order = order
1054+ )
1055+ result = dpnp_array ._create_from_usm_ndarray (res_usm )
1056+
10631057 if NumPy_special_case :
10641058 result = dpnp .tile (result , out .shape )
10651059 elif res_shape != result_shape :
10661060 result = dpnp .reshape (result , result_shape )
10671061
1068- if compute_dtype != res_dtype :
1069- result = dpnp .astype (result , res_dtype , copy = False )
1070-
10711062 if out is None :
10721063 if axes is not None :
10731064 # Move the data back to the appropriate axes of the result array
@@ -1207,7 +1198,7 @@ def dpnp_vecdot(
12071198 )
12081199
12091200 # Determine the appropriate data types
1210- _ , res_dtype = _compute_res_dtype (
1201+ res_dtype = _compute_res_dtype (
12111202 x1 , x2 , dtype = dtype , casting = casting , sycl_queue = exec_q
12121203 )
12131204
0 commit comments