diff --git a/mesmer/prototype/calibrate_multiple.py b/mesmer/prototype/calibrate_multiple.py index 049b1ce6..a6d847e1 100644 --- a/mesmer/prototype/calibrate_multiple.py +++ b/mesmer/prototype/calibrate_multiple.py @@ -3,8 +3,11 @@ import scipy.stats import xarray as xr +from ..core.computation import ( + calc_gaspari_cohn_correlation_matrices, + calc_geodist_exact, +) from .calibrate import AutoRegression1D, AutoRegression1DOrderSelection -from .utils import calculate_gaspari_cohn_correlation_matrices def _get_predictor_dims(predictors): @@ -185,10 +188,9 @@ def calibrate_auto_regressive_process_with_spatially_correlated_errors_multiple_ for gridpoint, gridpoint_vals in target.groupby("gridpoint") } - gaspari_cohn_correlation_matrices = calculate_gaspari_cohn_correlation_matrices( - target.lat, - target.lon, - localisation_radii, + geodist = calc_geodist_exact(target.lon, target.lat) + gaspari_cohn_correlation_matrices = calc_gaspari_cohn_correlation_matrices( + geodist, localisation_radii ) localised_empirical_covariance_matrix = ( diff --git a/mesmer/prototype/utils.py b/mesmer/prototype/utils.py deleted file mode 100644 index 397bb5ed..00000000 --- a/mesmer/prototype/utils.py +++ /dev/null @@ -1,141 +0,0 @@ -import numpy as np -import pyproj -import xarray as xr - - -def calculate_gaspari_cohn_correlation_matrices( - latitudes, - longitudes, - localisation_radii, -): - """ - Calculate Gaspari-Cohn correlation matrices for a range of localisation radiis - - Parameters - ---------- - latitudes : :obj:`xr.DataArray` - Latitudes (one-dimensional) - - longitudes : :obj:`xr.DataArray` - Longitudes (one-dimensional) - - localisation_radii : list-like - Localisation radii to test (in metres) - - Returns - ------- - dict[float: :obj:`xr.DataArray`] - Gaspari-Cohn correlation matrix (values) for each localisation radius (keys) - - Notes - ----- - Values in ``localisation_radii`` should not exceed 10'000 by much because - it can lead to ``ValueError: the input matrix must be positive semidefinite`` - """ - # I wonder if xarray can apply a function to all pairs of points in arrays - # or something - geodistance = calculate_geodistance_exact(latitudes, longitudes) - - gaspari_cohn_correlation_matrices = { - lr: calculate_gaspari_cohn_values(geodistance / lr) for lr in localisation_radii - } - - return gaspari_cohn_correlation_matrices - - -def calculate_geodistance_exact(latitudes, longitudes): - """ - Calculate exact great circle distance based on WSG 84 - - Parameters - ---------- - latitudes : :obj:`xr.DataArray` - Latitudes (one-dimensional) - - longitudes : :obj:`xr.DataArray` - Longitudes (one-dimensional) - - Returns - ------- - :obj:`xr.DataArray` - 2D array of great circle distances between points represented by ``latitudes`` - and ``longitudes`` - """ - if longitudes.shape != latitudes.shape or longitudes.ndim != 1: - raise ValueError("lon and lat need to be 1D arrays of the same shape") - - geod = pyproj.Geod(ellps="WGS84") - - n_points = longitudes.shape[0] - - geodistance = np.zeros([n_points, n_points]) - - # calculate only the upper right half of the triangle first - for i in range(n_points): - - # need to duplicate gridpoint (required by geod.inv) - lat = np.tile(latitudes[i], n_points - (i + 1)) - lon = np.tile(longitudes[i], n_points - (i + 1)) - - geodistance[i, i + 1 :] = geod.inv( - lon, lat, longitudes.values[i + 1 :], latitudes.values[i + 1 :] - )[2] - - # convert m to km - geodistance /= 1000 - - # fill the lower left half of the triangle (in-place) - geodistance += np.transpose(geodistance) - - if latitudes.dims != longitudes.dims: - raise AssertionError( - f"latitudes and longitudes have different dims: {latitudes.dims} vs. {longitudes.dims}" - ) - - geodistance = xr.DataArray( - geodistance, dims=list(latitudes.dims) * 2, coords=latitudes.coords - ) - - return geodistance - - -def calculate_gaspari_cohn_values(inputs): - """ - Calculate smooth, exponentially decaying Gaspari-Cohn values - - Parameters - ---------- - inputs : :obj:`xr.DataArray` - Inputs at which to calculate the value of the smooth, exponentially decaying Gaspari-Cohn - correlation function (these could be e.g. normalised geographical distances) - - Returns - ------- - :obj:`xr.DataArray` - Gaspari-Cohn correlation function applied to each point in ``inputs`` - """ - inputs_abs = abs(inputs) - out = np.zeros_like(inputs) - - sel_zero_to_one = (inputs_abs.values >= 0) & (inputs_abs.values < 1) - r_s = inputs_abs.values[sel_zero_to_one] - out[sel_zero_to_one] = ( - 1 - 5 / 3 * r_s**2 + 5 / 8 * r_s**3 + 1 / 2 * r_s**4 - 1 / 4 * r_s**5 - ) - - sel_one_to_two = (inputs_abs.values >= 1) & (inputs_abs.values < 2) - r_s = inputs_abs.values[sel_one_to_two] - - out[sel_one_to_two] = ( - 4 - - 5 * r_s - + 5 / 3 * r_s**2 - + 5 / 8 * r_s**3 - - 1 / 2 * r_s**4 - + 1 / 12 * r_s**5 - - 2 / (3 * r_s) - ) - - out = xr.DataArray(out, dims=inputs.dims, coords=inputs.coords) - - return out diff --git a/tests/integration/test_prototype.py b/tests/integration/test_prototype.py index e23b3418..efef96a5 100644 --- a/tests/integration/test_prototype.py +++ b/tests/integration/test_prototype.py @@ -7,13 +7,16 @@ from mesmer.calibrate_mesmer.train_gv import train_gv from mesmer.calibrate_mesmer.train_lt import train_lt from mesmer.calibrate_mesmer.train_lv import train_lv +from mesmer.core.computation import ( + calc_gaspari_cohn_correlation_matrices, + calc_geodist_exact, +) from mesmer.prototype.calibrate import LinearRegression from mesmer.prototype.calibrate_multiple import ( calibrate_auto_regressive_process_multiple_scenarios_and_ensemble_members, calibrate_auto_regressive_process_with_spatially_correlated_errors_multiple_scenarios_and_ensemble_members, flatten_predictors_and_target, ) -from mesmer.prototype.utils import calculate_gaspari_cohn_correlation_matrices class _MockConfig: @@ -325,11 +328,13 @@ def _do_legacy_run_train_lv( .values ) - gaspari_cohn_correlation_matrices = calculate_gaspari_cohn_correlation_matrices( - latitudes=esm_tas_residual_local_variability.lat, - longitudes=esm_tas_residual_local_variability.lon, - localisation_radii=localisation_radii, + geodist = calc_geodist_exact( + esm_tas_residual_local_variability.lon, esm_tas_residual_local_variability.lat + ) + gaspari_cohn_correlation_matrices = calc_gaspari_cohn_correlation_matrices( + geodist, localisation_radii ) + gaspari_cohn_correlation_matrices = { k: v.values for k, v in gaspari_cohn_correlation_matrices.items() }