From ccdfabbe1166880c396f72cf9a8d0b1ef49da8a4 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Fri, 24 Mar 2023 09:48:21 -0500 Subject: [PATCH 1/2] add unstack, moveaxis, swapaxes --- dpctl/tensor/__init__.py | 6 ++ dpctl/tensor/_manipulation_functions.py | 101 +++++++++++++++++++ dpctl/tests/test_usm_ndarray_manipulation.py | 76 ++++++++++++++ 3 files changed, 183 insertions(+) diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index 77f102fa56..e970af98f5 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -69,11 +69,14 @@ finfo, flip, iinfo, + moveaxis, permute_dims, result_type, roll, squeeze, stack, + swapaxes, + unstack, ) from dpctl.tensor._print import ( get_print_options, @@ -143,6 +146,9 @@ "complex128", "iinfo", "finfo", + "unstack", + "moveaxis", + "swapaxes", "can_cast", "result_type", "meshgrid", diff --git a/dpctl/tensor/_manipulation_functions.py b/dpctl/tensor/_manipulation_functions.py index c2a77c3599..94995b5433 100644 --- a/dpctl/tensor/_manipulation_functions.py +++ b/dpctl/tensor/_manipulation_functions.py @@ -741,6 +741,107 @@ def finfo(dtype): return finfo_object(dtype) +def unstack(X, axis=0): + """ + Args: + x (usm_ndarray): input array + + axis (int): axis along which X is unstacked. + If `X` has rank (i.e, number of dimensions) `N`, + a valid `axis` must reside in the half-open interval `[-N, N)`. + default value is axis=0. + + Returns: + out (usm_narray): A tuple of arrays. + + Raises: + AxisError: if provided axis position is invalid. + """ + if not isinstance(X, dpt.usm_ndarray): + raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") + + axis = normalize_axis_index(axis, X.ndim) + Y = dpt.moveaxis(X, axis, 0) + + return tuple(Y[i] for i in range(Y.shape[0])) + + +def moveaxis(X, src, dst): + """ + Args: + x (usm_ndarray): input array + + src (int or a sequence of int): Original positions of the axes to move. + These must be unique. If `X` has rank (i.e., number of dimensions) `N`, + a valid `axis` must reside in the half-open interval `[-N, N)`. + + dst (int or a sequence of int): Destination positions for each of the + original axes. These must also be unique. If `X` has rank + (i.e., number of dimensions) `N`, a valid `axis` must reside + in the half-open interval `[-N, N)`. + + Returns: + out (usm_narray): Array with moved axes. + The returned array must has the same data type as `X`, + is created on the same device as `X` and has the same USM allocation + type as `X`. + + Raises: + AxisError: if provided axis position is invalid. + """ + if not isinstance(X, dpt.usm_ndarray): + raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") + + if not isinstance(src, (tuple, list)): + src = (src,) + + if not isinstance(dst, (tuple, list)): + dst = (dst,) + + src = normalize_axis_tuple(src, X.ndim, "src") + dst = normalize_axis_tuple(dst, X.ndim, "dst") + ind = list(range(0, X.ndim)) + for i in range(len(src)): + ind.remove(src[i]) # using the value here which is the same as index + ind.insert(dst[i], src[i]) + + return dpt.permute_dims(X, tuple(ind)) + + +def swapaxes(X, axis1, axis2): + """ + Args: + x (usm_ndarray): input array + + axis1 (int): First axis. + If `X` has rank (i.e., number of dimensions) `N`, + a valid `axis` must reside in the half-open interval `[-N, N)`. + + axis2 (int): Second axis. + If `X` has rank (i.e., number of dimensions) `N`, + a valid `axis` must reside in the half-open interval `[-N, N)`. + + Returns: + out (usm_narray): Swapped array. + The returned array must has the same data type as `X`, + is created on the same device as `X` and has the same USM allocation + type as `X`. + + Raises: + AxisError: if provided axis position is invalid. + """ + if not isinstance(X, dpt.usm_ndarray): + raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") + + axis1 = normalize_axis_index(axis1, X.ndim, "axis1") + axis2 = normalize_axis_index(axis2, X.ndim, "axis2") + + ind = list(range(0, X.ndim)) + ind[axis1] = axis2 + ind[axis2] = axis1 + return dpt.permute_dims(X, tuple(ind)) + + def _supported_dtype(dtypes): for dtype in dtypes: if dtype.char not in "?bBhHiIlLqQefdFD": diff --git a/dpctl/tests/test_usm_ndarray_manipulation.py b/dpctl/tests/test_usm_ndarray_manipulation.py index 5dc227409d..a6e8ec244c 100644 --- a/dpctl/tests/test_usm_ndarray_manipulation.py +++ b/dpctl/tests/test_usm_ndarray_manipulation.py @@ -1046,3 +1046,79 @@ def test_result_type(): X_np = [np.ones((2), dtype=np.int64), np.int32, "float16"] assert dpt.result_type(*X) == np.result_type(*X_np) + + +def test_swapaxes_1d(): + x = np.array([[1, 2, 3]]) + exp = np.swapaxes(x, 0, 1) + + y = dpt.asarray([[1, 2, 3]]) + res = dpt.swapaxes(y, 0, 1) + + assert_array_equal(exp, dpt.asnumpy(res)) + + +def test_swapaxes_2d(): + x = np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]]) + exp = np.swapaxes(x, 0, 2) + + y = dpt.asarray([[[0, 1], [2, 3]], [[4, 5], [6, 7]]]) + res = dpt.swapaxes(y, 0, 2) + + assert_array_equal(exp, dpt.asnumpy(res)) + + +def test_moveaxis_1axis(): + x = np.arange(60).reshape((3, 4, 5)) + exp = np.moveaxis(x, 0, -1) + + y = dpt.reshape(dpt.arange(60), (3, 4, 5)) + res = dpt.moveaxis(y, 0, -1) + + assert_array_equal(exp, dpt.asnumpy(res)) + + +def test_moveaxis_2axes(): + x = np.arange(60).reshape((3, 4, 5)) + exp = np.moveaxis(x, [0, 1], [-1, -2]) + + y = dpt.reshape(dpt.arange(60), (3, 4, 5)) + res = dpt.moveaxis(y, [0, 1], [-1, -2]) + + assert_array_equal(exp, dpt.asnumpy(res)) + + +def test_moveaxis_3axes(): + x = np.arange(60).reshape((3, 4, 5)) + exp = np.moveaxis(x, [0, 1, 2], [-1, -2, -3]) + + y = dpt.reshape(dpt.arange(60), (3, 4, 5)) + res = dpt.moveaxis(y, [0, 1, 2], [-1, -2, -3]) + + assert_array_equal(exp, dpt.asnumpy(res)) + + +def test_unstack_axis0(): + y = dpt.reshape(dpt.arange(6), (2, 3)) + res = dpt.unstack(y) + + assert_array_equal(dpt.asnumpy(y[0, ...]), dpt.asnumpy(res[0])) + assert_array_equal(dpt.asnumpy(y[1, ...]), dpt.asnumpy(res[1])) + + +def test_unstack_axis1(): + y = dpt.reshape(dpt.arange(6), (2, 3)) + res = dpt.unstack(y, 1) + + assert_array_equal(dpt.asnumpy(y[:, 0, ...]), dpt.asnumpy(res[0])) + assert_array_equal(dpt.asnumpy(y[:, 1, ...]), dpt.asnumpy(res[1])) + assert_array_equal(dpt.asnumpy(y[:, 2, ...]), dpt.asnumpy(res[2])) + + +def test_unstack_axis2(): + y = dpt.reshape(dpt.arange(60), (4, 5, 3)) + res = dpt.unstack(y, 2) + + assert_array_equal(dpt.asnumpy(y[:, :, 0, ...]), dpt.asnumpy(res[0])) + assert_array_equal(dpt.asnumpy(y[:, :, 1, ...]), dpt.asnumpy(res[1])) + assert_array_equal(dpt.asnumpy(y[:, :, 2, ...]), dpt.asnumpy(res[2])) From 18e680bacb292048ef6954c604f56c132e724b9c Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Mon, 27 Mar 2023 06:19:48 -0500 Subject: [PATCH 2/2] Fixed docstrings for unstack, swapaxes, moveaxis --- dpctl/tensor/_manipulation_functions.py | 74 ++++++++++++++----------- 1 file changed, 43 insertions(+), 31 deletions(-) diff --git a/dpctl/tensor/_manipulation_functions.py b/dpctl/tensor/_manipulation_functions.py index 94995b5433..caa36a3dc8 100644 --- a/dpctl/tensor/_manipulation_functions.py +++ b/dpctl/tensor/_manipulation_functions.py @@ -1,6 +1,6 @@ # Data Parallel Control (dpctl) # -# Copyright 2020-2022 Intel Corporation +# Copyright 2020-2023 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -742,20 +742,23 @@ def finfo(dtype): def unstack(X, axis=0): - """ + """unstack(x, axis=0) + + Splits an array in a sequence of arrays along the given axis. + Args: x (usm_ndarray): input array - axis (int): axis along which X is unstacked. - If `X` has rank (i.e, number of dimensions) `N`, - a valid `axis` must reside in the half-open interval `[-N, N)`. - default value is axis=0. + axis (int, optional): axis along which `x` is unstacked. + If `x` has rank (i.e, number of dimensions) `N`, + a valid `axis` must reside in the half-open interval `[-N, N)`. + Default: `0`. Returns: - out (usm_narray): A tuple of arrays. + Tuple[usm_ndarray,...]: A tuple of arrays. Raises: - AxisError: if provided axis position is invalid. + AxisError: if the `axis` value is invalid. """ if not isinstance(X, dpt.usm_ndarray): raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") @@ -767,27 +770,33 @@ def unstack(X, axis=0): def moveaxis(X, src, dst): - """ + """moveaxis(x, src, dst) + + Moves axes of an array to new positions. + Args: x (usm_ndarray): input array - src (int or a sequence of int): Original positions of the axes to move. - These must be unique. If `X` has rank (i.e., number of dimensions) `N`, - a valid `axis` must reside in the half-open interval `[-N, N)`. + src (int or a sequence of int): + Original positions of the axes to move. + These must be unique. If `x` has rank (i.e., number of + dimensions) `N`, a valid `axis` must be in the + half-open interval `[-N, N)`. - dst (int or a sequence of int): Destination positions for each of the - original axes. These must also be unique. If `X` has rank - (i.e., number of dimensions) `N`, a valid `axis` must reside - in the half-open interval `[-N, N)`. + dst (int or a sequence of int): + Destination positions for each of the original axes. + These must also be unique. If `x` has rank + (i.e., number of dimensions) `N`, a valid `axis` must be + in the half-open interval `[-N, N)`. Returns: - out (usm_narray): Array with moved axes. - The returned array must has the same data type as `X`, - is created on the same device as `X` and has the same USM allocation - type as `X`. + usm_narray: Array with moved axes. + The returned array must has the same data type as `x`, + is created on the same device as `x` and has the same + USM allocation type as `x`. Raises: - AxisError: if provided axis position is invalid. + AxisError: if `axis` value is invalid. """ if not isinstance(X, dpt.usm_ndarray): raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") @@ -809,26 +818,29 @@ def moveaxis(X, src, dst): def swapaxes(X, axis1, axis2): - """ + """swapaxes(x, axis1, axis2) + + Interchanges two axes of an array. + Args: x (usm_ndarray): input array axis1 (int): First axis. - If `X` has rank (i.e., number of dimensions) `N`, - a valid `axis` must reside in the half-open interval `[-N, N)`. + If `x` has rank (i.e., number of dimensions) `N`, + a valid `axis` must be in the half-open interval `[-N, N)`. axis2 (int): Second axis. - If `X` has rank (i.e., number of dimensions) `N`, - a valid `axis` must reside in the half-open interval `[-N, N)`. + If `x` has rank (i.e., number of dimensions) `N`, + a valid `axis` must be in the half-open interval `[-N, N)`. Returns: - out (usm_narray): Swapped array. - The returned array must has the same data type as `X`, - is created on the same device as `X` and has the same USM allocation - type as `X`. + usm_narray: Array with swapped axes. + The returned array must has the same data type as `x`, + is created on the same device as `x` and has the same USM + allocation type as `x`. Raises: - AxisError: if provided axis position is invalid. + AxisError: if `axis` value is invalid. """ if not isinstance(X, dpt.usm_ndarray): raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")