# CDAT Migration Regression Testing Notebook (`.nc` files)

This notebook is used to perform regression testing between the development and
production versions of a diagnostic set.

## How it works

It compares the relative differences (%) between ref and test variables between
the dev and `main` branches.

## How to use

PREREQUISITE: The diagnostic set's netCDF stored in `.json` files in two directories
(dev and `main` branches).

1. Make a copy of this notebook under `auxiliary_tools/cdat_regression_testing/<DIR_NAME>`.
2. Run `mamba create -n cdat_regression_test -y -c conda-forge "python<3.12" xarray netcdf4 dask pandas matplotlib-base ipykernel`
3. Run `mamba activate cdat_regression_test`
4. Update `SET_DIR` and `SET_NAME` in the copy of your notebook.
5. Run all cells IN ORDER.
6. Review results for any outstanding differences (>=1e-5 relative tolerance).
   - Debug these differences (e.g., bug in metrics functions, incorrect variable references, etc.)


## Setup Code


In [21]:
import glob
from typing import List

import numpy as np
import xarray as xr

from e3sm_diags.derivations.derivations import DERIVED_VARIABLES

# The path to the development data.
DEV_DIR = "25-01-15-branch-907"
DEV_PATH = f"/global/cfs/cdirs/e3sm/www/e3sm_diags/complete_run/{DEV_DIR}/"
DEV_GLOB = sorted(glob.glob(DEV_PATH + "**/**/*.nc"))
DEV_NUM_FILES = len(DEV_GLOB)

# The path to the production data to compare against.
MAIN_DIR = "v2.12.1v2"
MAIN_PATH = f"/global/cfs/cdirs/e3sm/www/e3sm_diags/complete_run/{MAIN_DIR}/"
MAIN_GLOB = sorted(glob.glob(MAIN_PATH + "**/**/*.nc"))
MAIN_NUM_FILES = len(MAIN_GLOB)


def _remove_unwanted_files(file_glob: List[str]) -> List[str]:
    """Remove files that we don't want to compare.

    * area_mean_time_series -- `main` does not generate netCDF
    * enso_diags -- `main` does not generate netCDF
    * qbo -- variable name differs
    * diurnal_cycle -- variable name differs
    * diff -- comparing the difference between regridded files is not helpful
      between branches because of the influence in floating point errors.
    * ERA5_ext-U10-ANN-global_ref and ERA5_ext-U10-JJA-global_ref -- dev
      branch does not generate these files because it is a model-only run.

    Parameters
    ----------
    file_glob : List[str]
        _description_

    Returns
    -------
    List[str]
        _description_
    """

    new_glob = []

    for fp in file_glob:
        if (
            "area_mean_time_series" in fp
            or "enso_diags" in fp
            or "qbo" in fp
            or "diurnal_cycle" in fp
            or "diff" in fp
            or "ERA5_ext-U10-ANN-global_ref" in fp
            or "ERA5_ext-U10-JJA-global_ref" in fp
            or "tropical_subseasonal" in fp
        ):
            continue

        new_glob.append(fp)

    return new_glob


DEV_GLOB = _remove_unwanted_files(DEV_GLOB)
MAIN_GLOB = _remove_unwanted_files(MAIN_GLOB)

In [22]:
def _get_relative_diffs():
    # The absolute and relative tolerances for the tests.
    ATOL = 0
    RTOL = 1e-4

    results = {
        "missing_files": [],
        "missing_vars": [],
        "matching_files": [],
        "mismatch_errors": [],
        "not_equal_errors": [],
        "key_errors": [],
    }

    for fp_main in MAIN_GLOB:
        fp_dev = fp_main.replace(MAIN_DIR, DEV_DIR)

        if "annual_cycle_zonal_mean" in fp_main:
            if "test.nc" in fp_main:
                fp_dev = fp_dev.replace("test.nc", "ref.nc")
            elif "ref.nc" in fp_main:
                fp_dev = fp_dev.replace("ref.nc", "test.nc")

        try:
            ds1 = xr.open_dataset(fp_dev)
            ds2 = xr.open_dataset(fp_main)
        except FileNotFoundError as e:
            print(f"    {e}")

            if isinstance(e, FileNotFoundError) or isinstance(e, OSError):
                results["missing_files"].append(fp_dev)

            continue

        var_key = fp_main.split("-")[-3]

        # for 3d vars such as T-200
        var_key.isdigit()
        if var_key.isdigit():
            var_key = fp_main.split("-")[-4]

        dev_data = _get_var_data(ds1, var_key)
        main_data = _get_var_data(ds2, var_key)

        if dev_data is None or main_data is None:
            if dev_data is None:
                results["missing_vars"].append(fp_dev)
            elif main_data is None:
                results["missing_vars"].append(fp_main)

            print("    * Could not find variable key in the dataset(s)")

            continue

        try:
            np.testing.assert_allclose(
                dev_data,
                main_data,
                atol=ATOL,
                rtol=RTOL,
            )
            results["matching_files"].append(fp_main)
        except (KeyError, AssertionError) as e:
            print("Comparing:")
            print(f"    * {fp_dev}")
            print(f"    * {fp_main}")
            print(f"    * var_key: {var_key}")
            msg = str(e)

            print(f"    {msg}")

            if "mismatch" in msg:
                results["mismatch_errors"].append(fp_dev)
            elif "Not equal to tolerance" in msg:
                results["not_equal_errors"].append(fp_dev)

    return results


def _get_var_data(ds: xr.Dataset, var_key: str) -> np.ndarray:
    """Get the variable data using a list of matching keys.

    The `main` branch saves the dataset using the original variable name,
    while the dev branch saves the variable with the derived variable name.
    The dev branch is performing the expected behavior here.

    Parameters
    ----------
    ds : xr.Dataset
        _description_
    var_key : str
        _description_

    Returns
    -------
    np.ndarray
        _description_
    """

    data = None

    try:
        data = ds[var_key].values
    except KeyError:
        try:
            var_keys = DERIVED_VARIABLES[var_key].keys()
        except KeyError:
            var_keys = DERIVED_VARIABLES[var_key.upper()].keys()

        var_keys = [var_key] + list(sum(var_keys, ()))

        for key in var_keys:
            if key in ds.data_vars.keys():
                data = ds[key].values
                break

    return data

In [23]:
def _check_if_files_found():
    if DEV_NUM_FILES == 0 or MAIN_NUM_FILES == 0:
        raise IOError(
            "No files found at DEV_DIR and/or MAIN_DIR. "
            f"Please check {DEV_PATH} and {MAIN_DIR}."
        )


def _check_if_matching_filecount():
    if DEV_NUM_FILES != MAIN_NUM_FILES:
        raise IOError(
            "Number of files do not match at DEV_DIR and MAIN_DIR "
            f"({DEV_NUM_FILES} vs. {MAIN_NUM_FILES})."
        )

    print(f"Matching file count ({DEV_NUM_FILES} and {MAIN_NUM_FILES}).")


def _check_if_missing_files():
    missing_dev_files = []
    missing_main_files = []

    for fp_main in MAIN_GLOB:
        fp_dev = fp_main.replace(MAIN_DIR, DEV_DIR)

        if fp_dev not in DEV_GLOB:
            missing_dev_files.append(fp_dev)

    for fp_dev in DEV_GLOB:
        fp_main = fp_dev.replace(DEV_DIR, MAIN_DIR)

        if fp_main not in MAIN_GLOB:
            missing_main_files.append(fp_main)

    return missing_dev_files, missing_main_files

## 1. Check for matching and equal number of files


In [24]:
_check_if_files_found()

In [25]:
DEV_GLOB = [fp for fp in DEV_GLOB if "diff.nc" not in fp]
MAIN_GLOB = [fp for fp in MAIN_GLOB if "diff.nc" not in fp]

In [26]:
len(DEV_GLOB), len(MAIN_GLOB)

(590, 590)

In [27]:
missing_dev_files, missing_main_files = _check_if_missing_files()

print(f"Missing dev files: {len(missing_dev_files)}")
print(f"Missing main files: {len(missing_main_files)}")

Missing dev files: 0
Missing main files: 0


### Check missing main files (not concerned)

Results:

- The missing files are due to a recent .cfg update in [PR #830](https://github.com/E3SM-Project/e3sm_diags/pull/830)


In [28]:
missing_main_files

[]

### Check missing dev files:

Results:

- The missing reference files are due to not saving them out to netCDF since they are the same as the test files (skipped, model-only run)


In [29]:
missing_dev_files

[]

## 2 Compare the netCDF files between branches

- Compare "ref" and "test" files
- "diff" files are ignored because getting relative diffs for these does not make sense (relative diff will be above tolerance)


In [30]:
results = _get_relative_diffs()

Comparing:
    * /global/cfs/cdirs/e3sm/www/e3sm_diags/complete_run/25-01-15-branch-907/lat_lon/COREv2_Flux/COREv2_Flux-PminusE-ANN-global_test.nc
    * /global/cfs/cdirs/e3sm/www/e3sm_diags/complete_run/v2.12.1v2/lat_lon/COREv2_Flux/COREv2_Flux-PminusE-ANN-global_test.nc
    * var_key: PminusE
    
Not equal to tolerance rtol=0.0001, atol=0

Mismatched elements: 47 / 64800 (0.0725%)
Max absolute difference among violations: 3.47921741e-07
Max relative difference among violations: 0.00244282
 ACTUAL: array([[0.154003, 0.154003, 0.154003, ..., 0.188599, 0.188599, 0.188599],
       [0.13969 , 0.139701, 0.139724, ..., 0.162309, 0.162268, 0.162247],
       [0.12946 , 0.12946 , 0.12946 , ..., 0.143414, 0.143414, 0.143414],...
 DESIRED: array([[0.154003, 0.154003, 0.154003, ..., 0.188599, 0.188599, 0.188599],
       [0.13969 , 0.139701, 0.139724, ..., 0.162309, 0.162268, 0.162247],
       [0.12946 , 0.12946 , 0.12946 , ..., 0.143414, 0.143414, 0.143414],...
Comparing:
    * /global/cfs/cdirs

In [31]:
import pandas as pd

(
    missing_files,
    missing_vars,
    matching_files,
    mismatch_errors,
    not_equal_errors,
    key_errors,
) = results.values()

In [12]:
# Assuming these variables are defined in your notebook
matching_files_count = len(matching_files)
missing_vars_count = len(missing_vars)
mismatch_errors_count = len(mismatch_errors)
not_equal_errors_count = len(not_equal_errors)
key_errors_count = len(key_errors)
missing_files_count = len(missing_files)

sum_files_compared = (
    matching_files_count
    + missing_vars_count
    + mismatch_errors_count
    + not_equal_errors_count
    + key_errors_count
    + missing_files_count
)

pct_match = (matching_files_count / sum_files_compared) * 100

# Collect statistics into a dictionary
statistics = {
    "stat_name": [
        "matching_files_count",
        "missing_vars_count",
        "mismatch_errors_count",
        "not_equal_errors_count",
        "key_errors_count",
        "missing_files_count",
    ],
    "value": [
        matching_files_count,
        missing_vars_count,
        mismatch_errors_count,
        not_equal_errors_count,
        key_errors_count,
        missing_files_count,
    ],
    "pct": [
        matching_files_count / sum_files_compared,
        missing_vars_count / sum_files_compared,
        mismatch_errors_count / sum_files_compared,
        not_equal_errors_count / sum_files_compared,
        key_errors_count / sum_files_compared,
        missing_files_count / sum_files_compared,
    ],
}

# Convert the dictionary to a pandas DataFrame
df = pd.DataFrame(statistics)

# Display the DataFrame
print(df)

                stat_name  value       pct
0    matching_files_count    553  0.937288
1      missing_vars_count      0  0.000000
2   mismatch_errors_count      4  0.006780
3  not_equal_errors_count     33  0.055932
4        key_errors_count      0  0.000000
5     missing_files_count      0  0.000000


### `NaN` Mismatching Errors

I found these `nan` mismatch errors occur due to either:

1. Regional subsetting on "ccb" flag in CDAT adding a coordinate points -- removing these coordinates results in matching results
2. Slightly different masking in the data between xCDAT and CDAT via xESMF/ESMF -- same number of nans just slightly shifted over some coordinates points

- Refer to PR [#794](https://github.com/E3SM-Project/e3sm_diags/pull/794)


In [13]:
mismatch_errors = [
    f
    for f in mismatch_errors
    # https://github.com/E3SM-Project/e3sm_diags/pull/794
    if "TAUXY" not in f and "ERA5-TREFHT" not in f and "MERRA2-TREFHT" not in f
    # https://github.com/E3SM-Project/e3sm_diags/pull/798#issuecomment-2251287986
    and "ceres_ebaf_toa_v4.1-ALBEDO" not in f
]

In [14]:
mismatch_errors

[]

### Not Equal Errors

- Note, some files are omitted due to known root causes to the diffs (not a concern)


Let's remove them and compare statistics between v3.0.0 and v2.12.1.


In [15]:
not_equal_errors_new = [
    f
    for f in not_equal_errors
    # Large diffs due to bug on `main` (https://github.com/E3SM-Project/e3sm_diags/issues/797)
    if "MISRCOSP-CLDLOW_TAU1.3_9.4_MISR" not in f
    and "MISRCOSP-CLDLOW_TAU1.3_MISR" not in f
    and "MISRCOSP-CLDLOW_TAU9.4_MISR" not in f
    # only 1 mismatching elements max abs diff 9.84e-06 and max rel diff 0.00094
    and "ERA5-OMEGA-500" not in f
    # only 3 mismatching elemente max abs diff of 2.88e-05 and max rel diff 0.0047
    and "ERA5-OMEGA-850" not in f
    # only 47 mismatching elements with max abs diff of 1.853e-06
    and "COREv2_Flux-PminusE" not in f
    # only 22 mismatching elements with max abs diff of 2.116e-05 and max rel diff 0.0052
    and "GPCP_OAFLux-PminusE" not in f
    # only 1 mismatching element with max abs diff of 9.841e-6 and max rel diff of 0.000947
    and "MERRA2-OMEGA-200" not in f
    # only 1 mismatching element with max abs diff of 2.884e-05 and max rel diff of 0.00474
    and "MERRA2-OMEGA-500" not in f
    # only 3 mismatching element with max abs diff of 2.884e-05 and max rel diff of 0.00474
    # and "MERRA2-OMEGA-850" not in f
    # only 1 mismatching element with max abs diff of 4.453e-07 and max rel diff of 0.00011
    and "MERRA2-OMEGA-ANN" not in f
    # only 4 mismatching elements with max abs diff of 1.086e-05 and max rel diff of 0.0000622
    and "ERA5-U-850" not in f
    # https://github.com/E3SM-Project/e3sm_diags/issues/787
    and "MERRA2-U" not in f
    # only 1 mismatching element with max abs diff of 0.000103 and max rel diff of 0.00027
    and "ERA5-OMEGA-ANN" not in f
    # only 1 mismatching element with max abs diff of 0.000103 and max rel diff of 0.00027
    and "MERRA2-OMEGA-ANN" not in f
    # https://github.com/E3SM-Project/e3sm_diags/issues/852
    and "AOD_550" not in f
]

In [16]:
not_equal_errors_new

['/global/cfs/cdirs/e3sm/www/e3sm_diags/complete_run/25-01-10-branch-907-no-arm-diags/lat_lon/Cloud MISR/MISRCOSP-CLDTOT_TAU1.3_9.4_MISR-ANN-global_test.nc',
 '/global/cfs/cdirs/e3sm/www/e3sm_diags/complete_run/25-01-10-branch-907-no-arm-diags/lat_lon/Cloud MISR/MISRCOSP-CLDTOT_TAU1.3_MISR-ANN-global_test.nc',
 '/global/cfs/cdirs/e3sm/www/e3sm_diags/complete_run/25-01-10-branch-907-no-arm-diags/lat_lon/ERA5/ERA5-NET_FLUX_SRF-ANN-global_test.nc',
 '/global/cfs/cdirs/e3sm/www/e3sm_diags/complete_run/25-01-10-branch-907-no-arm-diags/lat_lon/ERA5/ERA5-OMEGA-200-ANN-global_test.nc',
 '/global/cfs/cdirs/e3sm/www/e3sm_diags/complete_run/25-01-10-branch-907-no-arm-diags/lat_lon/GPCP_v2.3/GPCP_v2.3-PRECT-ANN-global_ref.nc',
 '/global/cfs/cdirs/e3sm/www/e3sm_diags/complete_run/25-01-10-branch-907-no-arm-diags/lat_lon/MERRA2/MERRA2-NET_FLUX_SRF-ANN-global_test.nc',
 '/global/cfs/cdirs/e3sm/www/e3sm_diags/complete_run/25-01-10-branch-907-no-arm-diags/lat_lon/MERRA2/MERRA2-OMEGA-850-ANN-global_test

In [17]:
import xarray as xr


def get_stats_for_not_equal_files(filepaths):
    for fp_dev in filepaths:
        fp_main = fp_dev.replace(DEV_DIR, MAIN_DIR)

        ds1 = xr.open_dataset(fp_dev)
        ds2 = xr.open_dataset(fp_main)

        var_key = fp_main.split("-")[-3]

        # for 3d vars such as T-200
        var_key.isdigit()
        if var_key.isdigit():
            var_key = fp_main.split("-")[-4]

        dev_mean = ds1[var_key].mean().item()
        main_mean = ds2[var_key].mean().item()

        dev_sum = ds1[var_key].sum().item()
        main_sum = ds2[var_key].sum().item()

        print(f"Checking variable {var_key}")
        print(f"Dev Path: {fp_dev}")
        print(f"Main Path: {fp_main}")
        print("-------------------------------------")

        mean_diff = dev_mean - main_mean
        sum_diff = dev_sum - main_sum

        absolute_mean_diff = abs(mean_diff)
        absolute_sum_diff = abs(sum_diff)

        relative_mean_diff = (
            absolute_mean_diff / abs(main_mean) if main_mean != 0 else float("inf")
        )
        relative_sum_diff = (
            absolute_sum_diff / abs(main_sum) if main_sum != 0 else float("inf")
        )
        dev_min = ds1[var_key].min().item()
        dev_max = ds1[var_key].max().item()

        main_min = ds2[var_key].min().item()
        main_max = ds2[var_key].max().item()

        print(f"* Min - dev: {dev_min:.6f}, main: {main_min:.6f}")
        print(f"* Max - dev: {dev_max:.6f}, main: {main_max:.6f}")

        print(f"* Mean - dev: {dev_mean:.6f}, main: {main_mean:.6f}")
        print(f"    * Absolute Mean Diff: {absolute_mean_diff}")
        print(f"    * Relative Mean Diff: {relative_mean_diff * 100:.6f}%")

        print(f"* Sum - dev: {dev_sum:.6f}, main: {main_sum:.6f}")
        print(f"    * Absolute Sum Diff: {absolute_sum_diff}")
        print(f"    * Relative Sum Diff: {relative_sum_diff * 100:.6f}%")

In [18]:
get_stats_for_not_equal_files(not_equal_errors_new)

Checking variable CLDTOT_TAU1.3_9.4_MISR
Dev Path: /global/cfs/cdirs/e3sm/www/e3sm_diags/complete_run/25-01-10-branch-907-no-arm-diags/lat_lon/Cloud MISR/MISRCOSP-CLDTOT_TAU1.3_9.4_MISR-ANN-global_test.nc
Main Path: /global/cfs/cdirs/e3sm/www/e3sm_diags/complete_run/v2.12.1v2/lat_lon/Cloud MISR/MISRCOSP-CLDTOT_TAU1.3_9.4_MISR-ANN-global_test.nc
-------------------------------------
* Min - dev: 0.055400, main: 0.055400
* Max - dev: 52.521282, main: 52.521290
* Mean - dev: 23.644886, main: 23.644806
    * Absolute Mean Diff: 8.027062228066484e-05
    * Relative Mean Diff: 0.000339%
* Sum - dev: 1532184.750000, main: 1532183.412355
    * Absolute Sum Diff: 1.3376447223126888
    * Relative Sum Diff: 0.000087%
Checking variable CLDTOT_TAU1.3_MISR
Dev Path: /global/cfs/cdirs/e3sm/www/e3sm_diags/complete_run/25-01-10-branch-907-no-arm-diags/lat_lon/Cloud MISR/MISRCOSP-CLDTOT_TAU1.3_MISR-ANN-global_test.nc
Main Path: /global/cfs/cdirs/e3sm/www/e3sm_diags/complete_run/v2.12.1v2/lat_lon/Cloud 

Now let's remove the files that are not a concern. These are files with the same min and max values, and close mean and sum values.


In [19]:
not_equal_errors_new = [
    f
    for f in not_equal_errors_new
    if "MISRCOSP-CLDTOT_TAU1.3_9.4_MISR" not in f
    and "MISRCOSP-CLDTOT_TAU1.3_MISR" not in f
    and "ERA5-NET_FLUX_SRF" not in f
    and "ERA5-OMEGA-200" not in f
    and "GPCP_v2.3-PRECT-ANN" not in f
    and "MERRA2-NET_FLUX_SRF" not in f
    and "MERRA2-OMEGA-850" not in f
    and "GPCP_v2.3-PRECT" not in f
]

In [20]:
not_equal_errors_new

[]

## Results

```python
                stat_name  value       pct
0    matching_files_count    553  0.937288
1      missing_vars_count      0  0.000000
2   mismatch_errors_count      4  0.006780
3  not_equal_errors_count     33  0.055932
4        key_errors_count      0  0.000000
5     missing_files_count      0  0.000000
```

- 553 match within the rtol of 1e-5, which is awesome.
- 4 mismatch errors are known issues due to regional subsetting differences with "ccb" flag
- 33 not equal errors are not a concern because they affect very small number of elements in the dataset. The stats (min, max, mean, and sum) of the datasets between branches are close.
