diff --git a/dpctl/tensor/_reduction.py b/dpctl/tensor/_reduction.py index 0fcda387ec..b918152467 100644 --- a/dpctl/tensor/_reduction.py +++ b/dpctl/tensor/_reduction.py @@ -20,6 +20,7 @@ import dpctl.tensor as dpt import dpctl.tensor._tensor_impl as ti import dpctl.tensor._tensor_reductions_impl as tri +from dpctl.utils import ExecutionPlacementError from ._type_utils import ( _default_accumulation_dtype, @@ -33,6 +34,7 @@ def _reduction_over_axis( axis, dtype, keepdims, + out, _reduction_fn, _dtype_supported, _default_reduction_type_fn, @@ -42,13 +44,16 @@ def _reduction_over_axis( nd = x.ndim if axis is None: axis = tuple(range(nd)) - if not isinstance(axis, (tuple, list)): - axis = (axis,) - axis = normalize_axis_tuple(axis, nd, "axis") + perm = list(axis) + arr = x + else: + if not isinstance(axis, (tuple, list)): + axis = (axis,) + axis = normalize_axis_tuple(axis, nd, "axis") + perm = [i for i in range(nd) if i not in axis] + list(axis) + arr = dpt.permute_dims(x, perm) red_nd = len(axis) - perm = [i for i in range(nd) if i not in axis] + list(axis) - arr2 = dpt.permute_dims(x, perm) - res_shape = arr2.shape[: nd - red_nd] + res_shape = arr.shape[: nd - red_nd] q = x.sycl_queue inp_dt = x.dtype if dtype is None: @@ -58,39 +63,90 @@ def _reduction_over_axis( res_dt = _to_device_supported_dtype(res_dt, q.sycl_device) res_usm_type = x.usm_type - if red_nd == 0: - return dpt.astype(x, res_dt, copy=True) - host_tasks_list = [] - if _dtype_supported(inp_dt, res_dt, res_usm_type, q): - res = dpt.empty( + implemented_types = _dtype_supported(inp_dt, res_dt, res_usm_type, q) + if dtype is None and not implemented_types: + raise RuntimeError( + "Automatically determined reduction data type does not " + "have direct implementation" + ) + orig_out = out + if out is not None: + if not isinstance(out, dpt.usm_ndarray): + raise TypeError( + f"output array must be of usm_ndarray type, got {type(out)}" + ) + if not out.flags.writable: + raise ValueError("provided `out` array is read-only") + if not keepdims: + final_res_shape = res_shape + else: + inp_shape = x.shape + final_res_shape = tuple( + inp_shape[i] if i not in axis else 1 for i in range(nd) + ) + if not out.shape == final_res_shape: + raise ValueError( + "The shape of input and output arrays are inconsistent. " + f"Expected output shape is {final_res_shape}, got {out.shape}" + ) + if res_dt != out.dtype: + raise ValueError( + f"Output array of type {res_dt} is needed, got {out.dtype}" + ) + if dpctl.utils.get_execution_queue((q, out.sycl_queue)) is None: + raise ExecutionPlacementError( + "Input and output allocation queues are not compatible" + ) + if keepdims: + out = dpt.squeeze(out, axis=axis) + orig_out = out + if ti._array_overlap(x, out) and implemented_types: + out = dpt.empty_like(out) + else: + out = dpt.empty( res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q ) - ht_e, _ = _reduction_fn( - src=arr2, trailing_dims_to_reduce=red_nd, dst=res, sycl_queue=q + + host_tasks_list = [] + if red_nd == 0: + ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr, dst=out, sycl_queue=q + ) + host_tasks_list.append(ht_e_cpy) + if not (orig_out is None or orig_out is out): + ht_e_cpy2, _ = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, dst=orig_out, sycl_queue=q, depends=[cpy_e] + ) + host_tasks_list.append(ht_e_cpy2) + out = orig_out + dpctl.SyclEvent.wait_for(host_tasks_list) + return out + + if implemented_types: + ht_e, red_e = _reduction_fn( + src=arr, trailing_dims_to_reduce=red_nd, dst=out, sycl_queue=q ) host_tasks_list.append(ht_e) - else: - if dtype is None: - raise RuntimeError( - "Automatically determined reduction data type does not " - "have direct implementation" + if not (orig_out is None or orig_out is out): + ht_e_cpy, _ = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, dst=orig_out, sycl_queue=q, depends=[red_e] ) + host_tasks_list.append(ht_e_cpy) + out = orig_out + else: if _dtype_supported(res_dt, res_dt, res_usm_type, q): tmp = dpt.empty( - arr2.shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q + arr.shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q ) ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray( - src=arr2, dst=tmp, sycl_queue=q + src=arr, dst=tmp, sycl_queue=q ) host_tasks_list.append(ht_e_cpy) - res = dpt.empty( - res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q - ) ht_e_red, _ = _reduction_fn( src=tmp, trailing_dims_to_reduce=red_nd, - dst=res, + dst=out, sycl_queue=q, depends=[cpy_e], ) @@ -98,18 +154,15 @@ def _reduction_over_axis( else: buf_dt = _default_reduction_type_fn(inp_dt, q) tmp = dpt.empty( - arr2.shape, dtype=buf_dt, usm_type=res_usm_type, sycl_queue=q + arr.shape, dtype=buf_dt, usm_type=res_usm_type, sycl_queue=q ) ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray( - src=arr2, dst=tmp, sycl_queue=q + src=arr, dst=tmp, sycl_queue=q ) tmp_res = dpt.empty( res_shape, dtype=buf_dt, usm_type=res_usm_type, sycl_queue=q ) host_tasks_list.append(ht_e_cpy) - res = dpt.empty( - res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q - ) ht_e_red, r_e = _reduction_fn( src=tmp, trailing_dims_to_reduce=red_nd, @@ -119,20 +172,20 @@ def _reduction_over_axis( ) host_tasks_list.append(ht_e_red) ht_e_cpy2, _ = ti._copy_usm_ndarray_into_usm_ndarray( - src=tmp_res, dst=res, sycl_queue=q, depends=[r_e] + src=tmp_res, dst=out, sycl_queue=q, depends=[r_e] ) host_tasks_list.append(ht_e_cpy2) if keepdims: res_shape = res_shape + (1,) * red_nd inv_perm = sorted(range(nd), key=lambda d: perm[d]) - res = dpt.permute_dims(dpt.reshape(res, res_shape), inv_perm) + out = dpt.permute_dims(dpt.reshape(out, res_shape), inv_perm) dpctl.SyclEvent.wait_for(host_tasks_list) - return res + return out -def sum(x, /, *, axis=None, dtype=None, keepdims=False): +def sum(x, /, *, axis=None, dtype=None, keepdims=False, out=None): """ Calculates the sum of elements in the input array ``x``. @@ -142,8 +195,8 @@ def sum(x, /, *, axis=None, dtype=None, keepdims=False): axis (Optional[int, Tuple[int, ...]]): axis or axes along which sums must be computed. If a tuple of unique integers, sums are computed over multiple axes. - If `None`, the sum is computed over the entire array. - Default: `None`. + If ``None``, the sum is computed over the entire array. + Default: ``None``. dtype (Optional[dtype]): data type of the returned array. If ``None``, the default data type is inferred from the "kind" of the input array data type. @@ -156,7 +209,7 @@ def sum(x, /, *, axis=None, dtype=None, keepdims=False): where input array ``x`` is allocated. * If ``x`` has unsigned integral data type, the returned array will have the default unsigned integral type for the device - where input array `x` is allocated. + where input array ``x`` is allocated. array ``x`` is allocated. * If ``x`` has a boolean data type, the returned array will have the default signed integral type for the device @@ -172,6 +225,11 @@ def sum(x, /, *, axis=None, dtype=None, keepdims=False): compatible with the input arrays according to Array Broadcasting rules. Otherwise, if ``False``, the reduced axes are not included in the returned array. Default: ``False``. + out (Optional[usm_ndarray]): + the array into which the result is written. + The data type of ``out`` must match the expected shape and the + expected data type of the result or (if provided) ``dtype``. + If ``None`` then a new array is returned. Default: ``None``. Returns: usm_ndarray: @@ -185,13 +243,14 @@ def sum(x, /, *, axis=None, dtype=None, keepdims=False): axis, dtype, keepdims, + out, tri._sum_over_axis, tri._sum_over_axis_dtype_supported, _default_accumulation_dtype, ) -def prod(x, /, *, axis=None, dtype=None, keepdims=False): +def prod(x, /, *, axis=None, dtype=None, keepdims=False, out=None): """ Calculates the product of elements in the input array ``x``. @@ -230,6 +289,11 @@ def prod(x, /, *, axis=None, dtype=None, keepdims=False): compatible with the input arrays according to Array Broadcasting rules. Otherwise, if ``False``, the reduced axes are not included in the returned array. Default: ``False``. + out (Optional[usm_ndarray]): + the array into which the result is written. + The data type of ``out`` must match the expected shape and the + expected data type of the result or (if provided) ``dtype``. + If ``None`` then a new array is returned. Default: ``None``. Returns: usm_ndarray: @@ -243,13 +307,14 @@ def prod(x, /, *, axis=None, dtype=None, keepdims=False): axis, dtype, keepdims, + out, tri._prod_over_axis, tri._prod_over_axis_dtype_supported, _default_accumulation_dtype, ) -def logsumexp(x, /, *, axis=None, dtype=None, keepdims=False): +def logsumexp(x, /, *, axis=None, dtype=None, keepdims=False, out=None): """ Calculates the logarithm of the sum of exponentials of elements in the input array ``x``. @@ -270,7 +335,7 @@ def logsumexp(x, /, *, axis=None, dtype=None, keepdims=False): returned array will have the same data type as ``x``. * If ``x`` has a boolean or integral data type, the returned array will have the default floating point data type for the device - where input array `x` is allocated. + where input array ``x`` is allocated. * If ``x`` has a complex-valued floating-point data type, an error is raised. @@ -284,6 +349,11 @@ def logsumexp(x, /, *, axis=None, dtype=None, keepdims=False): compatible with the input arrays according to Array Broadcasting rules. Otherwise, if ``False``, the reduced axes are not included in the returned array. Default: ``False``. + out (Optional[usm_ndarray]): + the array into which the result is written. + The data type of ``out`` must match the expected shape and the + expected data type of the result or (if provided) ``dtype``. + If ``None`` then a new array is returned. Default: ``None``. Returns: usm_ndarray: @@ -297,6 +367,7 @@ def logsumexp(x, /, *, axis=None, dtype=None, keepdims=False): axis, dtype, keepdims, + out, tri._logsumexp_over_axis, lambda inp_dt, res_dt, *_: tri._logsumexp_over_axis_dtype_supported( inp_dt, res_dt @@ -305,7 +376,7 @@ def logsumexp(x, /, *, axis=None, dtype=None, keepdims=False): ) -def reduce_hypot(x, /, *, axis=None, dtype=None, keepdims=False): +def reduce_hypot(x, /, *, axis=None, dtype=None, keepdims=False, out=None): """ Calculates the square root of the sum of squares of elements in the input array ``x``. @@ -326,7 +397,7 @@ def reduce_hypot(x, /, *, axis=None, dtype=None, keepdims=False): returned array will have the same data type as ``x``. * If ``x`` has a boolean or integral data type, the returned array will have the default floating point data type for the device - where input array `x` is allocated. + where input array ``x`` is allocated. * If ``x`` has a complex-valued floating-point data type, an error is raised. @@ -339,6 +410,11 @@ def reduce_hypot(x, /, *, axis=None, dtype=None, keepdims=False): compatible with the input arrays according to Array Broadcasting rules. Otherwise, if ``False``, the reduced axes are not included in the returned array. Default: ``False``. + out (Optional[usm_ndarray]): + the array into which the result is written. + The data type of ``out`` must match the expected shape and the + expected data type of the result or (if provided) ``dtype``. + If ``None`` then a new array is returned. Default: ``None``. Returns: usm_ndarray: @@ -352,6 +428,7 @@ def reduce_hypot(x, /, *, axis=None, dtype=None, keepdims=False): axis, dtype, keepdims, + out, tri._hypot_over_axis, lambda inp_dt, res_dt, *_: tri._hypot_over_axis_dtype_supported( inp_dt, res_dt @@ -360,60 +437,105 @@ def reduce_hypot(x, /, *, axis=None, dtype=None, keepdims=False): ) -def _comparison_over_axis(x, axis, keepdims, _reduction_fn): +def _comparison_over_axis(x, axis, keepdims, out, _reduction_fn): if not isinstance(x, dpt.usm_ndarray): raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") nd = x.ndim if axis is None: axis = tuple(range(nd)) - if not isinstance(axis, (tuple, list)): - axis = (axis,) - axis = normalize_axis_tuple(axis, nd, "axis") + perm = list(axis) + x_tmp = x + else: + if not isinstance(axis, (tuple, list)): + axis = (axis,) + axis = normalize_axis_tuple(axis, nd, "axis") + perm = [i for i in range(nd) if i not in axis] + list(axis) + x_tmp = dpt.permute_dims(x, perm) red_nd = len(axis) - perm = [i for i in range(nd) if i not in axis] + list(axis) - x_tmp = dpt.permute_dims(x, perm) + if any([x_tmp.shape[i] == 0 for i in range(-red_nd, 0)]): + raise ValueError("reduction cannot be performed over zero-size axes") res_shape = x_tmp.shape[: nd - red_nd] exec_q = x.sycl_queue res_dt = x.dtype res_usm_type = x.usm_type - if x.size == 0: - if any([x.shape[i] == 0 for i in axis]): - raise ValueError( - "reduction cannot be performed over zero-size axes" + + orig_out = out + if out is not None: + if not isinstance(out, dpt.usm_ndarray): + raise TypeError( + f"output array must be of usm_ndarray type, got {type(out)}" ) + if not out.flags.writable: + raise ValueError("provided `out` array is read-only") + if not keepdims: + final_res_shape = res_shape else: - return dpt.empty( - res_shape, - dtype=res_dt, - usm_type=res_usm_type, - sycl_queue=exec_q, + inp_shape = x.shape + final_res_shape = tuple( + inp_shape[i] if i not in axis else 1 for i in range(nd) + ) + if not out.shape == final_res_shape: + raise ValueError( + "The shape of input and output arrays are inconsistent. " + f"Expected output shape is {final_res_shape}, got {out.shape}" ) + if res_dt != out.dtype: + raise ValueError( + f"Output array of type {res_dt} is needed, got {out.dtype}" + ) + if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None: + raise ExecutionPlacementError( + "Input and output allocation queues are not compatible" + ) + if keepdims: + out = dpt.squeeze(out, axis=axis) + orig_out = out + if ti._array_overlap(x, out): + out = dpt.empty_like(out) + else: + out = dpt.empty( + res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=exec_q + ) + + host_tasks_list = [] if red_nd == 0: - return dpt.copy(x) + ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray( + src=x_tmp, dst=out, sycl_queue=exec_q + ) + host_tasks_list.append(ht_e_cpy) + if not (orig_out is None or orig_out is out): + ht_e_cpy2, _ = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, dst=orig_out, sycl_queue=exec_q, depends=[cpy_e] + ) + host_tasks_list.append(ht_e_cpy2) + out = orig_out + dpctl.SyclEvent.wait_for(host_tasks_list) + return out - res = dpt.empty( - res_shape, - dtype=res_dt, - usm_type=res_usm_type, - sycl_queue=exec_q, - ) - hev, _ = _reduction_fn( + hev, red_ev = _reduction_fn( src=x_tmp, trailing_dims_to_reduce=red_nd, - dst=res, + dst=out, sycl_queue=exec_q, ) + host_tasks_list.append(hev) + if not (orig_out is None or orig_out is out): + ht_e_cpy2, _ = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, dst=orig_out, sycl_queue=exec_q, depends=[red_ev] + ) + host_tasks_list.append(ht_e_cpy2) + out = orig_out if keepdims: res_shape = res_shape + (1,) * red_nd inv_perm = sorted(range(nd), key=lambda d: perm[d]) - res = dpt.permute_dims(dpt.reshape(res, res_shape), inv_perm) - hev.wait() - return res + out = dpt.permute_dims(dpt.reshape(out, res_shape), inv_perm) + dpctl.SyclEvent.wait_for(host_tasks_list) + return out -def max(x, /, *, axis=None, keepdims=False): +def max(x, /, *, axis=None, keepdims=False, out=None): """ Calculates the maximum value of the input array ``x``. @@ -431,6 +553,11 @@ def max(x, /, *, axis=None, keepdims=False): compatible with the input arrays according to Array Broadcasting rules. Otherwise, if ``False``, the reduced axes are not included in the returned array. Default: ``False``. + out (Optional[usm_ndarray]): + the array into which the result is written. + The data type of ``out`` must match the expected shape and the + expected data type of the result. + If ``None`` then a new array is returned. Default: ``None``. Returns: usm_ndarray: @@ -438,10 +565,10 @@ def max(x, /, *, axis=None, keepdims=False): entire array, a zero-dimensional array is returned. The returned array has the same data type as ``x``. """ - return _comparison_over_axis(x, axis, keepdims, tri._max_over_axis) + return _comparison_over_axis(x, axis, keepdims, out, tri._max_over_axis) -def min(x, /, *, axis=None, keepdims=False): +def min(x, /, *, axis=None, keepdims=False, out=None): """ Calculates the minimum value of the input array ``x``. @@ -459,6 +586,11 @@ def min(x, /, *, axis=None, keepdims=False): compatible with the input arrays according to Array Broadcasting rules. Otherwise, if ``False``, the reduced axes are not included in the returned array. Default: ``False``. + out (Optional[usm_ndarray]): + the array into which the result is written. + The data type of ``out`` must match the expected shape and the + expected data type of the result. + If ``None`` then a new array is returned. Default: ``None``. Returns: usm_ndarray: @@ -466,69 +598,106 @@ def min(x, /, *, axis=None, keepdims=False): entire array, a zero-dimensional array is returned. The returned array has the same data type as ``x``. """ - return _comparison_over_axis(x, axis, keepdims, tri._min_over_axis) + return _comparison_over_axis(x, axis, keepdims, out, tri._min_over_axis) -def _search_over_axis(x, axis, keepdims, _reduction_fn): +def _search_over_axis(x, axis, keepdims, out, _reduction_fn): if not isinstance(x, dpt.usm_ndarray): raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") nd = x.ndim if axis is None: axis = tuple(range(nd)) - elif isinstance(axis, int): - axis = (axis,) + perm = list(axis) + x_tmp = x else: - raise TypeError( - f"`axis` argument expected `int` or `None`, got {type(axis)}" - ) + if isinstance(axis, int): + axis = (axis,) + else: + raise TypeError( + f"`axis` argument expected `int` or `None`, got {type(axis)}" + ) + axis = normalize_axis_tuple(axis, nd, "axis") + perm = [i for i in range(nd) if i not in axis] + list(axis) + x_tmp = dpt.permute_dims(x, perm) axis = normalize_axis_tuple(axis, nd, "axis") red_nd = len(axis) - perm = [i for i in range(nd) if i not in axis] + list(axis) - x_tmp = dpt.permute_dims(x, perm) + if any([x_tmp.shape[i] == 0 for i in range(-red_nd, 0)]): + raise ValueError("reduction cannot be performed over zero-size axes") res_shape = x_tmp.shape[: nd - red_nd] exec_q = x.sycl_queue res_dt = ti.default_device_index_type(exec_q.sycl_device) res_usm_type = x.usm_type - if x.size == 0: - if any([x.shape[i] == 0 for i in axis]): - raise ValueError( - "reduction cannot be performed over zero-size axes" + + orig_out = out + if out is not None: + if not isinstance(out, dpt.usm_ndarray): + raise TypeError( + f"output array must be of usm_ndarray type, got {type(out)}" ) + if not out.flags.writable: + raise ValueError("provided `out` array is read-only") + if not keepdims: + final_res_shape = res_shape else: - return dpt.empty( - res_shape, - dtype=res_dt, - usm_type=res_usm_type, - sycl_queue=exec_q, + inp_shape = x.shape + final_res_shape = tuple( + inp_shape[i] if i not in axis else 1 for i in range(nd) ) - if red_nd == 0: - return dpt.zeros( + if not out.shape == final_res_shape: + raise ValueError( + "The shape of input and output arrays are inconsistent. " + f"Expected output shape is {final_res_shape}, got {out.shape}" + ) + if res_dt != out.dtype: + raise ValueError( + f"Output array of type {res_dt} is needed, got {out.dtype}" + ) + if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None: + raise ExecutionPlacementError( + "Input and output allocation queues are not compatible" + ) + if keepdims: + out = dpt.squeeze(out, axis=axis) + orig_out = out + if ti._array_overlap(x, out) and red_nd > 0: + out = dpt.empty_like(out) + else: + out = dpt.empty( res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=exec_q ) - res = dpt.empty( - res_shape, - dtype=res_dt, - usm_type=res_usm_type, - sycl_queue=exec_q, - ) - hev, _ = _reduction_fn( + if red_nd == 0: + ht_e_fill, _ = ti._full_usm_ndarray( + fill_value=0, dst=out, sycl_queue=exec_q + ) + ht_e_fill.wait() + return out + + host_tasks_list = [] + hev, red_ev = _reduction_fn( src=x_tmp, trailing_dims_to_reduce=red_nd, - dst=res, + dst=out, sycl_queue=exec_q, ) + host_tasks_list.append(hev) + if not (orig_out is None or orig_out is out): + ht_e_cpy2, _ = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, dst=orig_out, sycl_queue=exec_q, depends=[red_ev] + ) + host_tasks_list.append(ht_e_cpy2) + out = orig_out if keepdims: res_shape = res_shape + (1,) * red_nd inv_perm = sorted(range(nd), key=lambda d: perm[d]) - res = dpt.permute_dims(dpt.reshape(res, res_shape), inv_perm) - hev.wait() - return res + out = dpt.permute_dims(dpt.reshape(out, res_shape), inv_perm) + dpctl.SyclEvent.wait_for(host_tasks_list) + return out -def argmax(x, /, *, axis=None, keepdims=False): +def argmax(x, /, *, axis=None, keepdims=False, out=None): """ Returns the indices of the maximum values of the input array ``x`` along a specified axis. @@ -549,6 +718,11 @@ def argmax(x, /, *, axis=None, keepdims=False): compatible with the input arrays according to Array Broadcasting rules. Otherwise, if ``False``, the reduced axes are not included in the returned array. Default: ``False``. + out (Optional[usm_ndarray]): + the array into which the result is written. + The data type of ``out`` must match the expected shape and the + expected data type of the result. + If ``None`` then a new array is returned. Default: ``None``. Returns: usm_ndarray: @@ -557,10 +731,10 @@ def argmax(x, /, *, axis=None, keepdims=False): zero-dimensional array is returned. The returned array has the default array index data type for the device of ``x``. """ - return _search_over_axis(x, axis, keepdims, tri._argmax_over_axis) + return _search_over_axis(x, axis, keepdims, out, tri._argmax_over_axis) -def argmin(x, /, *, axis=None, keepdims=False): +def argmin(x, /, *, axis=None, keepdims=False, out=None): """ Returns the indices of the minimum values of the input array ``x`` along a specified axis. @@ -581,6 +755,11 @@ def argmin(x, /, *, axis=None, keepdims=False): compatible with the input arrays according to Array Broadcasting rules. Otherwise, if ``False``, the reduced axes are not included in the returned array. Default: ``False``. + out (Optional[usm_ndarray]): + the array into which the result is written. + The data type of ``out`` must match the expected shape and the + expected data type of the result. + If ``None`` then a new array is returned. Default: ``None``. Returns: usm_ndarray: @@ -589,4 +768,4 @@ def argmin(x, /, *, axis=None, keepdims=False): zero-dimensional array is returned. The returned array has the default array index data type for the device of ``x``. """ - return _search_over_axis(x, axis, keepdims, tri._argmin_over_axis) + return _search_over_axis(x, axis, keepdims, out, tri._argmin_over_axis) diff --git a/dpctl/tensor/_search_functions.py b/dpctl/tensor/_search_functions.py index 87bc533a11..94982d0c82 100644 --- a/dpctl/tensor/_search_functions.py +++ b/dpctl/tensor/_search_functions.py @@ -18,6 +18,7 @@ import dpctl.tensor as dpt import dpctl.tensor._tensor_impl as ti from dpctl.tensor._manipulation_functions import _broadcast_shapes +from dpctl.utils import ExecutionPlacementError from ._copy_utils import _empty_like_orderK, _empty_like_triple_orderK from ._type_utils import _all_data_types, _can_cast @@ -40,37 +41,41 @@ def _where_result_type(dt1, dt2, dev): return None -def where(condition, x1, x2): - """where(condition, x1, x2) - +def where(condition, x1, x2, /, *, order="K", out=None): + """ Returns :class:`dpctl.tensor.usm_ndarray` with elements chosen - from `x1` or `x2` depending on `condition`. + from ``x1`` or ``x2`` depending on ``condition``. Args: - condition (usm_ndarray): When True yields from `x1`, - and otherwise yields from `x2`. - Must be compatible with `x1` and `x2` according + condition (usm_ndarray): When ``True`` yields from ``x1``, + and otherwise yields from ``x2``. + Must be compatible with ``x1`` and ``x2`` according to broadcasting rules. x1 (usm_ndarray): Array from which values are chosen when - `condition` is True. - Must be compatible with `condition` and `x2` according + ``condition`` is ``True``. + Must be compatible with ``condition`` and ``x2`` according to broadcasting rules. x2 (usm_ndarray): Array from which values are chosen when - `condition` is not True. - Must be compatible with `condition` and `x2` according + ``condition`` is not ``True``. + Must be compatible with ``condition`` and ``x2`` according to broadcasting rules. + order (``"K"``, ``"C"``, ``"F"``, ``"A"``, optional): + Memory layout of the new output arra, + if parameter ``out`` is ``None``. + Default: ``"K"``. + out (Optional[usm_ndarray]): + the array into which the result is written. + The data type of `out` must match the expected shape and the + expected data type of the result. + If ``None`` then a new array is returned. Default: ``None``. Returns: usm_ndarray: - An array with elements from `x1` where `condition` is True, - and elements from `x2` elsewhere. + An array with elements from ``x1`` where ``condition`` is ``True``, + and elements from ``x2`` elsewhere. The data type of the returned array is determined by applying - the Type Promotion Rules to `x1` and `x2`. - - The memory layout of the returned array is - F-contiguous (column-major) when all inputs are F-contiguous, - and C-contiguous (row-major) otherwise. + the Type Promotion Rules to ``x1`` and ``x2``. """ if not isinstance(condition, dpt.usm_ndarray): raise TypeError( @@ -84,6 +89,8 @@ def where(condition, x1, x2): raise TypeError( "Expecting dpctl.tensor.usm_ndarray type, " f"got {type(x2)}" ) + if order not in ["K", "C", "F", "A"]: + order = "K" exec_q = dpctl.utils.get_execution_queue( ( condition.sycl_queue, @@ -93,7 +100,7 @@ def where(condition, x1, x2): ) if exec_q is None: raise dpctl.utils.ExecutionPlacementError - dst_usm_type = dpctl.utils.get_coerced_usm_type( + out_usm_type = dpctl.utils.get_coerced_usm_type( ( condition.usm_type, x1.usm_type, @@ -103,8 +110,8 @@ def where(condition, x1, x2): x1_dtype = x1.dtype x2_dtype = x2.dtype - dst_dtype = _where_result_type(x1_dtype, x2_dtype, exec_q.sycl_device) - if dst_dtype is None: + out_dtype = _where_result_type(x1_dtype, x2_dtype, exec_q.sycl_device) + if out_dtype is None: raise TypeError( "function 'where' does not support input " f"types ({x1_dtype}, {x2_dtype}), " @@ -114,15 +121,90 @@ def where(condition, x1, x2): res_shape = _broadcast_shapes(condition, x1, x2) - if condition.size == 0: - return dpt.empty( - res_shape, dtype=dst_dtype, usm_type=dst_usm_type, sycl_queue=exec_q + orig_out = out + if out is not None: + if not isinstance(out, dpt.usm_ndarray): + raise TypeError( + "output array must be of usm_ndarray type, got " f"{type(out)}" + ) + + if not out.flags.writable: + raise ValueError("provided `out` array is read-only") + + if out.shape != res_shape: + raise ValueError( + "The shape of input and output arrays are " + f"inconsistent. Expected output shape is {res_shape}, " + f"got {out.shape}" + ) + + if out_dtype != out.dtype: + raise ValueError( + f"Output array of type {out_dtype} is needed, " + f"got {out.dtype}" + ) + + if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None: + raise ExecutionPlacementError( + "Input and output allocation queues are not compatible" + ) + + if ti._array_overlap(condition, out): + if not ti._same_logical_tensors(condition, out): + out = dpt.empty_like(out) + + if ti._array_overlap(x1, out): + if not ti._same_logical_tensors(x1, out): + out = dpt.empty_like(out) + + if ti._array_overlap(x2, out): + if not ti._same_logical_tensors(x2, out): + out = dpt.empty_like(out) + + if order == "A": + order = ( + "F" + if all( + arr.flags.f_contiguous + for arr in ( + condition, + x1, + x2, + ) + ) + else "C" ) + if condition.size == 0: + if out is not None: + return out + else: + if order == "K": + return _empty_like_triple_orderK( + condition, + x1, + x2, + out_dtype, + res_shape, + out_usm_type, + exec_q, + ) + else: + return dpt.empty( + res_shape, + dtype=out_dtype, + order=order, + usm_type=out_usm_type, + sycl_queue=exec_q, + ) + deps = [] wait_list = [] - if x1_dtype != dst_dtype: - _x1 = _empty_like_orderK(x1, dst_dtype) + if x1_dtype != out_dtype: + if order == "K": + _x1 = _empty_like_orderK(x1, out_dtype) + else: + _x1 = dpt.empty_like(x1, dtype=out_dtype, order=order) ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray( src=x1, dst=_x1, sycl_queue=exec_q ) @@ -130,8 +212,11 @@ def where(condition, x1, x2): deps.append(copy1_ev) wait_list.append(ht_copy1_ev) - if x2_dtype != dst_dtype: - _x2 = _empty_like_orderK(x2, dst_dtype) + if x2_dtype != out_dtype: + if order == "K": + _x2 = _empty_like_orderK(x2, out_dtype) + else: + _x2 = dpt.empty_like(x2, dtype=out_dtype, order=order) ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray( src=x2, dst=_x2, sycl_queue=exec_q ) @@ -139,23 +224,43 @@ def where(condition, x1, x2): deps.append(copy2_ev) wait_list.append(ht_copy2_ev) - dst = _empty_like_triple_orderK( - condition, x1, x2, dst_dtype, res_shape, dst_usm_type, exec_q - ) + if out is None: + if order == "K": + out = _empty_like_triple_orderK( + condition, x1, x2, out_dtype, res_shape, out_usm_type, exec_q + ) + else: + out = dpt.empty( + res_shape, + dtype=out_dtype, + order=order, + usm_type=out_usm_type, + sycl_queue=exec_q, + ) condition = dpt.broadcast_to(condition, res_shape) x1 = dpt.broadcast_to(x1, res_shape) x2 = dpt.broadcast_to(x2, res_shape) - hev, _ = ti._where( + hev, where_ev = ti._where( condition=condition, x1=x1, x2=x2, - dst=dst, + dst=out, sycl_queue=exec_q, depends=deps, ) + if not (orig_out is None or orig_out is out): + # Copy the out data from temporary buffer to original memory + ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, + dst=orig_out, + sycl_queue=exec_q, + depends=[where_ev], + ) + ht_copy_out_ev.wait() + out = orig_out dpctl.SyclEvent.wait_for(wait_list) hev.wait() - return dst + return out diff --git a/dpctl/tensor/libtensor/source/where.cpp b/dpctl/tensor/libtensor/source/where.cpp index e3dbbfed6c..3af3fb3ee2 100644 --- a/dpctl/tensor/libtensor/source/where.cpp +++ b/dpctl/tensor/libtensor/source/where.cpp @@ -114,7 +114,12 @@ py_where(const dpctl::tensor::usm_ndarray &condition, } auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); - if (overlap(dst, condition) || overlap(dst, x1) || overlap(dst, x2)) { + auto const &same_logical_tensors = + dpctl::tensor::overlap::SameLogicalTensors(); + if ((overlap(dst, condition) && !same_logical_tensors(dst, condition)) || + (overlap(dst, x1) && !same_logical_tensors(dst, x1)) || + (overlap(dst, x2) && !same_logical_tensors(dst, x2))) + { throw py::value_error("Destination array overlaps with input."); } diff --git a/dpctl/tests/test_usm_ndarray_reductions.py b/dpctl/tests/test_usm_ndarray_reductions.py index 8b22d049e3..4e66ae29a7 100644 --- a/dpctl/tests/test_usm_ndarray_reductions.py +++ b/dpctl/tests/test_usm_ndarray_reductions.py @@ -22,6 +22,7 @@ import dpctl.tensor as dpt from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported +from dpctl.utils import ExecutionPlacementError _no_complex_dtypes = [ "?", @@ -497,3 +498,174 @@ def test_tree_reduction_axis1_axis0(): rtol=tol, atol=tol, ) + + +def test_numeric_reduction_out_kwarg(): + get_queue_or_skip() + + n1, n2, n3 = 3, 4, 5 + x = dpt.ones((n1, n2, n3), dtype="i8") + out = dpt.zeros((2 * n1, 3 * n2), dtype="i8") + res = dpt.sum(x, axis=-1, out=out[::-2, 1::3]) + assert dpt.all(out[::-2, 0::3] == 0) + assert dpt.all(out[::-2, 2::3] == 0) + assert dpt.all(out[::-2, 1::3] == res) + assert dpt.all(out[::-2, 1::3] == 5) + + out = dpt.zeros((2 * n1, 3 * n2, 1), dtype="i8") + res = dpt.sum(x, axis=-1, keepdims=True, out=out[::-2, 1::3]) + assert res.shape == (n1, n2, 1) + assert dpt.all(out[::-2, 0::3] == 0) + assert dpt.all(out[::-2, 2::3] == 0) + assert dpt.all(out[::-2, 1::3] == res) + assert dpt.all(out[::-2, 1::3] == 5) + + res = dpt.sum(x, axis=0, out=x[-1]) + assert dpt.all(x[-1] == res) + assert dpt.all(x[-1] == 3) + assert dpt.all(x[0:-1] == 1) + + # test no-op case + x = dpt.ones((n1, n2, n3), dtype="i8") + out = dpt.zeros((2 * n1, 3 * n2, n3), dtype="i8") + res = dpt.sum(x, axis=(), out=out[::-2, 1::3]) + assert dpt.all(out[::-2, 0::3] == 0) + assert dpt.all(out[::-2, 2::3] == 0) + assert dpt.all(out[::-2, 1::3] == x) + + # test with dtype kwarg + x = dpt.ones((n1, n2, n3), dtype="i4") + out = dpt.zeros((2 * n1, 3 * n2), dtype="f4") + res = dpt.sum(x, axis=-1, dtype="f4", out=out[::-2, 1::3]) + zero_res = dpt.zeros_like(res) + assert dpt.allclose(out[::-2, 0::3], zero_res) + assert dpt.allclose(out[::-2, 2::3], zero_res) + assert dpt.allclose(out[::-2, 1::3], res) + assert dpt.allclose(out[::-2, 1::3], dpt.full_like(res, 5, dtype="f4")) + + +def test_comparison_reduction_out_kwarg(): + get_queue_or_skip() + + n1, n2, n3 = 3, 4, 5 + x = dpt.reshape(dpt.arange(n1 * n2 * n3, dtype="i4"), (n1, n2, n3)) + out = dpt.zeros((2 * n1, 3 * n2), dtype="i4") + res = dpt.max(x, axis=-1, out=out[::-2, 1::3]) + assert dpt.all(out[::-2, 0::3] == 0) + assert dpt.all(out[::-2, 2::3] == 0) + assert dpt.all(out[::-2, 1::3] == res) + assert dpt.all(out[::-2, 1::3] == x[:, :, -1]) + + out = dpt.zeros((2 * n1, 3 * n2, 1), dtype="i4") + res = dpt.max(x, axis=-1, keepdims=True, out=out[::-2, 1::3]) + assert res.shape == (n1, n2, 1) + assert dpt.all(out[::-2, 0::3] == 0) + assert dpt.all(out[::-2, 2::3] == 0) + assert dpt.all(out[::-2, 1::3] == res) + assert dpt.all(out[::-2, 1::3] == x[:, :, -1, dpt.newaxis]) + + # test no-op case + out = dpt.zeros((2 * n1, 3 * n2, n3), dtype="i4") + res = dpt.max(x, axis=(), out=out[::-2, 1::3]) + assert dpt.all(out[::-2, 0::3] == 0) + assert dpt.all(out[::-2, 2::3] == 0) + assert dpt.all(out[::-2, 1::3] == x) + + # test overlap + res = dpt.max(x, axis=0, out=x[0]) + assert dpt.all(x[0] == res) + assert dpt.all(x[0] == x[-1]) + + +def test_search_reduction_out_kwarg(): + get_queue_or_skip() + + n1, n2, n3 = 3, 4, 5 + dt = dpt.__array_namespace_info__().default_dtypes()["indexing"] + + x = dpt.reshape(dpt.arange(n1 * n2 * n3, dtype=dt), (n1, n2, n3)) + out = dpt.zeros((2 * n1, 3 * n2), dtype=dt) + res = dpt.argmax(x, axis=-1, out=out[::-2, 1::3]) + assert dpt.all(out[::-2, 0::3] == 0) + assert dpt.all(out[::-2, 2::3] == 0) + assert dpt.all(out[::-2, 1::3] == res) + assert dpt.all(out[::-2, 1::3] == n2) + + out = dpt.zeros((2 * n1, 3 * n2, 1), dtype=dt) + res = dpt.argmax(x, axis=-1, keepdims=True, out=out[::-2, 1::3]) + assert res.shape == (n1, n2, 1) + assert dpt.all(out[::-2, 0::3] == 0) + assert dpt.all(out[::-2, 2::3] == 0) + assert dpt.all(out[::-2, 1::3] == res) + assert dpt.all(out[::-2, 1::3] == n3 - 1) + + # test no-op case + x = dpt.ones((), dtype=dt) + out = dpt.ones(2, dtype=dt) + res = dpt.argmax(x, axis=None, out=out[1]) + assert dpt.all(out[0] == 1) + assert dpt.all(out[1] == 0) + + # test overlap + x = dpt.reshape(dpt.arange(n1 * n2, dtype=dt), (n1, n2)) + res = dpt.argmax(x, axis=0, out=x[0]) + assert dpt.all(x[0] == res) + assert dpt.all(x[0] == n1 - 1) + + +def test_reduction_out_kwarg_arg_validation(): + q1 = get_queue_or_skip() + q2 = get_queue_or_skip() + + ind_dt = dpt.__array_namespace_info__().default_dtypes()["indexing"] + + x = dpt.ones(10, dtype="f4") + out_wrong_queue = dpt.empty((), dtype="f4", sycl_queue=q2) + out_wrong_dtype = dpt.empty((), dtype="i4", sycl_queue=q1) + out_wrong_shape = dpt.empty(1, dtype="f4", sycl_queue=q1) + out_wrong_keepdims = dpt.empty((), dtype="f4", sycl_queue=q1) + out_not_writable = dpt.empty((), dtype="f4", sycl_queue=q1) + out_not_writable.flags["W"] = False + + with pytest.raises(TypeError): + dpt.sum(x, out=dict()) + with pytest.raises(TypeError): + dpt.max(x, out=dict()) + with pytest.raises(TypeError): + dpt.argmax(x, out=dict()) + with pytest.raises(ExecutionPlacementError): + dpt.sum(x, out=out_wrong_queue) + with pytest.raises(ExecutionPlacementError): + dpt.max(x, out=out_wrong_queue) + with pytest.raises(ExecutionPlacementError): + dpt.argmax(x, out=dpt.empty_like(out_wrong_queue, dtype=ind_dt)) + with pytest.raises(ValueError): + dpt.sum(x, out=out_wrong_dtype) + with pytest.raises(ValueError): + dpt.max(x, out=out_wrong_dtype) + with pytest.raises(ValueError): + dpt.argmax(x, out=dpt.empty_like(out_wrong_dtype, dtype="f4")) + with pytest.raises(ValueError): + dpt.sum(x, out=out_wrong_shape) + with pytest.raises(ValueError): + dpt.max(x, out=out_wrong_shape) + with pytest.raises(ValueError): + dpt.argmax(x, out=dpt.empty_like(out_wrong_shape, dtype=ind_dt)) + with pytest.raises(ValueError): + dpt.sum(x, out=out_not_writable) + with pytest.raises(ValueError): + dpt.max(x, out=out_not_writable) + with pytest.raises(ValueError): + search_not_writable = dpt.empty_like(out_not_writable, dtype=ind_dt) + search_not_writable.flags["W"] = False + dpt.argmax(x, out=search_not_writable) + with pytest.raises(ValueError): + dpt.sum(x, keepdims=True, out=out_wrong_keepdims) + with pytest.raises(ValueError): + dpt.max(x, keepdims=True, out=out_wrong_keepdims) + with pytest.raises(ValueError): + dpt.argmax( + x, + keepdims=True, + out=dpt.empty_like(out_wrong_keepdims, dtype=ind_dt), + ) diff --git a/dpctl/tests/test_usm_ndarray_search_functions.py b/dpctl/tests/test_usm_ndarray_search_functions.py index 04c06e7179..38e106fb9f 100644 --- a/dpctl/tests/test_usm_ndarray_search_functions.py +++ b/dpctl/tests/test_usm_ndarray_search_functions.py @@ -386,31 +386,47 @@ def test_where_order(): ar1 = dpt.zeros(test_sh, dtype=dt1, order="C") ar2 = dpt.ones(test_sh, dtype=dt2, order="C") condition = dpt.zeros(test_sh, dtype="?", order="C") - res = dpt.where(condition, ar1, ar2) - assert res.flags.c_contiguous + res1 = dpt.where(condition, ar1, ar2, order="C") + assert res1.flags.c_contiguous + res2 = dpt.where(condition, ar1, ar2, order="F") + assert res2.flags.f_contiguous + res3 = dpt.where(condition, ar1, ar2, order="A") + assert res3.flags.c_contiguous + res4 = dpt.where(condition, ar1, ar2, order="K") + assert res4.flags.c_contiguous ar1 = dpt.ones(test_sh, dtype=dt1, order="F") ar2 = dpt.ones(test_sh, dtype=dt2, order="F") condition = dpt.zeros(test_sh, dtype="?", order="F") - res = dpt.where(condition, ar1, ar2) - assert res.flags.f_contiguous + res1 = dpt.where(condition, ar1, ar2, order="C") + assert res1.flags.c_contiguous + res2 = dpt.where(condition, ar1, ar2, order="F") + assert res2.flags.f_contiguous + res3 = dpt.where(condition, ar1, ar2, order="A") + assert res2.flags.f_contiguous + res4 = dpt.where(condition, ar1, ar2, order="K") + assert res4.flags.f_contiguous ar1 = dpt.ones(test_sh2, dtype=dt1, order="C")[:20, ::-2] ar2 = dpt.ones(test_sh2, dtype=dt2, order="C")[:20, ::-2] condition = dpt.zeros(test_sh2, dtype="?", order="C")[:20, ::-2] - res = dpt.where(condition, ar1, ar2) - assert res.strides == (n, -1) + res1 = dpt.where(condition, ar1, ar2, order="K") + assert res1.strides == (n, -1) + res2 = dpt.where(condition, ar1, ar2, order="C") + assert res2.strides == (n, 1) ar1 = dpt.ones(test_sh2, dtype=dt1, order="C")[:20, ::-2].mT ar2 = dpt.ones(test_sh2, dtype=dt2, order="C")[:20, ::-2].mT condition = dpt.zeros(test_sh2, dtype="?", order="C")[:20, ::-2].mT - res = dpt.where(condition, ar1, ar2) - assert res.strides == (-1, n) + res1 = dpt.where(condition, ar1, ar2, order="K") + assert res1.strides == (-1, n) + res2 = dpt.where(condition, ar1, ar2, order="C") + assert res2.strides == (n, 1) ar1 = dpt.ones(n, dtype=dt1, order="C") ar2 = dpt.broadcast_to(dpt.ones(n, dtype=dt2, order="C"), test_sh) condition = dpt.zeros(n, dtype="?", order="C") - res = dpt.where(condition, ar1, ar2) + res = dpt.where(condition, ar1, ar2, order="K") assert res.strides == (20, 1) @@ -423,3 +439,86 @@ def test_where_unaligned(): expected = dpt.full(512, 2, dtype="i4") assert dpt.all(dpt.where(x[1:], a, b) == expected) + + +def test_where_out(): + get_queue_or_skip() + + n1, n2, n3 = 3, 4, 5 + ar1 = dpt.reshape(dpt.arange(n1 * n2 * n3, dtype="i4"), (n1, n2, n3)) + ar2 = dpt.full_like(ar1, -5) + condition = dpt.tile( + dpt.reshape( + dpt.asarray([True, False, False, True], dtype="?"), (1, n2, 1) + ), + (n1, 1, n3), + ) + + out = dpt.zeros((2 * n1, 3 * n2, n3), dtype="i4") + res = dpt.where(condition, ar1, ar2, out=out[::-2, 1::3, :]) + + assert dpt.all(res == out[::-2, 1::3, :]) + assert dpt.all(out[::-2, 0::3, :] == 0) + assert dpt.all(out[::-2, 2::3, :] == 0) + + assert dpt.all(res[:, 1:3, :] == -5) + assert dpt.all(res[:, 0, :] == ar1[:, 0, :]) + assert dpt.all(res[:, 3, :] == ar1[:, 3, :]) + + condition = dpt.tile( + dpt.reshape(dpt.asarray([1, 0], dtype="i4"), (1, 2, 1)), + (n1, 2, n3), + ) + res = dpt.where( + condition[:, ::-1, :], condition[:, ::-1, :], condition, out=condition + ) + assert dpt.all(res == condition) + assert dpt.all(condition == 1) + + condition = dpt.tile( + dpt.reshape(dpt.asarray([True, False], dtype="?"), (1, 2, 1)), + (n1, 2, n3), + ) + ar1 = dpt.full((n1, n2, n3), 7, dtype="i4") + ar2 = dpt.full_like(ar1, -5) + res = dpt.where(condition, ar1, ar2, out=ar2[:, ::-1, :]) + assert dpt.all(ar2[:, ::-1, :] == res) + assert dpt.all(ar2[:, ::2, :] == -5) + assert dpt.all(ar2[:, 1::2, :] == 7) + + condition = dpt.tile( + dpt.reshape(dpt.asarray([True, False], dtype="?"), (1, 2, 1)), + (n1, 2, n3), + ) + ar1 = dpt.full((n1, n2, n3), 7, dtype="i4") + ar2 = dpt.full_like(ar1, -5) + res = dpt.where(condition, ar1, ar2, out=ar1[:, ::-1, :]) + assert dpt.all(ar1[:, ::-1, :] == res) + assert dpt.all(ar1[:, ::2, :] == -5) + assert dpt.all(ar1[:, 1::2, :] == 7) + + +def test_where_out_arg_validation(): + q1 = get_queue_or_skip() + q2 = get_queue_or_skip() + + condition = dpt.ones(5, dtype="i4", sycl_queue=q1) + x1 = dpt.ones(5, dtype="i4", sycl_queue=q1) + x2 = dpt.ones(5, dtype="i4", sycl_queue=q1) + + out_wrong_queue = dpt.empty_like(condition, sycl_queue=q2) + out_wrong_dtype = dpt.empty_like(condition, dtype="f4") + out_wrong_shape = dpt.empty(6, dtype="i4", sycl_queue=q1) + out_not_writable = dpt.empty_like(condition) + out_not_writable.flags["W"] = False + + with pytest.raises(TypeError): + dpt.where(condition, x1, x2, out=dict()) + with pytest.raises(ExecutionPlacementError): + dpt.where(condition, x1, x2, out=out_wrong_queue) + with pytest.raises(ValueError): + dpt.where(condition, x1, x2, out=out_wrong_dtype) + with pytest.raises(ValueError): + dpt.where(condition, x1, x2, out=out_wrong_shape) + with pytest.raises(ValueError): + dpt.where(condition, x1, x2, out=out_not_writable)