Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 75 additions & 3 deletions scripts/collect_metrics.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import json
import re
from pathlib import Path

import climatebenchpress.compressor
import pandas as pd
import xarray as xr

import climatebenchpress.compressor

REPO = Path(__file__).parent.parent

EVALUATION_METRICS: dict[str, climatebenchpress.compressor.metrics.abc.Metric] = {
Expand All @@ -31,6 +33,8 @@ def main():
continue

for error_bound in dataset.iterdir():
variable2error_bound = parse_error_bounds(error_bound.name)

for compressor in error_bound.iterdir():
print(f"Evaluating {compressor.stem} on {dataset.name}...")

Expand Down Expand Up @@ -58,7 +62,9 @@ def main():
compressor_metrics.mkdir(parents=True, exist_ok=True)

metrics = compute_metrics(compressor_metrics, ds, ds_new)
tests = compute_tests(compressor_metrics, ds, ds_new)
tests = compute_tests(
compressor_metrics, variable2error_bound, ds, ds_new
)
measurements = load_measurements(compressed_dataset, compressor)

df = merge_metrics(measurements, metrics, tests)
Expand All @@ -70,6 +76,53 @@ def main():
all_results.to_csv(metrics_dir / "all_results.csv", index=False)


def parse_error_bounds(error_bound_str: str) -> dict[str, tuple[str, float]]:
"""
The error bound string is of the form
"{variable_name1}-{error_type1}={error_bound1}_{variable_name2}-{error_type2}={error_bound2}".
More than 2 variables are possible.
Each variable name can itself contain an underscore.
The error type is either "abs_error" or "rel_error".
The error bound is a floating point number represented either in decimal or scientific notation.

This function parses the string and returns a dictionary of the form
{
"variable_name1": (error_type1, error_bound1),
"variable_name2": (error_type2, error_bound2),
}

For example, the string
"pr-abs_error=3.108691982924938e-05_rlut-abs_error=0.2788982238769531"
would be parsed as
{
"pr": ("abs_error", 3.108691982924938e-05),
"rlut": ("abs_error", 0.2788982238769531),
}
"""
pattern = re.compile(
r"(?:_?)" # Underscore at the beginning separating the different variables.
r"(?P<variable>[\w]+)" # Variable name can any alphanumeric character.
r"-(?P<error_type>abs_error|rel_error)=" # Error type is either "abs_error" or "rel_error".
r"(?P<error_bound>\d+(\.\d+)?([eE][+-]?\d+)?)" # Error bound is a floating point number.
)
result = {}
for match in pattern.finditer(error_bound_str):
try:
error_bound = float(match["error_bound"])
except ValueError:
raise ValueError(
f"Error bound '{match['error_bound']}' from '{error_bound_str}' is not a valid float"
)

result[match["variable"]] = (match["error_type"], error_bound)

assert len(result) > 0, (
f"Error bound string {error_bound_str} does not match expected format"
)

return result


def compute_metrics(
compressor_metrics: Path, ds: xr.Dataset, ds_new: xr.Dataset
) -> pd.DataFrame:
Expand All @@ -94,7 +147,10 @@ def compute_metrics(


def compute_tests(
compressor_metrics: Path, ds: xr.Dataset, ds_new: xr.Dataset
compressor_metrics: Path,
variable2bound: dict[str, tuple[str, float]],
ds: xr.Dataset,
ds_new: xr.Dataset,
) -> pd.DataFrame:
tests_path = compressor_metrics / "tests.csv"
if tests_path.exists():
Expand All @@ -112,6 +168,22 @@ def compute_tests(
"Value": test_value,
}
)

for v in ds_new:
error_type, bound = variable2bound[str(v)]
test = climatebenchpress.compressor.tests.ErrorBound(
error_type=error_type, threshold=bound
)
test_result, test_value = test(ds[v], ds_new[v])
test_list.append(
{
"Test": "Satisfies Bound",
"Variable": v,
"Passed": test_result,
"Value": test_value,
}
)

tests = pd.DataFrame(test_list)
tests.to_csv(tests_path, index=False)
return tests
Expand Down
3 changes: 2 additions & 1 deletion src/climatebenchpress/compressor/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from . import abc as abc
from .error_bound import ErrorBound
from .r2_correlation import R2
from .spatial_relative_error import SRE

__all__ = ["SRE", "R2"]
__all__ = ["SRE", "R2", "ErrorBound"]
41 changes: 41 additions & 0 deletions src/climatebenchpress/compressor/tests/error_bound.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import numpy as np
import xarray as xr

from .abc import Test


class ErrorBound(Test):
def __init__(self, error_type: str, threshold: float = 0.05):
self.threshold = threshold
assert error_type in [
"abs_error",
"rel_error",
], f"error_type must be either 'abs_error' or 'rel_error', not {error_type}"
self.error_type = error_type

def __call__(self, x: xr.DataArray, y: xr.DataArray) -> tuple[bool, float]:
# Check that two arrays are both floats
assert x.dtype.kind == "f", f"Expected x to be float, got {x.dtype}"
assert y.dtype.kind == "f", f"Expected y to be float, got {y.dtype}"

abs_error = np.abs(x - y)
Comment thread
juntyr marked this conversation as resolved.
relative_error = abs_error / abs(x)

error_to_check = abs_error if self.error_type == "abs_error" else relative_error
satisfied = error_to_check <= self.threshold
Comment thread
treigerm marked this conversation as resolved.

# The comparison does not work for NaN values, as `np.nan < threshold` is False.
# This check ensures that if x contains a NaN then y must also contain a NaN at
# the same location.
# Note, it is an error to have a NaN in y and not in x which will be caught.
x_and_y_nan = np.isnan(x) & np.isnan(y)
satisfied = satisfied | x_and_y_nan

# Similarly, np.inf - np.inf is NaN but should pass the test.
# The x == y condition ensures that their sign is the same.
x_and_y_inf = np.isinf(x) & np.isinf(y) & x == y
satisfied = satisfied | x_and_y_inf

# Proportion of entries that exceed the threshold.
exceed_thresh = np.sum(~satisfied) / x.size
Comment thread
juntyr marked this conversation as resolved.
return bool(exceed_thresh == 0.0), float(exceed_thresh)
Loading