Skip to content

Commit f4c8a3d

Browse files
authored
Implementing repeat function (#875)
1 parent 9988853 commit f4c8a3d

File tree

5 files changed

+77
-4
lines changed

5 files changed

+77
-4
lines changed

ci/Numba-array-api-xfails.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ array_api_tests/test_creation_functions.py::test_empty_like
7070
array_api_tests/test_data_type_functions.py::test_finfo[complex64]
7171
array_api_tests/test_manipulation_functions.py::test_squeeze
7272
array_api_tests/test_has_names.py::test_has_names[utility-diff]
73-
array_api_tests/test_has_names.py::test_has_names[manipulation-repeat]
7473
array_api_tests/test_has_names.py::test_has_names[manipulation-tile]
7574
array_api_tests/test_has_names.py::test_has_names[manipulation-unstack]
7675
array_api_tests/test_has_names.py::test_has_names[statistical-cumulative_sum]
@@ -79,7 +78,6 @@ array_api_tests/test_has_names.py::test_has_names[indexing-take_along_axis]
7978
array_api_tests/test_has_names.py::test_has_names[searching-count_nonzero]
8079
array_api_tests/test_has_names.py::test_has_names[searching-searchsorted]
8180
array_api_tests/test_signatures.py::test_func_signature[diff]
82-
array_api_tests/test_signatures.py::test_func_signature[repeat]
8381
array_api_tests/test_signatures.py::test_func_signature[tile]
8482
array_api_tests/test_signatures.py::test_func_signature[unstack]
8583
array_api_tests/test_signatures.py::test_func_signature[take_along_axis]
@@ -107,7 +105,6 @@ array_api_tests/test_statistical_functions.py::test_cumulative_sum
107105
array_api_tests/test_array_object.py::test_getitem_arrays_and_ints_1[None]
108106
array_api_tests/test_array_object.py::test_getitem_arrays_and_ints_2[1]
109107
array_api_tests/test_array_object.py::test_getitem_arrays_and_ints_2[None]
110-
array_api_tests/test_manipulation_functions.py::test_repeat
111108
array_api_tests/test_searching_functions.py::test_count_nonzero
112109
array_api_tests/test_searching_functions.py::test_searchsorted
113110
array_api_tests/test_manipulation_functions.py::test_tile

sparse/numba_backend/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@
119119
permute_dims,
120120
prod,
121121
real,
122+
repeat,
122123
reshape,
123124
round,
124125
squeeze,
@@ -335,6 +336,7 @@
335336
"where",
336337
"zeros",
337338
"zeros_like",
339+
"repeat",
338340
]
339341

340342

sparse/numba_backend/_common.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import numpy as np
1212

13-
from ._coo import as_coo
13+
from ._coo import as_coo, expand_dims
1414
from ._sparse_array import SparseArray
1515
from ._utils import (
1616
_zero_of_dtype,
@@ -3104,3 +3104,46 @@ def vecdot(x1, x2, /, *, axis=-1):
31043104
x1 = np.conjugate(x1)
31053105

31063106
return np.sum(x1 * x2, axis=axis, dtype=np.result_type(x1, x2))
3107+
3108+
3109+
def repeat(a, repeats, axis=None):
3110+
"""
3111+
Repeat each element of an array after themselves
3112+
3113+
Parameters
3114+
----------
3115+
a : SparseArray
3116+
Input sparse arrays
3117+
repeats : int
3118+
The number of repetitions for each element.
3119+
(Uneven repeats are not yet Implemented.)
3120+
axis : int, optional
3121+
The axis along which to repeat values. Returns a flattened sparse array if not specified.
3122+
3123+
Returns
3124+
-------
3125+
out : SparseArray
3126+
A sparse array which has the same shape as a, except along the given axis.
3127+
"""
3128+
if not isinstance(a, SparseArray):
3129+
raise TypeError("`a` must be a SparseArray.")
3130+
3131+
if not isinstance(repeats, int):
3132+
raise ValueError("`repeats` must be an integer, uneven repeats are not yet Implemented.")
3133+
axes = list(range(a.ndim))
3134+
new_shape = list(a.shape)
3135+
axis_is_none = False
3136+
if axis is None:
3137+
a = a.reshape(-1)
3138+
axis = 0
3139+
axis_is_none = True
3140+
if axis < 0:
3141+
axis = a.ndim + axis
3142+
axes[a.ndim - 1], axes[axis] = axes[axis], axes[a.ndim - 1]
3143+
new_shape[axis] *= repeats
3144+
a = expand_dims(a, axis=axis + 1)
3145+
shape_to_broadcast = a.shape[: axis + 1] + (a.shape[axis + 1] * repeats,) + a.shape[axis + 2 :]
3146+
a = broadcast_to(a, shape_to_broadcast)
3147+
if not axis_is_none:
3148+
return a.reshape(new_shape)
3149+
return a.reshape(new_shape).flatten()

sparse/numba_backend/tests/test_coo.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1926,3 +1926,33 @@ def test_xH_x():
19261926
assert_eq(Ysp.conj().T @ Y, Y.conj().T @ Y)
19271927
assert_eq(Ysp.conj().T @ Ysp, Y.conj().T @ Y)
19281928
assert_eq(Y.conj().T @ Ysp.conj().T, Y.conj().T @ Y.conj().T)
1929+
1930+
1931+
def test_repeat_invalid_input():
1932+
a = np.eye(3)
1933+
with pytest.raises(TypeError, match="`a` must be a SparseArray"):
1934+
sparse.repeat(a, repeats=2)
1935+
with pytest.raises(ValueError, match="`repeats` must be an integer"):
1936+
sparse.repeat(COO.from_numpy(a), repeats=[2, 2, 2])
1937+
1938+
1939+
@pytest.mark.parametrize("ndim", range(1, 5))
1940+
@pytest.mark.parametrize("repeats", [1, 2, 3])
1941+
def test_repeat(ndim, repeats):
1942+
rng = np.random.default_rng()
1943+
shape = tuple(rng.integers(1, 4) for _ in range(ndim))
1944+
a = rng.integers(1, 10, size=shape)
1945+
sparse_a = COO.from_numpy(a)
1946+
for axis in [*range(-ndim, ndim), None]:
1947+
expected = np.repeat(a, repeats=repeats, axis=axis)
1948+
result_sparse = sparse.repeat(sparse_a, repeats=repeats, axis=axis)
1949+
actual = result_sparse.todense()
1950+
assert actual.shape == expected.shape, f"Shape mismatch on axis {axis}: {actual.shape} vs {expected.shape}"
1951+
np.testing.assert_array_equal(actual, expected)
1952+
1953+
expected = np.repeat(a, repeats=repeats, axis=None)
1954+
result_sparse = sparse.repeat(sparse_a, repeats=repeats, axis=None)
1955+
actual = result_sparse.todense()
1956+
print(f"Expected: {expected}, Actual: {actual}")
1957+
assert actual.shape == expected.shape
1958+
np.testing.assert_array_equal(actual, expected)

sparse/numba_backend/tests/test_namespace.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def test_namespace():
134134
"real",
135135
"reciprocal",
136136
"remainder",
137+
"repeat",
137138
"reshape",
138139
"result_type",
139140
"roll",

0 commit comments

Comments
 (0)