Skip to content

Commit

Permalink
Fixed bugs in data reshaper
Browse files Browse the repository at this point in the history
  • Loading branch information
mlwong committed Aug 19, 2017
1 parent 21dce06 commit 80de62b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 21 deletions.
2 changes: 1 addition & 1 deletion floatpy/parallel/transpose_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(self, grid_partition, direction, dimension=3):
self._pencil_hi = self._pencil_hi - 1

# Initialize the data reshaper.
self._data_reshpaer = data_reshaper.DataReshaper(self._dim, data_order='F')
self._data_reshaper = data_reshaper.DataReshaper(self._dim, data_order='F')


@property
Expand Down
30 changes: 10 additions & 20 deletions floatpy/utilities/data_reshaper.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,42 +99,32 @@ def reshapeFrom3d(self, data, data_output=None):

shape_low_dim = numpy.array(data.shape)

# Check whether the shape of data is valid.
# Check whether the shape of data is valid and get the squeezed shape of data.

if component_idx is None:
if data.ndim != 3:
raise RuntimeError('Dimension of data is invalid!')

if not (numpy.all(data.shape[self._dim:] == 1)):
if data.ndim == 3:
if not all(e == 1 for e in data.shape[self._dim:]):
raise RuntimeError('Shape of data is invalid!')

shape_low_dim = numpy.array(shape_low_dim[:self._dim])

else:
if data.ndim != 4:
raise RuntimeError('Dimension of data is invalid!')

# Check whether the component_idx is valid and get the shape of the component's data.

elif data.ndim == 4:
if self._data_order == 'C':
if component_idx >= data.shape[0] or component_idx < 0:
raise RuntimeError('Component index is invalid!')

if not (numpy.all(data.shape[self._dim+1:] == 1)):
if not all(e == 1 for e in data.shape[self._dim+1:]):
raise RuntimeError('Shape of data is invalid!')

shape_low_dim = numpy.array(shape_low_dim[:1+self._dim])

else:
if component_idx >= data.shape[-1] or component_idx < 0:
raise RuntimeError('Component index is invalid!')

if not (numpy.all(data.shape[self._dim:-1] == 1)):
if not all(e == 1 for e in data.shape[self._dim:-1]):
raise RuntimeError('Shape of data is invalid!')

shape_low_dim = numpy.append(shape_low_dim[:self._dim], shape_low_dim[-1])

else:
raise RuntimeError('Dimension of data is invalid!')

if data_output is None:
return numpy.reshape(data, shape_low_dim, order=self._data_order)
else:
data_output = numpy.reshape(data, shape_low_dim, order=self._data_order)

0 comments on commit 80de62b

Please sign in to comment.