diff --git a/scripts/collect_metrics.py b/scripts/collect_metrics.py index cc7f66e..aa24a25 100644 --- a/scripts/collect_metrics.py +++ b/scripts/collect_metrics.py @@ -1,10 +1,12 @@ import json +import re from pathlib import Path -import climatebenchpress.compressor import pandas as pd import xarray as xr +import climatebenchpress.compressor + REPO = Path(__file__).parent.parent EVALUATION_METRICS: dict[str, climatebenchpress.compressor.metrics.abc.Metric] = { @@ -31,6 +33,8 @@ def main(): continue for error_bound in dataset.iterdir(): + variable2error_bound = parse_error_bounds(error_bound.name) + for compressor in error_bound.iterdir(): print(f"Evaluating {compressor.stem} on {dataset.name}...") @@ -58,7 +62,9 @@ def main(): compressor_metrics.mkdir(parents=True, exist_ok=True) metrics = compute_metrics(compressor_metrics, ds, ds_new) - tests = compute_tests(compressor_metrics, ds, ds_new) + tests = compute_tests( + compressor_metrics, variable2error_bound, ds, ds_new + ) measurements = load_measurements(compressed_dataset, compressor) df = merge_metrics(measurements, metrics, tests) @@ -70,6 +76,53 @@ def main(): all_results.to_csv(metrics_dir / "all_results.csv", index=False) +def parse_error_bounds(error_bound_str: str) -> dict[str, tuple[str, float]]: + """ + The error bound string is of the form + "{variable_name1}-{error_type1}={error_bound1}_{variable_name2}-{error_type2}={error_bound2}". + More than 2 variables are possible. + Each variable name can itself contain an underscore. + The error type is either "abs_error" or "rel_error". + The error bound is a floating point number represented either in decimal or scientific notation. + + This function parses the string and returns a dictionary of the form + { + "variable_name1": (error_type1, error_bound1), + "variable_name2": (error_type2, error_bound2), + } + + For example, the string + "pr-abs_error=3.108691982924938e-05_rlut-abs_error=0.2788982238769531" + would be parsed as + { + "pr": ("abs_error", 3.108691982924938e-05), + "rlut": ("abs_error", 0.2788982238769531), + } + """ + pattern = re.compile( + r"(?:_?)" # Underscore at the beginning separating the different variables. + r"(?P[\w]+)" # Variable name can any alphanumeric character. + r"-(?Pabs_error|rel_error)=" # Error type is either "abs_error" or "rel_error". + r"(?P\d+(\.\d+)?([eE][+-]?\d+)?)" # Error bound is a floating point number. + ) + result = {} + for match in pattern.finditer(error_bound_str): + try: + error_bound = float(match["error_bound"]) + except ValueError: + raise ValueError( + f"Error bound '{match['error_bound']}' from '{error_bound_str}' is not a valid float" + ) + + result[match["variable"]] = (match["error_type"], error_bound) + + assert len(result) > 0, ( + f"Error bound string {error_bound_str} does not match expected format" + ) + + return result + + def compute_metrics( compressor_metrics: Path, ds: xr.Dataset, ds_new: xr.Dataset ) -> pd.DataFrame: @@ -94,7 +147,10 @@ def compute_metrics( def compute_tests( - compressor_metrics: Path, ds: xr.Dataset, ds_new: xr.Dataset + compressor_metrics: Path, + variable2bound: dict[str, tuple[str, float]], + ds: xr.Dataset, + ds_new: xr.Dataset, ) -> pd.DataFrame: tests_path = compressor_metrics / "tests.csv" if tests_path.exists(): @@ -112,6 +168,22 @@ def compute_tests( "Value": test_value, } ) + + for v in ds_new: + error_type, bound = variable2bound[str(v)] + test = climatebenchpress.compressor.tests.ErrorBound( + error_type=error_type, threshold=bound + ) + test_result, test_value = test(ds[v], ds_new[v]) + test_list.append( + { + "Test": "Satisfies Bound", + "Variable": v, + "Passed": test_result, + "Value": test_value, + } + ) + tests = pd.DataFrame(test_list) tests.to_csv(tests_path, index=False) return tests diff --git a/src/climatebenchpress/compressor/tests/__init__.py b/src/climatebenchpress/compressor/tests/__init__.py index df77eb9..ca4955e 100644 --- a/src/climatebenchpress/compressor/tests/__init__.py +++ b/src/climatebenchpress/compressor/tests/__init__.py @@ -1,5 +1,6 @@ from . import abc as abc +from .error_bound import ErrorBound from .r2_correlation import R2 from .spatial_relative_error import SRE -__all__ = ["SRE", "R2"] +__all__ = ["SRE", "R2", "ErrorBound"] diff --git a/src/climatebenchpress/compressor/tests/error_bound.py b/src/climatebenchpress/compressor/tests/error_bound.py new file mode 100644 index 0000000..fcc052f --- /dev/null +++ b/src/climatebenchpress/compressor/tests/error_bound.py @@ -0,0 +1,41 @@ +import numpy as np +import xarray as xr + +from .abc import Test + + +class ErrorBound(Test): + def __init__(self, error_type: str, threshold: float = 0.05): + self.threshold = threshold + assert error_type in [ + "abs_error", + "rel_error", + ], f"error_type must be either 'abs_error' or 'rel_error', not {error_type}" + self.error_type = error_type + + def __call__(self, x: xr.DataArray, y: xr.DataArray) -> tuple[bool, float]: + # Check that two arrays are both floats + assert x.dtype.kind == "f", f"Expected x to be float, got {x.dtype}" + assert y.dtype.kind == "f", f"Expected y to be float, got {y.dtype}" + + abs_error = np.abs(x - y) + relative_error = abs_error / abs(x) + + error_to_check = abs_error if self.error_type == "abs_error" else relative_error + satisfied = error_to_check <= self.threshold + + # The comparison does not work for NaN values, as `np.nan < threshold` is False. + # This check ensures that if x contains a NaN then y must also contain a NaN at + # the same location. + # Note, it is an error to have a NaN in y and not in x which will be caught. + x_and_y_nan = np.isnan(x) & np.isnan(y) + satisfied = satisfied | x_and_y_nan + + # Similarly, np.inf - np.inf is NaN but should pass the test. + # The x == y condition ensures that their sign is the same. + x_and_y_inf = np.isinf(x) & np.isinf(y) & x == y + satisfied = satisfied | x_and_y_inf + + # Proportion of entries that exceed the threshold. + exceed_thresh = np.sum(~satisfied) / x.size + return bool(exceed_thresh == 0.0), float(exceed_thresh)