Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Question about mx.nd.split() ? #8827

Closed
kobenaxie opened this issue Nov 27, 2017 · 4 comments
Closed

Question about mx.nd.split() ? #8827

kobenaxie opened this issue Nov 27, 2017 · 4 comments

Comments

@kobenaxie
Copy link

kobenaxie commented Nov 27, 2017

mxnet 0.12.0

when the size of data in the split axis is 1, it will return incorrect result. For example,

`In [1]: import mxnet as mx

In [2]: data = mx.nd.ones((2,3,4))

In [3]: data_s = mx.nd.split(data, axis=1, num_outputs=data.shape[1], squeeze_axis=False)

In [4]: len(data_s)
Out[4]: 3

In [5]: data_s[0].shape
Out[5]: (2L, 1L, 4L)

**but, when splitting x along axis=1 below, it will return 2 NDArray.**
In [6]: x = data_s[0] #with shape  (2L, 1L, 4L)

In [8]: x_s = mx.nd.split(x, axis=1, num_outputs=x.shape[1], squeeze_axis=False)

In [9]: len(x_s)
Out[9]: 2

In [10]: x_s[0].shape
Out[10]: (1L, 4L)
`

Is it a bug ?

@solin319
Copy link
Contributor

len(x_s) return the first axis value in x_s, it's 2.
x_s[0] return the first column in axis=0, it has a shape (1,4)

@kobenaxie
Copy link
Author

As I split along the axis=1, and x has the shape (2,1,4), so normally, it will return a list containing only one ndarray with the same shape:(2,1,4), rather than two ,with each has the shape (1, 4)

@solin319
Copy link
Contributor

solin319 commented Nov 27, 2017

try to print x_s, it's a (2,1,4) ndarray

x_s = mx.nd.split(x, axis=1, num_outputs=x.shape[1], squeeze_axis=False)
x_s

[[[ 1. 1. 1. 1.]]

[[ 1. 1. 1. 1.]]]
<NDArray 2x1x4 @cpu(0)>

len(x_s)
2
x_s[0].shape
(1L, 4L)
x_s.shape
(2L, 1L, 4L)

@kobenaxie
Copy link
Author

I find that if the result after splitting has only one NDArray, it will return this NDArray directly , rather than return a list .
Thank you~

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants