Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

make array.reshape compatible with numpy #9790

Merged
merged 4 commits into from
Feb 19, 2018
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,12 +926,12 @@ def _at(self, idx):
self.handle, mx_uint(idx), ctypes.byref(handle)))
return NDArray(handle=handle, writable=self.writable)

def reshape(self, shape):
def reshape(self, *shape, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(self, *args, shape=None)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't work in py2

"""Returns a **view** of this array with a new shape without altering any data.

Parameters
----------
shape : tuple of int
shape : tuple of int, or n ints
The new shape should not change the array size, namely
``np.prod(new_shape)`` should be equal to ``np.prod(self.shape)``.

Expand Down Expand Up @@ -960,6 +960,11 @@ def reshape(self, shape):
[ 4., 5.]], dtype=float32)
>>> y = x.reshape((3,-1))
>>> y.asnumpy()
array([[ 0., 1.],
[ 2., 3.],
[ 4., 5.]], dtype=float32)
>>> y = x.reshape(3,2)
>>> y.asnumpy()
array([[ 0., 1.],
[ 2., 3.],
[ 4., 5.]], dtype=float32)
Expand All @@ -968,6 +973,14 @@ def reshape(self, shape):
array([[-1., -1., -1.],
[-1., -1., -1.]], dtype=float32)
"""
if len(shape) == 1 and isinstance(shape[0], (list, tuple)):
shape = shape[0]
elif not len(shape):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens for reshape(1, 2, shape=1)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keyword argument is ignored.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would that be a problem?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

an error should be raised

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Numpy throws an error too, though the error is confusing.

In [1]: import numpy as np

In [2]: a = np.ones((3,5))

In [3]: a.reshape(3,5,shape=(7,))
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-3-cac6c36bb3d2> in <module>()
----> 1 a.reshape(3,5,shape=(7,))

TypeError: 'shape' is an invalid keyword argument for this function

for key, value in kwargs.items():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to iterate. test for len(kwargs) == 1 and kwargs.get('shape', None) directly

Copy link
Member Author

@szha szha Feb 15, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loop also helps throw exception with the invalid argument name, which is consistent with numpy's behavior. The check would be lost in your proposal.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is slow. Use test for len(kwargs) == 1, then kwargs.get('shape', None)

if key == 'shape':
shape = value
else:
raise TypeError("'%s' is an invalid keyword argument for this function"%key)
handle = NDArrayHandle()

# Actual reshape
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/ndarray/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _at(self, idx):
def _slice(self, start, stop):
raise NotSupportedForSparseNDArray(self._slice, None, start, stop)

def reshape(self, shape):
def reshape(self, *shape, **kwargs):
raise NotSupportedForSparseNDArray(self.reshape, None, shape)

@property
Expand Down
5 changes: 5 additions & 0 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,11 @@ def test_ndarray_reshape():
[5, 6],
[7, 8]])
assert same(tensor.reshape((-1, 2)).asnumpy(), true_res.asnumpy())
assert same(tensor.reshape(4, 2).asnumpy(), true_res.asnumpy())
assert same(tensor.reshape(-1, 2).asnumpy(), true_res.asnumpy())
true_res = mx.nd.arange(8) + 1
assert same(tensor.reshape(-1).asnumpy(), true_res.asnumpy())
assert same(tensor.reshape(8).asnumpy(), true_res.asnumpy())


def test_ndarray_choose():
Expand Down