-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use version of NDArray split that always returns a list. #454
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -860,3 +860,28 @@ def uncast_conditionally(data: mx.sym.Symbol, dtype: str) -> mx.sym.Symbol: | |
if dtype != C.DTYPE_FP32: | ||
return mx.sym.cast(data=data, dtype=C.DTYPE_FP32) | ||
return data | ||
|
||
|
||
def split(data: mx.nd.NDArray, | ||
num_outputs: int, | ||
axis: int = 1, | ||
squeeze_axis: bool = False) -> List[mx.nd.NDArray]: | ||
""" | ||
Version of mxnet.ndarray.split that always returns a list. The original | ||
implementation only returns a list if num_outputs > 1: | ||
https://mxnet.incubator.apache.org/api/python/ndarray/ndarray.html#mxnet.ndarray.split | ||
|
||
Splits an array along a particular axis into multiple sub-arrays. | ||
|
||
:param data: The input. | ||
:param num_outputs: Number of splits. Note that this should evenly divide | ||
the length of the axis. | ||
:param axis: Axis along which to split. | ||
:param squeeze_axis: If true, Removes the axis with length 1 from the shapes | ||
of the output arrays. | ||
:return: List of NDArrays resulting from the split. | ||
""" | ||
ndarray_or_list = data.split(num_outputs=num_outputs, axis=axis, squeeze_axis=squeeze_axis) | ||
if num_outputs == 1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it make sense to avoid the split altogether when num_outputs is 1? If squeeze_axis==True, one only would need a reshape/squeeze which is essentially a no-op. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another good point. I think it's a toss-up between staying as close to the original as possible versus micro-optimizing the call. Since we're now using this in just one place, called once per batch, I would lean toward keeping it this way for clarity. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, probably not worth the additional complexity. |
||
return [ndarray_or_list] | ||
return ndarray_or_list |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we don't need this change here, as in the symbolic API split seems to return always an 'indexable' Symbol/SliceChannel. We do these source-factor related splits also in other places of the code and it works there just fine. Also, the util function isn't typed for symbols and I am surprised
data.split
(aka using the fluent method) works for symbols).This throws an error:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch!