Skip to content

Commit de26663

Browse files
Use dpnp.get_usm_ndarray in take and update examples
1 parent 98c2003 commit de26663

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,8 @@ def take(x, indices, /, *, axis=None, out=None, mode="wrap"):
580580
>>> np.take(x, indices)
581581
array([4, 3, 6])
582582
583+
In this example "fancy" indexing can be used.
584+
583585
>>> x[indices]
584586
array([4, 3, 6])
585587
@@ -589,6 +591,7 @@ def take(x, indices, /, *, axis=None, out=None, mode="wrap"):
589591
590592
>>> np.take(x, indices, mode="clip")
591593
array([4, 4, 4, 8, 8])
594+
592595
"""
593596

594597
if dpnp.is_supported_array_type(x) and dpnp.is_supported_array_type(
@@ -605,12 +608,8 @@ def take(x, indices, /, *, axis=None, out=None, mode="wrap"):
605608
elif mode not in ("clip", "wrap"):
606609
pass
607610
else:
608-
dpt_array = x.get_array() if isinstance(x, dpnp_array) else x
609-
dpt_indices = (
610-
indices.get_array()
611-
if isinstance(indices, dpnp_array)
612-
else indices
613-
)
611+
dpt_array = dpnp.get_usm_ndarray(x)
612+
dpt_indices = dpnp.get_usm_ndarray(indices)
614613
return dpnp_array._create_from_usm_ndarray(
615614
dpt.take(dpt_array, dpt_indices, axis=axis, mode=mode)
616615
)

0 commit comments

Comments
 (0)