diff --git a/CHANGELOG.rst b/CHANGELOG.rst index e2659f58..9c850aca 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -74,9 +74,10 @@ New Features By `Mathias Hauser`_. -- Allow passing `xr.DataArray` to ``gaspari_cohn`` (`#298 `__). +- Allow passing `xr.DataArray` to ``gaspari_cohn`` (`#298 `__). By `Mathias Hauser`_. - +- Allow passing `xr.DataArray` to ``calc_geodist_exact`` (`#299 `__). + By `Zeb Nicholls`_ and `Mathias Hauser`_. Breaking changes diff --git a/mesmer/core/computation.py b/mesmer/core/computation.py index 4ee9c063..186fb36a 100644 --- a/mesmer/core/computation.py +++ b/mesmer/core/computation.py @@ -7,19 +7,21 @@ import pyproj import xarray as xr +from .utils import create_equal_dim_names + def gaspari_cohn(r): """smooth, exponentially decaying Gaspari-Cohn correlation function Parameters ---------- - r : xr.DataArray, np.array + r : xr.DataArray, np.ndarray Values for which to calculate the value of the Gaspari-Cohn correlation function (e.g. normalised geographical distances) Returns ------- - out : xr.DataArray, , np.array + out : xr.DataArray, , np.ndarray Gaspari-Cohn correlation function Notes @@ -89,30 +91,60 @@ def _gaspari_cohn_np(r): return out -def calc_geodist_exact(lon, lat): +def calc_geodist_exact(lon, lat, equal_dim_suffixes=("_i", "_j")): """exact great circle distance based on WSG 84 Parameters ---------- - lon : array-like + lon : xr.DataArray, np.ndarray 1D array of longitudes - lat : array-like + lat : xr.DataArray, np.ndarray 1D array of latitudes + equal_dim_suffixes : tuple of str, default: ("_i", "_j") + Suffixes to add to the the name of ``dim`` for the geodist array (xr.DataArray + cannot have two dimensions with the same name). Returns ------- - geodist : np.array + geodist : xr.DataArray, np.ndarray 2D array of great circle distances. """ + # TODO: allow Dataset (e.g. using cf_xarray) + if isinstance(lon, xr.Dataset) or isinstance(lat, xr.Dataset): + raise TypeError("Dataset is not supported, please pass a DataArray") + + # handle numpy arrays + if not isinstance(lon, xr.DataArray) or not isinstance(lat, xr.DataArray): + return _calc_geodist_exact(np.asarray(lon), np.asarray(lat)) + + # TODO: allow differently named lon and lat dims? + if lon.dims != lat.dims: + raise AssertionError( + f"lon and lat have different dims: {lon.dims} vs. {lat.dims}. Expected " + "equally named dimensions from a stacked array" + ) + + geodist = _calc_geodist_exact(lon.values, lat.values) + + (dim,) = lon.dims + dims = create_equal_dim_names(dim, equal_dim_suffixes) + + # TODO: assign coords? + geodist = xr.DataArray(geodist, dims=dims) + + return geodist + + +def _calc_geodist_exact(lon, lat): + # ensure correct shape - lon, lat = np.asarray(lon), np.asarray(lat) if lon.shape != lat.shape or lon.ndim != 1: - raise ValueError("lon and lat need to be 1D arrays of the same shape") + raise ValueError("lon and lat must be 1D arrays of the same shape") geod = pyproj.Geod(ellps="WGS84") - n_points = len(lon) + n_points = lon.size geodist = np.zeros([n_points, n_points]) @@ -120,8 +152,8 @@ def calc_geodist_exact(lon, lat): for i in range(n_points): # need to duplicate gridpoint (required by geod.inv) - lt = np.tile(lat[i], n_points - (i + 1)) - ln = np.tile(lon[i], n_points - (i + 1)) + lt = np.repeat(lat[i : i + 1], n_points - (i + 1)) + ln = np.repeat(lon[i : i + 1], n_points - (i + 1)) geodist[i, i + 1 :] = geod.inv(ln, lt, lon[i + 1 :], lat[i + 1 :])[2] diff --git a/mesmer/stats/localized_covariance.py b/mesmer/stats/localized_covariance.py index f9d4519f..35ac6998 100644 --- a/mesmer/stats/localized_covariance.py +++ b/mesmer/stats/localized_covariance.py @@ -95,9 +95,9 @@ def find_localized_empirical_covariance( Dimension along which to calculate the covariance. k_folds : int Number of folds to use for cross validation. - equal_dim_suffixes : tuple of str - Suffixes to add to the the name of ``dim`` for the covariance array (xr.DataArray cannot have two - dimensions with the same name). + equal_dim_suffixes : tuple of str, default: ("_i", "_j") + Suffixes to add to the the name of ``dim`` for the covariance array + (xr.DataArray cannot have two dimensions with the same name). Returns ------- diff --git a/tests/unit/test_computation.py b/tests/unit/test_computation.py index 34840d39..020d0dfb 100644 --- a/tests/unit/test_computation.py +++ b/tests/unit/test_computation.py @@ -2,7 +2,7 @@ import pytest import xarray as xr -from mesmer.core.computation import gaspari_cohn +from mesmer.core.computation import calc_geodist_exact, gaspari_cohn def test_gaspari_cohn_error(): @@ -51,3 +51,99 @@ def test_gaspari_cohn_np(): # make sure shape is conserved values = np.arange(9).reshape(3, 3) assert gaspari_cohn(values).shape == (3, 3) + + +def test_calc_geodist_dataset_error(): + + ds = xr.Dataset() + da = xr.DataArray() + + with pytest.raises(TypeError, match="Dataset is not supported"): + calc_geodist_exact(ds, ds) + + with pytest.raises(TypeError, match="Dataset is not supported"): + calc_geodist_exact(ds, da) + + with pytest.raises(TypeError, match="Dataset is not supported"): + calc_geodist_exact(da, ds) + + +def test_calc_geodist_dataarray_equal_dims_required(): + + lon = xr.DataArray([0], dims="lon") + lat = xr.DataArray([0], dims="lat") + + with pytest.raises(AssertionError, match="lon and lat have different dims"): + calc_geodist_exact(lon, lat) + + +@pytest.mark.parametrize("as_dataarray", [True, False]) +def test_calc_geodist_not_same_shape_error(as_dataarray): + + lon, lat = [0, 0], [0] + + if as_dataarray: + lon, lat = xr.DataArray(lon), xr.DataArray(lat) + + with pytest.raises(ValueError, match="lon and lat must be 1D arrays"): + calc_geodist_exact(lon, lat) + + +@pytest.mark.parametrize("as_dataarray", [True, False]) +def test_calc_geodist_not_1D_error(as_dataarray): + + lon = lat = [[0, 0]] + + if as_dataarray: + lon, lat = xr.DataArray(lon), xr.DataArray(lat) + + with pytest.raises(ValueError, match=".*of the same shape"): + calc_geodist_exact(lon, lat) + + +@pytest.mark.parametrize("lon", [[0, 0], [0, 360], [1, 361], [180, -180]]) +@pytest.mark.parametrize("as_dataarray", [True, False]) +def test_calc_geodist_exact_equal(lon, as_dataarray): + """test points with distance 0""" + + expected = np.array([[0, 0], [0, 0]]) + + lat = [0, 0] + + if as_dataarray: + lon = xr.DataArray(lon) + + result = calc_geodist_exact(lon, lat) + np.testing.assert_equal(result, expected) + # when passing only one DataArray it's also returned as np.array + assert isinstance(result, np.ndarray) + + +@pytest.mark.parametrize("as_dataarray", [True, False]) +def test_calc_geodist_exact(as_dataarray): + """test some random points""" + + lon = [-180, 0, 3] + lat = [0, 0, 5] + + if as_dataarray: + lon = xr.DataArray(lon, dims="gp", coords={"lon": ("gp", lon)}) + lat = xr.DataArray(lat, dims="gp", coords={"lat": ("gp", lat)}) + + result = calc_geodist_exact(lon, lat) + expected = np.array( + [ + [0.0, 20003.93145863, 19366.51816487], + [20003.93145863, 0.0, 645.70051988], + [19366.51816487, 645.70051988, 0.0], + ] + ) + + if as_dataarray: + + expected = xr.DataArray(expected, dims=("gp_i", "gp_j")) + xr.testing.assert_allclose(expected, result) + + else: + + np.testing.assert_allclose(result, expected) diff --git a/tests/unit/test_phi_gc.py b/tests/unit/test_phi_gc.py index 38a7724c..6e055006 100644 --- a/tests/unit/test_phi_gc.py +++ b/tests/unit/test_phi_gc.py @@ -1,7 +1,5 @@ import numpy as np -import pytest -from mesmer.core.computation import calc_geodist_exact from mesmer.io import load_phi_gc, load_regs_ls_wgt_lon_lat @@ -57,46 +55,3 @@ def test_phi_gc_end_to_end(tmp_path): ] ) np.testing.assert_allclose(expected, actual[1000], rtol=1e-5) - - -def test_calc_geodist_exact_shape(): - - msg = "lon and lat need to be 1D arrays of the same shape" - - # not the same shape - with pytest.raises(ValueError, match=msg): - calc_geodist_exact([0, 0], [0]) - - # not 1D - with pytest.raises(ValueError, match=msg): - calc_geodist_exact([[0, 0]], [[0, 0]]) - - -def test_calc_geodist_exact_equal(): - """test points with distance 0""" - - expected = np.array([[0, 0], [0, 0]]) - - lat = [0, 0] - lons = [[0, 0], [0, 360], [1, 361], [180, -180]] - - for lon in lons: - result = calc_geodist_exact(lon, lat) - np.testing.assert_equal(result, expected) - - result = calc_geodist_exact(lon, lat) - np.testing.assert_equal(result, expected) - - -def test_calc_geodist_exact(): - """test some random points""" - result = calc_geodist_exact([-180, 0, 3], [0, 0, 5]) - expected = np.array( - [ - [0.0, 20003.93145863, 19366.51816487], - [20003.93145863, 0.0, 645.70051988], - [19366.51816487, 645.70051988, 0.0], - ] - ) - - np.testing.assert_allclose(result, expected)