Skip to content

Commit

Permalink
Lazy weighted RMS calculation (#5017)
Browse files Browse the repository at this point in the history
* Enable lazy RMS aggregation with weights

Note that the referenced dask issue was fixed by dask#4236 which was included
in v1.1.0.

* whatsnew

* docstring

* add note about NEP13/18 not applying

* whatsnew PR num fix
  • Loading branch information
rcomer committed Feb 27, 2023
1 parent 8d1e96a commit 21b8a2c
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 42 deletions.
3 changes: 3 additions & 0 deletions docs/src/whatsnew/latest.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ This document explains the changes made to Iris for this release
:ref:`documentation page<community_plugins>` for further information.
(:pull:`5144`)

#. `@rcomer`_ enabled lazy evaluation of :obj:`~iris.analysis.RMS` calcuations
with weights. (:pull:`5017`)


🐛 Bugs Fixed
=============
Expand Down
34 changes: 14 additions & 20 deletions lib/iris/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1583,27 +1583,19 @@ def _lazy_max_run(array, axis=-1, **kwargs):


def _rms(array, axis, **kwargs):
# XXX due to the current limitations in `da.average` (see below), maintain
# an explicit non-lazy aggregation function for now.
# Note: retaining this function also means that if weights are passed to
# the lazy aggregator, the aggregation will fall back to using this
# non-lazy aggregator.
rval = np.sqrt(ma.average(np.square(array), axis=axis, **kwargs))
if not ma.isMaskedArray(array):
rval = np.asarray(rval)
rval = np.sqrt(ma.average(array**2, axis=axis, **kwargs))

return rval


@_build_dask_mdtol_function
def _lazy_rms(array, axis, **kwargs):
# XXX This should use `da.average` and not `da.mean`, as does the above.
# However `da.average` current doesn't handle masked weights correctly
# (see https://github.com/dask/dask/issues/3846).
# To work around this we use da.mean, which doesn't support weights at
# all. Thus trying to use this aggregator with weights will currently
# raise an error in dask due to the unexpected keyword `weights`,
# rather than silently returning the wrong answer.
return da.sqrt(da.mean(array**2, axis=axis, **kwargs))
# Note that, since we specifically need the ma version of average to handle
# weights correctly with masked data, we cannot rely on NEP13/18 and need
# to implement a separate lazy RMS function.

rval = da.sqrt(da.ma.average(array**2, axis=axis, **kwargs))

return rval


def _sum(array, **kwargs):
Expand Down Expand Up @@ -2071,14 +2063,16 @@ def interp_order(length):
the root mean square over a :class:`~iris.cube.Cube`, as computed by
((x0**2 + x1**2 + ... + xN-1**2) / N) ** 0.5.
Additional kwargs associated with the use of this aggregator:
Parameters
----------
* weights (float ndarray):
weights : array-like, optional
Weights matching the shape of the cube or the length of the window for
rolling window operations. The weights are applied to the squares when
taking the mean.
**For example**:
Example
-------
To compute the zonal root mean square over the *longitude* axis of a cube::
Expand Down
32 changes: 10 additions & 22 deletions lib/iris/tests/unit/analysis/test_RMS.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,20 +101,16 @@ def test_1d_weighted(self):
data = as_lazy_data(np.array([4, 7, 10, 8], dtype=np.float64))
weights = np.array([1, 4, 3, 2], dtype=np.float64)
expected_rms = 8.0
# https://github.com/dask/dask/issues/3846.
with self.assertRaisesRegex(TypeError, "unexpected keyword argument"):
rms = RMS.lazy_aggregate(data, 0, weights=weights)
self.assertAlmostEqual(rms, expected_rms)
rms = RMS.lazy_aggregate(data, 0, weights=weights)
self.assertAlmostEqual(rms, expected_rms)

def test_1d_lazy_weighted(self):
# 1-dimensional input with lazy weights.
data = as_lazy_data(np.array([4, 7, 10, 8], dtype=np.float64))
weights = as_lazy_data(np.array([1, 4, 3, 2], dtype=np.float64))
expected_rms = 8.0
# https://github.com/dask/dask/issues/3846.
with self.assertRaisesRegex(TypeError, "unexpected keyword argument"):
rms = RMS.lazy_aggregate(data, 0, weights=weights)
self.assertAlmostEqual(rms, expected_rms)
rms = RMS.lazy_aggregate(data, 0, weights=weights)
self.assertAlmostEqual(rms, expected_rms)

def test_2d_weighted(self):
# 2-dimensional input with weights.
Expand All @@ -123,20 +119,16 @@ def test_2d_weighted(self):
)
weights = np.array([[1, 4, 3, 2], [2, 1, 1.5, 0.5]], dtype=np.float64)
expected_rms = np.array([8.0, 16.0], dtype=np.float64)
# https://github.com/dask/dask/issues/3846.
with self.assertRaisesRegex(TypeError, "unexpected keyword argument"):
rms = RMS.lazy_aggregate(data, 1, weights=weights)
self.assertArrayAlmostEqual(rms, expected_rms)
rms = RMS.lazy_aggregate(data, 1, weights=weights)
self.assertArrayAlmostEqual(rms, expected_rms)

def test_unit_weighted(self):
# Unit weights should be the same as no weights.
data = as_lazy_data(np.array([5, 2, 6, 4], dtype=np.float64))
weights = np.ones_like(data)
expected_rms = 4.5
# https://github.com/dask/dask/issues/3846.
with self.assertRaisesRegex(TypeError, "unexpected keyword argument"):
rms = RMS.lazy_aggregate(data, 0, weights=weights)
self.assertAlmostEqual(rms, expected_rms)
rms = RMS.lazy_aggregate(data, 0, weights=weights)
self.assertAlmostEqual(rms, expected_rms)

def test_masked(self):
# Masked entries should be completely ignored.
Expand All @@ -152,9 +144,6 @@ def test_masked(self):
self.assertAlmostEqual(rms, expected_rms)

def test_masked_weighted(self):
# Weights should work properly with masked arrays, but currently don't
# (see https://github.com/dask/dask/issues/3846).
# For now, masked weights are simply not supported.
data = as_lazy_data(
ma.array(
[4, 7, 18, 10, 11, 8],
Expand All @@ -164,9 +153,8 @@ def test_masked_weighted(self):
)
weights = np.array([1, 4, 5, 3, 8, 2])
expected_rms = 8.0
with self.assertRaisesRegex(TypeError, "unexpected keyword argument"):
rms = RMS.lazy_aggregate(data, 0, weights=weights)
self.assertAlmostEqual(rms, expected_rms)
rms = RMS.lazy_aggregate(data, 0, weights=weights)
self.assertAlmostEqual(rms, expected_rms)


class Test_name(tests.IrisTest):
Expand Down

0 comments on commit 21b8a2c

Please sign in to comment.