From bb2ec46061715075a1e2f44e81e79891cbe4927c Mon Sep 17 00:00:00 2001 From: Bouwe Andela Date: Fri, 23 Feb 2024 22:12:05 +0100 Subject: [PATCH 1/2] Lazy rolling_window --- lib/iris/util.py | 45 +++++++++++++++++++++++++-------------------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/lib/iris/util.py b/lib/iris/util.py index 020b67783a..13c5cde803 100644 --- a/lib/iris/util.py +++ b/lib/iris/util.py @@ -3,6 +3,7 @@ # This file is part of Iris and is released under the BSD license. # See LICENSE in the root of the repository for full licensing details. """Miscellaneous utility functions.""" +from __future__ import annotations from abc import ABCMeta, abstractmethod from collections.abc import Hashable, Iterable @@ -281,7 +282,12 @@ def guess_coord_axis(coord): return axis -def rolling_window(a, window=1, step=1, axis=-1): +def rolling_window( + a: np.ndarray | da.Array, + window: int = 1, + step: int = 1, + axis: int = -1, +) -> np.ndarray | da.Array: """Make an ndarray with a rolling window of the last dimension. Parameters @@ -322,8 +328,6 @@ def rolling_window(a, window=1, step=1, axis=-1): See more at :doc:`/userguide/real_and_lazy_data`. """ - # NOTE: The implementation of this function originates from - # https://github.com/numpy/numpy/pull/31#issuecomment-1304851 04/08/2011 if window < 1: raise ValueError("`window` must be at least 1.") if window > a.shape[axis]: @@ -331,25 +335,26 @@ def rolling_window(a, window=1, step=1, axis=-1): if step < 1: raise ValueError("`step` must be at least 1.") axis = axis % a.ndim - num_windows = (a.shape[axis] - window + step) // step - shape = a.shape[:axis] + (num_windows, window) + a.shape[axis + 1 :] - strides = ( - a.strides[:axis] - + (step * a.strides[axis], a.strides[axis]) - + a.strides[axis + 1 :] + array_module = da if isinstance(a, da.Array) else np + steps = tuple( + slice(None, None, step) if i == axis else slice(None) for i in range(a.ndim) ) - rw = np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides) - if ma.isMaskedArray(a): - mask = ma.getmaskarray(a) - strides = ( - mask.strides[:axis] - + (step * mask.strides[axis], mask.strides[axis]) - + mask.strides[axis + 1 :] - ) - rw = ma.array( - rw, - mask=np.lib.stride_tricks.as_strided(mask, shape=shape, strides=strides), + + def _rolling_window(array): + return array_module.moveaxis( + array_module.lib.stride_tricks.sliding_window_view( + array, + window_shape=window, + axis=axis, + )[steps], + -1, + axis + 1, ) + + rw = _rolling_window(a) + if isinstance(da.utils.meta_from_array(a), np.ma.MaskedArray): + mask = _rolling_window(array_module.ma.getmaskarray(a)) + rw = array_module.ma.masked_array(rw, mask) return rw From 9a8b8d9b52e6da8636bd9a5c7208dd45a9e34b07 Mon Sep 17 00:00:00 2001 From: Bouwe Andela Date: Fri, 23 Feb 2024 23:18:28 +0100 Subject: [PATCH 2/2] Add test and whatsnew entry --- docs/src/whatsnew/latest.rst | 3 ++- lib/iris/tests/unit/util/test_rolling_window.py | 7 +++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/docs/src/whatsnew/latest.rst b/docs/src/whatsnew/latest.rst index ab5d18d3eb..6c3359b13f 100644 --- a/docs/src/whatsnew/latest.rst +++ b/docs/src/whatsnew/latest.rst @@ -48,7 +48,8 @@ This document explains the changes made to Iris for this release 🚀 Performance Enhancements =========================== -#. N/A +#. `@bouweandela`_ made :func:`iris.util.rolling_window` work with lazy arrays. + (:pull:`5775`) 🔥 Deprecations diff --git a/lib/iris/tests/unit/util/test_rolling_window.py b/lib/iris/tests/unit/util/test_rolling_window.py index 8a017e4e08..d70b398ed5 100644 --- a/lib/iris/tests/unit/util/test_rolling_window.py +++ b/lib/iris/tests/unit/util/test_rolling_window.py @@ -8,6 +8,7 @@ # importing anything else import iris.tests as tests # isort:skip +import dask.array as da import numpy as np import numpy.ma as ma @@ -35,6 +36,12 @@ def test_2d(self): result = rolling_window(a, window=3, axis=1) self.assertArrayEqual(result, expected_result) + def test_3d_lazy(self): + a = da.arange(2 * 3 * 4).reshape((2, 3, 4)) + expected_result = np.arange(2 * 3 * 4).reshape((1, 2, 3, 4)) + result = rolling_window(a, window=2, axis=0).compute() + self.assertArrayEqual(result, expected_result) + def test_1d_masked(self): # 1-d masked array input a = ma.array([0, 1, 2, 3, 4], mask=[0, 0, 1, 0, 0], dtype=np.int32)