diff --git a/CHANGELOG.rst b/CHANGELOG.rst index ef4056fb..27f92b3c 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -90,6 +90,10 @@ Breaking changes Deprecations ^^^^^^^^^^^^ +- The function `mesmer.utils.select.extract_time_period` is now deprecated and will be + removed in a future version. Please raise an issue if you use this function (`#243 + `_). By `Mathias Hauser + `_. Bug fixes ^^^^^^^^^ diff --git a/mesmer/utils/select.py b/mesmer/utils/select.py index e3a8f771..e3a3e103 100644 --- a/mesmer/utils/select.py +++ b/mesmer/utils/select.py @@ -104,12 +104,12 @@ def extract_land(var, reg_dict=None, wgt=None, ls=None, threshold_land=0.25): return var_l, {}, ls -def extract_time_period(var, time, start, end): +def extract_time_period(data, time, start, end): """Extract selected time period. Parameters ---------- - var : np.ndarray + data : np.ndarray variable in 1-4d array - (time); @@ -139,17 +139,15 @@ def extract_time_period(var, time, start, end): """ - # find index of start and end of time period - idx_start = np.where(time == int(start))[0][0] - idx_end = np.where(time == int(end))[0][0] + 1 # to include the end year + warnings.warn( + "`extract_time_period` is deprecated in v0.9.0 and will be remove in a future " + "version. Please raise an issue if you still use this function.", + FutureWarning, + ) - # extract time period from variable dictionary - if len(var.shape) > 1: - var_tp = var[:, idx_start:idx_end] - else: - var_tp = var[idx_start:idx_end] + sel = (time >= start) & (time <= end) - # extract time period from time vector - time_tp = time[idx_start:idx_end] + time = time[sel] + data = data[:, sel, ...] if data.ndim > 1 else data[sel] - return var_tp, time_tp + return data, time diff --git a/tests/unit/test_extract_time_period.py b/tests/unit/test_extract_time_period.py new file mode 100644 index 00000000..56679293 --- /dev/null +++ b/tests/unit/test_extract_time_period.py @@ -0,0 +1,63 @@ +import numpy as np +import pytest + +from mesmer.utils.select import extract_time_period + + +def test_extract_time_period_deprecation(): + + time = np.arange(1950, 2050) + + data = np.linspace(0, 1, time.size) + + with pytest.warns(FutureWarning, match="`extract_time_period` is deprecated"): + extract_time_period(data, time, 1955, 2005) + + +@pytest.mark.filterwarnings("ignore:`extract_time_period` is deprecated") +def test_extract_time_period_1D(): + + time = np.arange(1950, 2050) + + data = np.linspace(0, 1, time.size) + + result_data, result_time = extract_time_period(data, time, 1955, 2005) + + expected_data = data[5 : 5 + 50 + 1] + expected_time = np.arange(1955, 2005 + 1) + + np.testing.assert_equal(result_data, expected_data) + np.testing.assert_equal(result_time, expected_time) + + +@pytest.mark.filterwarnings("ignore:`extract_time_period` is deprecated") +def test_extract_time_period_2D(): + + time = np.arange(1900, 2000) + + data = np.arange(3 * time.size).reshape(3, -1) + + result_data, result_time = extract_time_period(data, time, 1911, 1995) + + expected_data = data[:, 11 : 95 + 1] + expected_time = np.arange(1911, 1995 + 1) + + np.testing.assert_equal(result_data, expected_data) + np.testing.assert_equal(result_time, expected_time) + + +@pytest.mark.filterwarnings("ignore:`extract_time_period` is deprecated") +def test_extract_time_period_3D(): + + time = np.arange(1900, 2000) + + # (run, time, gridpoint) + data = np.arange(3 * 5 * time.size).reshape(3, -1, 5) + + result_data, result_time = extract_time_period(data, time, 1911, 1995) + + expected_data = data[:, 11 : 95 + 1, :] + expected_time = np.arange(1911, 1995 + 1) + + np.testing.assert_equal(result_data, expected_data) + np.testing.assert_equal(result_time, expected_time)