From 99644835013d79f9de51bdcaa67b4c6383a67139 Mon Sep 17 00:00:00 2001 From: Natalia Polina Date: Thu, 28 Jul 2022 14:57:40 -0500 Subject: [PATCH 1/3] Added dpctl.tensor.stack feature and tests stack() function joins a sequence of arrays along a new axis and follows array API spec. https://data-apis.org/array-api/latest/API_specification/generated/signatures.manipulation_functions.stack.html#signatures.manipulation_functions.stack --- dpctl/tensor/__init__.py | 2 + dpctl/tensor/_manipulation_functions.py | 78 ++++++++++--- dpctl/tests/test_usm_ndarray_manipulation.py | 110 +++++++++++++++++++ 3 files changed, 175 insertions(+), 15 deletions(-) diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index de5532c5ed..44e77aa6b3 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -45,6 +45,7 @@ permute_dims, roll, squeeze, + stack, ) from dpctl.tensor._reshape import reshape from dpctl.tensor._usmarray import usm_ndarray @@ -68,6 +69,7 @@ "reshape", "roll", "concat", + "stack", "broadcast_arrays", "broadcast_to", "expand_dims", diff --git a/dpctl/tensor/_manipulation_functions.py b/dpctl/tensor/_manipulation_functions.py index 2e36b26dc1..365aa91f5d 100644 --- a/dpctl/tensor/_manipulation_functions.py +++ b/dpctl/tensor/_manipulation_functions.py @@ -288,12 +288,7 @@ def roll(X, shift, axes=None): return res -def concat(arrays, axis=0): - """ - concat(arrays: tuple or list of usm_ndarrays, axis: int) -> usm_ndarray - - Joins a sequence of arrays along an existing axis. - """ +def arrays_validation(arrays): n = len(arrays) if n == 0: raise TypeError("Missing 1 required positional argument: 'arrays'") @@ -324,11 +319,23 @@ def concat(arrays, axis=0): for i in range(1, n): if X0.ndim != arrays[i].ndim: raise ValueError( - "All the input arrays must have same number of " - "dimensions, but the array at index 0 has " - f"{X0.ndim} dimension(s) and the array at index " - f"{i} has {arrays[i].ndim} dimension(s)" + "All the input arrays must have same number of dimensions, " + f"but the array at index 0 has {X0.ndim} dimension(s) and the " + f"array at index {i} has {arrays[i].ndim} dimension(s)" ) + return res_dtype, res_usm_type, exec_q + + +def concat(arrays, axis=0): + """ + concat(arrays: tuple or list of usm_ndarrays, axis: int) -> usm_ndarray + + Joins a sequence of arrays along an existing axis. + """ + res_dtype, res_usm_type, exec_q = arrays_validation(arrays) + + n = len(arrays) + X0 = arrays[0] axis = normalize_axis_index(axis, X0.ndim) X0_shape = X0.shape @@ -337,11 +344,10 @@ def concat(arrays, axis=0): for j in range(X0.ndim): if X0_shape[j] != Xi_shape[j] and j != axis: raise ValueError( - "All the input array dimensions for the " - "concatenation axis must match exactly, but " - f"along dimension {j}, the array at index 0 " - f"has size {X0_shape[j]} and the array at " - f"index {i} has size {Xi_shape[j]}" + "All the input array dimensions for the concatenation " + f"axis must match exactly, but along dimension {j}, the " + f"array at index 0 has size {X0_shape[j]} and the array " + f"at index {i} has size {Xi_shape[j]}" ) res_shape_axis = 0 @@ -373,3 +379,45 @@ def concat(arrays, axis=0): dpctl.SyclEvent.wait_for(hev_list) return res + + +def stack(arrays, axis=0): + """ + stack(arrays: tuple or list of usm_ndarrays, axis: int) -> usm_ndarray + + Joins a sequence of arrays along a new axis. + """ + res_dtype, res_usm_type, exec_q = arrays_validation(arrays) + + n = len(arrays) + X0 = arrays[0] + res_ndim = X0.ndim + 1 + axis = normalize_axis_index(axis, res_ndim) + X0_shape = X0.shape + + for i in range(1, n): + if X0_shape != arrays[i].shape: + raise ValueError("All input arrays must have the same shape") + + res_shape = tuple( + X0_shape[i - 1 * (i >= axis)] if i != axis else n + for i in range(res_ndim) + ) + + res = dpt.empty( + res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q + ) + + hev_list = [] + for i in range(n): + c_shapes_copy = tuple( + i if j == axis else np.s_[:] for j in range(res_ndim) + ) + hev, _ = ti._copy_usm_ndarray_into_usm_ndarray( + src=arrays[i], dst=res[c_shapes_copy], sycl_queue=exec_q + ) + hev_list.append(hev) + + dpctl.SyclEvent.wait_for(hev_list) + + return res diff --git a/dpctl/tests/test_usm_ndarray_manipulation.py b/dpctl/tests/test_usm_ndarray_manipulation.py index 9e99372639..363464f186 100644 --- a/dpctl/tests/test_usm_ndarray_manipulation.py +++ b/dpctl/tests/test_usm_ndarray_manipulation.py @@ -889,3 +889,113 @@ def test_concat_3arrays(data): R = dpt.concat([X, Y, Z], axis=axis) assert_array_equal(Rnp, dpt.asnumpy(R)) + + +def test_stack_incorrect_shape(): + try: + q = dpctl.SyclQueue() + except dpctl.SyclQueueCreationError: + pytest.skip("Queue could not be created") + + X = dpt.ones((1,), sycl_queue=q) + Y = dpt.ones((2,), sycl_queue=q) + + pytest.raises(ValueError, dpt.stack, [X, Y], 0) + + +@pytest.mark.parametrize( + "data", + [ + [(6,), 0], + [(2, 3), 1], + [(3, 2), -1], + [(1, 6), 2], + [(2, 1, 3), 2], + ], +) +def test_stack_1array(data): + try: + q = dpctl.SyclQueue() + except dpctl.SyclQueueCreationError: + pytest.skip("Queue could not be created") + + shape, axis = data + + Xnp = np.arange(6).reshape(shape) + X = dpt.asarray(Xnp, sycl_queue=q) + + Ynp = np.stack([Xnp], axis=axis) + Y = dpt.stack([X], axis=axis) + + assert_array_equal(Ynp, dpt.asnumpy(Y)) + + Ynp = np.stack((Xnp,), axis=axis) + Y = dpt.stack((X,), axis=axis) + + assert_array_equal(Ynp, dpt.asnumpy(Y)) + + +@pytest.mark.parametrize( + "data", + [ + [(1,), 0], + [(0, 2), 0], + [(2, 0), 0], + [(2, 3), 0], + [(2, 3), 1], + [(2, 3), 2], + [(2, 3), -1], + [(2, 3), -2], + [(2, 2, 2), 1], + ], +) +def test_stack_2arrays(data): + try: + q = dpctl.SyclQueue() + except dpctl.SyclQueueCreationError: + pytest.skip("Queue could not be created") + + shape, axis = data + + Xnp = np.ones(shape) + X = dpt.asarray(Xnp, sycl_queue=q) + + Ynp = np.zeros(shape) + Y = dpt.asarray(Ynp, sycl_queue=q) + + Znp = np.stack([Xnp, Ynp], axis=axis) + print(Znp.shape) + Z = dpt.stack([X, Y], axis=axis) + + assert_array_equal(Znp, dpt.asnumpy(Z)) + + +@pytest.mark.parametrize( + "data", + [ + [(1,), 0], + [(0, 2), 0], + [(2, 1, 2), 1], + ], +) +def test_stack_3arrays(data): + try: + q = dpctl.SyclQueue() + except dpctl.SyclQueueCreationError: + pytest.skip("Queue could not be created") + + shape, axis = data + + Xnp = np.ones(shape) + X = dpt.asarray(Xnp, sycl_queue=q) + + Ynp = np.zeros(shape) + Y = dpt.asarray(Ynp, sycl_queue=q) + + Znp = np.full(shape, 2.0) + Z = dpt.asarray(Znp, sycl_queue=q) + + Rnp = np.stack([Xnp, Ynp, Znp], axis=axis) + R = dpt.stack([X, Y, Z], axis=axis) + + assert_array_equal(Rnp, dpt.asnumpy(R)) From 9fce1a255e5448d9f5e1f511d1d903bf8ac44b1e Mon Sep 17 00:00:00 2001 From: Natalia Polina Date: Thu, 28 Jul 2022 15:47:37 -0500 Subject: [PATCH 2/3] Renames an internal function --- dpctl/tensor/_manipulation_functions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dpctl/tensor/_manipulation_functions.py b/dpctl/tensor/_manipulation_functions.py index 365aa91f5d..90c5cc895c 100644 --- a/dpctl/tensor/_manipulation_functions.py +++ b/dpctl/tensor/_manipulation_functions.py @@ -288,7 +288,7 @@ def roll(X, shift, axes=None): return res -def arrays_validation(arrays): +def _arrays_validation(arrays): n = len(arrays) if n == 0: raise TypeError("Missing 1 required positional argument: 'arrays'") @@ -332,7 +332,7 @@ def concat(arrays, axis=0): Joins a sequence of arrays along an existing axis. """ - res_dtype, res_usm_type, exec_q = arrays_validation(arrays) + res_dtype, res_usm_type, exec_q = _arrays_validation(arrays) n = len(arrays) X0 = arrays[0] @@ -387,7 +387,7 @@ def stack(arrays, axis=0): Joins a sequence of arrays along a new axis. """ - res_dtype, res_usm_type, exec_q = arrays_validation(arrays) + res_dtype, res_usm_type, exec_q = _arrays_validation(arrays) n = len(arrays) X0 = arrays[0] From 7e21c25592d2db84065910158e99f3bb8fb24c62 Mon Sep 17 00:00:00 2001 From: Natalia Polina Date: Thu, 28 Jul 2022 15:49:32 -0500 Subject: [PATCH 3/3] Noted the change in CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index fc4cf6dc32..7b26c3ab44 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Wrote manual page about working with `dpctl.SyclQueue` [#829](https://github.com/IntelPython/dpctl/pull/829). * Added cmake scripts to dpctl package layout and a way to query the location [#853](https://github.com/IntelPython/dpctl/pull/853). * Implemented `dpctl.tensor.concat` function from array-API [#867](https://github.com/IntelPython/dpctl/867). +* Implemented `dpctl.tensor.stack` function from array-API [#872](https://github.com/IntelPython/dpctl/872). ### Changed