-
-
Notifications
You must be signed in to change notification settings - Fork 132
Implementing repeat
function
#875
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
8182dec
ac37ba0
4541b77
d7850ac
5551ae6
917ed7f
39418bd
a55e5a1
7c8d5b2
db78f49
f2d67ce
ea38652
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -10,7 +10,7 @@ | |||||
|
||||||
import numpy as np | ||||||
|
||||||
from ._coo import as_coo | ||||||
from ._coo import as_coo, expand_dims | ||||||
from ._sparse_array import SparseArray | ||||||
from ._utils import ( | ||||||
_zero_of_dtype, | ||||||
|
@@ -3104,3 +3104,45 @@ def vecdot(x1, x2, /, *, axis=-1): | |||||
x1 = np.conjugate(x1) | ||||||
|
||||||
return np.sum(x1 * x2, axis=axis, dtype=np.result_type(x1, x2)) | ||||||
|
||||||
|
||||||
def repeat(a, repeats, axis=None): | ||||||
""" | ||||||
Repeat each element of an array after themselves | ||||||
|
||||||
Parameters | ||||||
---------- | ||||||
a : SparseArray | ||||||
Input sparse arrays | ||||||
repeats : int | ||||||
The number of repetitions for each element. | ||||||
(Uneven repeats are not yet Implemented.) | ||||||
axis : int, optional | ||||||
The axis along which to repeat values. Returns a flattened sparse array if not specified. | ||||||
|
||||||
Returns | ||||||
------- | ||||||
out : SparseArray | ||||||
A sparse array which has the same shape as a, except along the given axis. | ||||||
""" | ||||||
if not isinstance(a, SparseArray): | ||||||
raise TypeError("`a` must be a SparseArray.") | ||||||
|
||||||
if not isinstance(repeats, int): | ||||||
raise Exception("`repeats` must be an integer, uneven repeats are not yet Implemented.") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it not possible to implement uneven repeats for sparse arrays? It's not possible via broadcasting, we might have to find a different way. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it's required by the Array API standard, I'd implement it as follows: Take the ceiling; use the current implementation then truncate to the desired length. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the standard specifies this: https://data-apis.org/array-api/2024.12/API_specification/generated/array_api.repeat.html#repeat Should I modify the behaviour in this PR or in future(keeping it as not implemented)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This PR is good as the change is related. |
||||||
|
||||||
axes = list(range(a.ndim)) | ||||||
new_shape = list(a.shape) | ||||||
axis_is_none = False | ||||||
if axis is None: | ||||||
a = a.reshape(-1) | ||||||
axis = 0 | ||||||
axis_is_none = True | ||||||
axes[a.ndim - 1], axes[axis] = axes[axis], axes[a.ndim - 1] | ||||||
new_shape[axis] *= repeats | ||||||
a = expand_dims(a, axis=axis + 1) | ||||||
shape_to_broadcast = a.shape[: axis + 1] + (a.shape[axis + 1] * repeats,) + a.shape[axis + 2 :] | ||||||
a = broadcast_to(a, shape_to_broadcast) | ||||||
if not axis_is_none: | ||||||
return a.reshape(new_shape) | ||||||
return a.reshape(new_shape).flatten() |
Uh oh!
There was an error while loading. Please reload this page.