Skip to content

Commit 97817ed

Browse files
committed
Removed extra copy for strided arrays in dot()
1 parent 10a656d commit 97817ed

File tree

2 files changed

+10
-15
lines changed

2 files changed

+10
-15
lines changed

dpnp/backend/kernels/dpnp_krnl_common.cpp

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -227,12 +227,10 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
227227
DPCTLSyclEventRef event_ref = nullptr;
228228
sycl::queue q = *(reinterpret_cast<sycl::queue *>(q_ref));
229229

230-
DPNPC_ptr_adapter<_DataType_input1> input1_ptr(q_ref, input1_in,
231-
input1_size);
232-
DPNPC_ptr_adapter<_DataType_input2> input2_ptr(q_ref, input2_in,
233-
input2_size);
234-
_DataType_input1 *input1 = input1_ptr.get_ptr();
235-
_DataType_input2 *input2 = input2_ptr.get_ptr();
230+
_DataType_input1 *input1 =
231+
static_cast<_DataType_input1 *>(const_cast<void *>(input1_in));
232+
_DataType_input2 *input2 =
233+
static_cast<_DataType_input2 *>(const_cast<void *>(input2_in));
236234
_DataType_output *result = reinterpret_cast<_DataType_output *>(result_out);
237235

238236
if (!input1_size || !input2_size) {
@@ -257,10 +255,12 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
257255
// if both arrays are vectors
258256
if ((input1_ndim == 1) && (input2_ndim == 1)) {
259257
assert(input1_size == input2_size);
258+
260259
sycl::event event = dot(q, result, input1, input2, input1_strides[0],
261260
input2_strides[0], input1_size);
262-
event.wait();
263-
return event_ref;
261+
262+
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);
263+
return DPCTLEvent_Copy(event_ref);
264264
}
265265

266266
// 1D vector
@@ -318,10 +318,6 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
318318
// (looks like there are such another cases)
319319

320320
if (ext_input1_ndim == 2 && ext_input2_ndim == 2) {
321-
// there is a difference of behavior with trans and sizes params in previous
322-
// version of GEMM only new version is supported, in case of old version
323-
// computation goes in common way
324-
#if INTEL_MKL_VERSION >= 20210004
325321
// is mat1 F-contiguous, C-contiguous
326322
bool mat1_f_contig =
327323
(((ext_input1_shape[0] == 1) || (ext_input1_strides[0] == 1)) &&
@@ -389,7 +385,6 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
389385
} catch (const std::exception &e) {
390386
// do nothing, proceed to general case
391387
}
392-
#endif
393388
}
394389
}
395390
}

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,14 +111,14 @@ def dot(x1, x2, out=None, **kwargs):
111111
# TODO: copy_when_strides=False (now it's done for faster implementation with transpose arrays)
112112
x1_desc = dpnp.get_dpnp_descriptor(
113113
x1,
114-
copy_when_strides=True,
114+
copy_when_strides=False,
115115
copy_when_nondefault_queue=False,
116116
alloc_usm_type=usm_type,
117117
alloc_queue=queue,
118118
)
119119
x2_desc = dpnp.get_dpnp_descriptor(
120120
x2,
121-
copy_when_strides=True,
121+
copy_when_strides=False,
122122
copy_when_nondefault_queue=False,
123123
alloc_usm_type=usm_type,
124124
alloc_queue=queue,

0 commit comments

Comments
 (0)