Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add prototype of class structure #109

Draft
wants to merge 28 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
93239ff
Sketch out prototype
znicholls Oct 21, 2021
9ae897f
Mark path which is never used to simply things
znicholls Oct 24, 2021
35b5fed
Block out another branch
znicholls Oct 24, 2021
92f1e4e
Implemente prototype method
znicholls Oct 24, 2021
47fa2c6
Tidy up
znicholls Oct 25, 2021
b408db0
Format
znicholls Oct 25, 2021
b93de58
Tidy up a bit more
znicholls Oct 25, 2021
de1d293
Change calibration pattern to simplify use of classes
znicholls Oct 25, 2021
3d97ec1
Try doing auto-regression implementation
znicholls Oct 26, 2021
14c429d
Clean up loops
znicholls Oct 27, 2021
4fd57aa
Remove pdb statement
znicholls Oct 27, 2021
1026161
Sketch out test for train lv
znicholls Nov 24, 2021
d4e4e7f
More notes about how to do train lv
znicholls Nov 25, 2021
3470ee4
Start working on geodesic functions
znicholls Nov 25, 2021
3b23a6d
Get legacy training running
znicholls Nov 26, 2021
88ec4ad
Finish reimplementing train_lv
znicholls Nov 26, 2021
592441f
Merge branch 'main' into prototype
mathause Sep 21, 2023
08a5134
linting
mathause Sep 21, 2023
fe1d1bc
fix: test_prototype_train_lv
mathause Sep 21, 2023
1b4b283
Merge branch 'main' into prototype
mathause Sep 21, 2023
c854fa4
clean train_lt.py
mathause Sep 21, 2023
c3bf5e0
remove prototype/utils.py after #298, #299, and #300
mathause Sep 21, 2023
c727a6e
Merge branch 'main' into prototype
mathause Sep 23, 2023
26f6761
allow selected ar order to be None
mathause Sep 25, 2023
a2fec9b
Merge branch 'main' into prototype
mathause Sep 25, 2023
7143934
Merge branch 'main' into prototype
mathause Dec 12, 2023
d3eb99d
fix for gaspari_cohn and geodist_exact
mathause Dec 12, 2023
8f19cd5
small refactor
mathause Dec 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
115 changes: 115 additions & 0 deletions mesmer/prototype/calibrate_multiple.py
@@ -1,8 +1,10 @@
import numpy as np
import pandas as pd
import scipy.stats
import xarray as xr

from .calibrate import AutoRegression1DOrderSelection, AutoRegression1D
from .utils import calculate_gaspari_cohn_correlation_matrices


def _get_predictor_dims(predictors):
Expand Down Expand Up @@ -33,6 +35,15 @@ def _check_coords_match(obj, obj_other, check_coord):
raise AssertionError(f"{check_coord} is not the same on {obj} and {obj_other}")


def _flatten(inp, dims_to_flatten):
stack_coord_name = _get_stack_coord_name(inp)
inp_flat = inp.stack({stack_coord_name: dims_to_flatten}).dropna(
stack_coord_name
)

return inp_flat, stack_coord_name


def _flatten_predictors(predictors, dims_to_flatten, stack_coord_name):
predictors_flat = []
for v, vals in predictors.items():
Expand Down Expand Up @@ -163,3 +174,107 @@ def calibrate_auto_regressive_process_multiple_scenarios_and_ensemble_members(
ar_params = _derive_auto_regressive_process_parameters(target, ar_order)

return ar_params


def calibrate_auto_regressive_process_with_spatially_correlated_errors_multiple_scenarios_and_ensemble_members(
target,
localisation_radii,
max_cross_validation_iterations=30,
gridpoint_dim_name="gridpoint",
):
gridpoint_autoregression_parameters = {
gridpoint: _derive_auto_regressive_process_parameters(gridpoint_vals, order=1)
for gridpoint, gridpoint_vals in target.groupby("gridpoint")
}

gaspari_cohn_correlation_matrices = calculate_gaspari_cohn_correlation_matrices(
target.lat,
target.lon,
localisation_radii,
)

localised_empirical_covariance_matrix = _calculate_localised_empirical_covariance_matrix(
target,
localisation_radii,
gaspari_cohn_correlation_matrices,
max_cross_validation_iterations,
gridpoint_dim_name=gridpoint_dim_name,
)

gridpoint_autoregression_coeffcients = np.hstack([v["lag_coefficients"] for v in gridpoint_autoregression_parameters.values()])

localised_empirical_covariance_matrix_with_ar1_errors = (
(1 - gridpoint_autoregression_coeffcients ** 2)
* localised_empirical_covariance_matrix
)

return localised_empirical_covariance_matrix_with_ar1_errors


def _calculate_localised_empirical_covariance_matrix(
target,
localisation_radii,
gaspari_cohn_correlation_matrices,
max_cross_validation_iterations,
gridpoint_dim_name="gridpoint",
):
dims_to_flatten = [d for d in target.dims if d != gridpoint_dim_name]
target_flattened, stack_coord_name = _flatten(target, dims_to_flatten)
target_flattened = target_flattened.transpose(stack_coord_name, gridpoint_dim_name)

number_samples = target_flattened[stack_coord_name].shape[0]
number_iterations = min([number_samples, max_cross_validation_iterations])

# setup cross-validation stuff
index_cross_validation_out = np.zeros([number_iterations, number_samples], dtype=bool)

for i in range(number_iterations):
index_cross_validation_out[i, i::max_cross_validation_iterations] = True

# No idea what these are either
log_likelihood_cross_validation_sum_max = -10000

for lr in localisation_radii:
log_likelihood_cross_validation_sum = 0

for i in range(number_iterations):
# extract folds (no idea why these are called folds)
target_estimator = target_flattened.isel(**{stack_coord_name: ~index_cross_validation_out[i]}).values
target_cross_validation = target_flattened.isel(**{stack_coord_name: index_cross_validation_out[i]}).values
# selecting relevant weights goes in here

empirical_covariance = np.cov(target_estimator, rowvar=False)
# must be a better way to handle ensuring that the dimensions line up correctly (rather than
# just cheating by using `.values`)
empirical_covariance_localised = empirical_covariance * gaspari_cohn_correlation_matrices[lr].values

# calculate likelihood of cross validation samples
log_likelihood_cross_validation_samples = scipy.stats.multivariate_normal.logpdf(
target_cross_validation,
mean=np.zeros(gaspari_cohn_correlation_matrices[lr].shape[0]),
cov=empirical_covariance_localised,
allow_singular=True,
)
log_likelihood_cross_validation_samples_weighted_sum = np.average(
log_likelihood_cross_validation_samples,
# weights=wgt_scen_eq_cv # TODO: weights handling
) * log_likelihood_cross_validation_samples.shape[0]

# add to full sum over all folds
log_likelihood_cross_validation_sum += log_likelihood_cross_validation_samples_weighted_sum

if log_likelihood_cross_validation_sum > log_likelihood_cross_validation_sum_max:
log_likelihood_cross_validation_sum_max = log_likelihood_cross_validation_sum
else:
# experience tells us that once we start selecting large localisation radii, performance
# will not improve ==> break (reduces computational effort and number of singular matrices
# encountered)
break

# TODO: replace print with logging
print(f"Selected localisation radius: {lr}")

empirical_covariance = np.cov(target_flattened.values, rowvar=False)
empirical_covariance_localised = empirical_covariance * gaspari_cohn_correlation_matrices[lr].values

return empirical_covariance_localised
136 changes: 136 additions & 0 deletions mesmer/prototype/utils.py
@@ -0,0 +1,136 @@
import numpy as np
import pyproj
import xarray as xr


def calculate_gaspari_cohn_correlation_matrices(
latitudes,
longitudes,
localisation_radii,
):
"""
Calculate Gaspari-Cohn correlation matrices for a range of localisation radiis

Parameters
----------
latitudes : :obj:`xr.DataArray`
Latitudes (one-dimensional)

longitudes : :obj:`xr.DataArray`
Longitudes (one-dimensional)

localisation_radii : list-like
Localisation radii to test (in metres)

Returns
-------
dict[float: :obj:`xr.DataArray`]
Gaspari-Cohn correlation matrix (values) for each localisation radius (keys)

Notes
-----
Values in ``localisation_radii`` should not exceed 10'000 by much because
it can lead to ``ValueError: the input matrix must be positive semidefinite``
"""
# I wonder if xarray can apply a function to all pairs of points in arrays
# or something
mathause marked this conversation as resolved.
Show resolved Hide resolved
geodistance = calculate_geodistance_exact(latitudes, longitudes)

gaspari_cohn_correlation_matrices = {
lr : calculate_gaspari_cohn_values(geodistance / lr)
for lr in localisation_radii
}

return gaspari_cohn_correlation_matrices


def calculate_geodistance_exact(latitudes, longitudes):
"""
Calculate exact great circle distance based on WSG 84

Parameters
----------
latitudes : :obj:`xr.DataArray`
Latitudes (one-dimensional)

longitudes : :obj:`xr.DataArray`
Longitudes (one-dimensional)

Returns
-------
:obj:`xr.DataArray`
2D array of great circle distances between points represented by ``latitudes``
and ``longitudes``
"""
if longitudes.shape != latitudes.shape or longitudes.ndim != 1:
raise ValueError("lon and lat need to be 1D arrays of the same shape")

geod = pyproj.Geod(ellps="WGS84")

n_points = longitudes.shape[0]

geodistance = np.zeros([n_points, n_points])

# calculate only the upper right half of the triangle first
for i in range(n_points):

# need to duplicate gridpoint (required by geod.inv)
lat = np.tile(latitudes[i], n_points - (i + 1))
lon = np.tile(longitudes[i], n_points - (i + 1))

geodistance[i, i + 1 :] = geod.inv(lon, lat, longitudes.values[i + 1 :], latitudes.values[i + 1 :])[2]

# convert m to km
geodistance /= 1000

# fill the lower left half of the triangle (in-place)
geodistance += np.transpose(geodistance)

if latitudes.dims != longitudes.dims:
raise AssertionError(f"latitudes and longitudes have different dims: {latitudes.dims} vs. {longitudes.dims}")

geodistance = xr.DataArray(geodistance, dims=list(latitudes.dims) * 2, coords=latitudes.coords)

return geodistance


def calculate_gaspari_cohn_values(inputs):
"""
Calculate smooth, exponentially decaying Gaspari-Cohn values

Parameters
----------
inputs : :obj:`xr.DataArray`
Inputs at which to calculate the value of the smooth, exponentially decaying Gaspari-Cohn
correlation function (these could be e.g. normalised geographical distances)

Returns
-------
:obj:`xr.DataArray`
Gaspari-Cohn correlation function applied to each point in ``inputs``
"""
inputs_abs = abs(inputs)
out = np.zeros_like(inputs)

sel_zero_to_one = (inputs_abs.values >= 0) & (inputs_abs.values < 1)
r_s = inputs_abs.values[sel_zero_to_one]
out[sel_zero_to_one] = (
1 - 5 / 3 * r_s ** 2 + 5 / 8 * r_s ** 3 + 1 / 2 * r_s ** 4 - 1 / 4 * r_s ** 5
mathause marked this conversation as resolved.
Show resolved Hide resolved
)

sel_one_to_two = (inputs_abs.values >= 1) & (inputs_abs.values < 2)
r_s = inputs_abs.values[sel_one_to_two]

out[sel_one_to_two] = (
4
- 5 * r_s
+ 5 / 3 * r_s ** 2
+ 5 / 8 * r_s ** 3
- 1 / 2 * r_s ** 4
+ 1 / 12 * r_s ** 5
- 2 / (3 * r_s)
)

out = xr.DataArray(out, dims=inputs.dims, coords=inputs.coords)

return out