Skip to content

Implementing repeat function #875

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 24, 2025
5 changes: 2 additions & 3 deletions ci/Numba-array-api-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ array_api_tests/test_creation_functions.py::test_empty_like
array_api_tests/test_data_type_functions.py::test_finfo[complex64]
array_api_tests/test_manipulation_functions.py::test_squeeze
array_api_tests/test_has_names.py::test_has_names[utility-diff]
array_api_tests/test_has_names.py::test_has_names[manipulation-repeat]

array_api_tests/test_has_names.py::test_has_names[manipulation-tile]
array_api_tests/test_has_names.py::test_has_names[manipulation-unstack]
array_api_tests/test_has_names.py::test_has_names[statistical-cumulative_sum]
Expand All @@ -79,7 +79,7 @@ array_api_tests/test_has_names.py::test_has_names[indexing-take_along_axis]
array_api_tests/test_has_names.py::test_has_names[searching-count_nonzero]
array_api_tests/test_has_names.py::test_has_names[searching-searchsorted]
array_api_tests/test_signatures.py::test_func_signature[diff]
array_api_tests/test_signatures.py::test_func_signature[repeat]

array_api_tests/test_signatures.py::test_func_signature[tile]
array_api_tests/test_signatures.py::test_func_signature[unstack]
array_api_tests/test_signatures.py::test_func_signature[take_along_axis]
Expand Down Expand Up @@ -107,7 +107,6 @@ array_api_tests/test_statistical_functions.py::test_cumulative_sum
array_api_tests/test_array_object.py::test_getitem_arrays_and_ints_1[None]
array_api_tests/test_array_object.py::test_getitem_arrays_and_ints_2[1]
array_api_tests/test_array_object.py::test_getitem_arrays_and_ints_2[None]
array_api_tests/test_manipulation_functions.py::test_repeat
array_api_tests/test_searching_functions.py::test_count_nonzero
array_api_tests/test_searching_functions.py::test_searchsorted
array_api_tests/test_manipulation_functions.py::test_tile
Expand Down
2 changes: 2 additions & 0 deletions sparse/numba_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
permute_dims,
prod,
real,
repeat,
reshape,
round,
squeeze,
Expand Down Expand Up @@ -335,6 +336,7 @@
"where",
"zeros",
"zeros_like",
"repeat",
]


Expand Down
44 changes: 43 additions & 1 deletion sparse/numba_backend/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np

from ._coo import as_coo
from ._coo import as_coo, expand_dims
from ._sparse_array import SparseArray
from ._utils import (
_zero_of_dtype,
Expand Down Expand Up @@ -3104,3 +3104,45 @@ def vecdot(x1, x2, /, *, axis=-1):
x1 = np.conjugate(x1)

return np.sum(x1 * x2, axis=axis, dtype=np.result_type(x1, x2))


def repeat(a, repeats, axis=None):
"""
Repeat each element of an array after themselves

Parameters
----------
a : SparseArray
Input sparse arrays
repeats : int
The number of repetitions for each element.
(Uneven repeats are not yet Implemented.)
axis : int, optional
The axis along which to repeat values. Returns a flattened sparse array if not specified.

Returns
-------
out : SparseArray
A sparse array which has the same shape as a, except along the given axis.
"""
if not isinstance(a, SparseArray):
raise TypeError("`a` must be a SparseArray.")

if not isinstance(repeats, int):
raise Exception("`repeats` must be an integer, uneven repeats are not yet Implemented.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise Exception("`repeats` must be an integer, uneven repeats are not yet Implemented.")
raise ValueError("`repeats` must be an integer, uneven repeats are not supported.")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it not possible to implement uneven repeats for sparse arrays? It's not possible via broadcasting, we might have to find a different way.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's required by the Array API standard, I'd implement it as follows: Take the ceiling; use the current implementation then truncate to the desired length.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the standard specifies this: https://data-apis.org/array-api/2024.12/API_specification/generated/array_api.repeat.html#repeat

Should I modify the behaviour in this PR or in future(keeping it as not implemented)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is good as the change is related.


axes = list(range(a.ndim))
new_shape = list(a.shape)
axis_is_none = False
if axis is None:
a = a.reshape(-1)
axis = 0
axis_is_none = True
axes[a.ndim - 1], axes[axis] = axes[axis], axes[a.ndim - 1]
new_shape[axis] *= repeats
a = expand_dims(a, axis=axis + 1)
shape_to_broadcast = a.shape[: axis + 1] + (a.shape[axis + 1] * repeats,) + a.shape[axis + 2 :]
a = broadcast_to(a, shape_to_broadcast)
if not axis_is_none:
return a.reshape(new_shape)
return a.reshape(new_shape).flatten()
26 changes: 26 additions & 0 deletions sparse/numba_backend/tests/test_coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1926,3 +1926,29 @@ def test_xH_x():
assert_eq(Ysp.conj().T @ Y, Y.conj().T @ Y)
assert_eq(Ysp.conj().T @ Ysp, Y.conj().T @ Y)
assert_eq(Y.conj().T @ Ysp.conj().T, Y.conj().T @ Y.conj().T)


@pytest.mark.parametrize("ndim", range(1, 5))
@pytest.mark.parametrize("repeats", [1, 2, 3])
def test_repeat(ndim, repeats):
rng = np.random.default_rng()
shape = tuple(rng.integers(1, 4) for _ in range(ndim))
a = rng.integers(1, 10, size=shape)
with pytest.raises(TypeError, match="`a` must be a SparseArray"):
sparse.repeat(a, repeats=2)
sparse_a = COO.from_numpy(a)
with pytest.raises(Exception, match="`repeats` must be an integer"):
sparse.repeat(sparse_a, repeats=[2, 2, 2])
for axis in range(ndim):
expected = np.repeat(a, repeats=repeats, axis=axis)
result_sparse = sparse.repeat(sparse_a, repeats=repeats, axis=axis)
actual = result_sparse.todense()
assert actual.shape == expected.shape, f"Shape mismatch on axis {axis}: {actual.shape} vs {expected.shape}"
np.testing.assert_array_equal(actual, expected)

expected = np.repeat(a, repeats=repeats, axis=None)
result_sparse = sparse.repeat(sparse_a, repeats=repeats, axis=None)
actual = result_sparse.todense()
print(f"Expected: {expected}, Actual: {actual}")
assert actual.shape == expected.shape
np.testing.assert_array_equal(actual, expected)
1 change: 1 addition & 0 deletions sparse/numba_backend/tests/test_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def test_namespace():
"real",
"reciprocal",
"remainder",
"repeat",
"reshape",
"result_type",
"roll",
Expand Down
Loading