Skip to content

Commit

Permalink
Grid conserves subclasses of ndarray (#56)
Browse files Browse the repository at this point in the history
- uses solution proposed by @AstroMike (np.asanyarray())
- fix #56
- add test (with np.ma.MaskedArray)
- update CHANGELOG
  • Loading branch information
orbeckst committed Apr 6, 2019
1 parent be64439 commit 04dcbc1
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 8 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG
Expand Up @@ -24,6 +24,8 @@ The rules for this file:
* Added missing floordivision to Grid (PR #53)
* fix test on ARM (#51)
* fix incorrect reading of ncstart and nrstart in CCP4 (#57)
* fix that arithemtical operations broke inheritance (#56)
* fix so that subclasses of ndarray are retained on input (#56)

Changes (do not affect user)

Expand Down
10 changes: 5 additions & 5 deletions gridData/core.py
Expand Up @@ -126,13 +126,13 @@ def __init__(self, grid=None, edges=None, origin=None, delta=None,
self.load(grid)
elif not (grid is None or edges is None):
# set up from histogramdd-type data
self.grid = numpy.asarray(grid)
self.grid = numpy.asanyarray(grid)
self.edges = edges
self._update()
elif not (grid is None or origin is None or delta is None):
# setup from generic data
origin = numpy.asarray(origin)
delta = numpy.asarray(delta)
origin = numpy.asanyarray(origin)
delta = numpy.asanyarray(delta)
if len(origin) != grid.ndim:
raise TypeError(
"Dimension of origin is not the same as grid dimension.")
Expand All @@ -148,7 +148,7 @@ def __init__(self, grid=None, edges=None, origin=None, delta=None,
self.edges = [origin[dim] +
(numpy.arange(m + 1) - 0.5) * delta[dim]
for dim, m in enumerate(grid.shape)]
self.grid = numpy.asarray(grid)
self.grid = numpy.asanyarray(grid)
self._update()
else:
# empty, must manually populate with load()
Expand Down Expand Up @@ -709,7 +709,7 @@ def ndmeshgrid(*arrs):
for i, arr in enumerate(arrs):
slc = [1] * dim
slc[i] = lens[i]
arr2 = numpy.asarray(arr).reshape(slc)
arr2 = numpy.asanyarray(arr).reshape(slc)
for j, sz in enumerate(lens):
if j != i:
arr2 = arr2.repeat(sz, axis=j)
Expand Down
18 changes: 15 additions & 3 deletions gridData/tests/test_grid.py
Expand Up @@ -9,6 +9,9 @@

from gridData import Grid

def f_arithmetic(g):
return g + g - 2.5 * g / (g + 5.3)

@pytest.fixture(scope="class")
def data():
d = dict(
Expand Down Expand Up @@ -148,11 +151,20 @@ class DerivedGrid(Grid):

dg = DerivedGrid(data['griddata'], origin=data['origin'],
delta=data['delta'])
result = dg + dg - 2.5 * dg / (dg + 5.3)
result = f_arithmetic(dg)

assert isinstance(result, DerivedGrid)

g = data['grid']
ref = g + g - 2.5 * g / (dg + 5.3)
ref = f_arithmetic(data['grid'])
assert_almost_equal(result.grid, ref.grid)

def test_anyarray(data):
ma = np.ma.MaskedArray(data['griddata'])
mg = Grid(ma, origin=data['origin'], delta=data['delta'])

assert isinstance(mg.grid, ma.__class__)

result = f_arithmetic(mg)
ref = f_arithmetic(data['grid'])

assert_almost_equal(result.grid, ref.grid)

0 comments on commit 04dcbc1

Please sign in to comment.