Inconsistent behavior in split() when num_outputs=1 #14064
Comments
Hey, this is the MXNet Label Bot. |
@jgasthaus Thanks for pointing it out. Yes, you are right, it can be documented. It is not a bug, since it follows the definition @mxnet-label-bot add [Question, Doc, Good First Issue] |
To me, such diverging return type is an anti-pattern that complicates all downstream code, and it should be addressed as part of MXNet 2.0 #9686. |
@ChaiBapchya Sure, I can try to submit a doc-change PR. I agree with @szha that this might be a good candidate for a breaking API change in 2.0. We have a few places in our code where we now have to branch because of this. It can also lead to fairly subtle bugs, especially since both lists and NDArrays support indexing, so you might not notice that you are indexing the returned array along the 0-th axis as opposed to the list as intended. Also note that this is not consistent with numpy's
|
Actually, I just noticed that the behavior differs between
|
Delving a bit deeper, here's what I found Similar behavior between NDArray and Symbol for num_outputs=2NDArray num_ouputs=2
Symbol num_ouputs=2
Abnormal behavior between NDArray and Symbol for num_outputs=2NDArray num_ouputs=1
Symbol num_ouputs=1
|
in 2.x we will focus on numpy array instead of ndarray. the np.split is now behaving consistently.
|
Description
The behaviour of
NDArray.split()
is inconsistent and surprising whennum_outputs=1
: Whennum_outputs > 1
,split()
returns a list containing the individual split elements, but whennum_outputs=1
, the single resulting array is returned directly, without being wrapped in a list.If this is the intended behavior, it appears to be undocumented.
See also #8827.
Environment info (Required)
Package used (Python/R/Scala/Julia):
Python
The text was updated successfully, but these errors were encountered: