In [1]:
"""
This script compares the CMORized E3SM datasets in CMIP6 format generated
from the "master" and dev branches.
How it works:
    All "*.nc" outputs are generated using the `tests/example_end-to-end_script.sh`
    script, which was ran on two separate conda environments using a local conda
    build of `e3sm_to_cmip` from each branch respectively. This script loops over
    all `*.nc` dataset files and compares their outputs using xarray.
How to use:
  1. Update `DEV_BRANCH` to the directory containing your branch's outputs.
  2. Run the script and view the results.
"""

import datetime
from math import inf
import os
from pathlib import Path
from typing import Dict
import numpy as np
import pandas as pd
import tqdm
import xarray as xr

# The base directory that stores the `master` and dev branch subdirectories.

BASE_DIR = "/p/user_pub/e3sm/e3sm_to_cmip/test-cases"
MASTER_OUTPUT = f"{BASE_DIR}/master/CMIP6"

# TODO: Update DEV_BRANCH to your branch.
DEV_BRANCH = "115-branch"
DEV_OUTPUT = f"{BASE_DIR}/{DEV_BRANCH}/CMIP6"

In [2]:
def get_filepaths(dir: str) -> Dict[str, str]:
    """
    Get absolute filepaths for all `.nc` outputs after CMORizing using
    `e3sm_to_cmip` on the "master" branch.
    Returns
    -------
    List[Path]
        A list of absolute paths to all `.nc` outputs on the "master" branch.
    """
    paths = {}
    for root, _, files in os.walk(dir):
        for filename in files:
            if ".nc" in filename:
                var_key = filename.split("_")[0]
                paths[var_key] = str(Path(root, filename).absolute())
    return paths


def compare_results(master_paths: Dict[str, str]) -> pd.DataFrame:
    """Compares the `.nc` outputs between the "master" and dev.
    When two `.nc` output files are not identical:
      * The file is not found on the dev (CMORizing was not successful)
      * The array values are different
      * The metadata is different
        * Excludes "creation_date", "history", and "tracking_id", which should
          be different between files.
    Notes
    -----
    - https://docs.xarray.dev/en/stable/generated/xarray.Dataset.identical.html
      - "Like equals, but also checks all dataset attributes and the attributes
       on all variables and coordinates."
    - https://docs.xarray.dev/en/stable/generated/xarray.Dataset.equals.html#xarray.Dataset.equals
      - Two Datasets are equal if they have matching variables and coordinates,
        all of which are equal.
      - Datasets can still be equal (like pandas objects) if they have NaN
        values in the same locations.
      - This method is necessary because v1 == v2 for Dataset does element-wise
        comparisons (like numpy.ndarrays).

    Parameters
    ----------
    m_paths : Dict[str, str]
        A list of absolute paths to all `.nc` outputs on the "master" branch.

    d_paths : Dict[str, str]
        A list of absolute paths to all `.nc` outputs on the dev branch.

    Returns
    -------
    pd.DataFrame
        A DataFrame that compares the `.nc` outputs between branches.
    """
    rows = []

    for filename, path in tqdm.tqdm(master_paths.items()):
        var_key = filename.split("_")[0]

        dev_path = path.replace("master", DEV_BRANCH)

        row = {
            "var_key": var_key,
            "master_path": path,
            "dev_path": dev_path,
        }
        try:
            ds1 = xr.open_dataset(path)
            ds2 = xr.open_dataset(dev_path)
        except Exception as e:
            row.update(
                {
                    "result": e,
                    "diff_attrs": None,
                    "dv_diff_attrs": None,
                }
            )
        else:
            # Get the attributes that are different between both files and remove
            # them. The different attributes are most often "creation_date",
            # "history" and "tracking_id"
            ds_diff_attrs = _get_ds_diff_attrs(ds1.attrs, ds2.attrs)
            dv_diff_attrs = _get_dv_diff_attrs(ds1, ds2)
            ds1 = _remove_diff_attrs(ds1, ds_diff_attrs, dv_diff_attrs)
            ds2 = _remove_diff_attrs(ds2, ds_diff_attrs, dv_diff_attrs)

            try:
                xr.testing.assert_identical(ds1, ds2)
                result = "identical"
            except AssertionError:
                try:
                    np.testing.assert_allclose(
                        ds1[var_key].values, ds2[var_key].values, atol=1e-7, rtol=1e-7
                    )
                except AssertionError as e:
                    msg = str(e)
                    if "nan location" in msg:
                        result = "mismatch_nan"
                    elif "shape" in msg:
                        result = "mismatch_shape"
                    else:
                        result = "not_close"
                else:
                    result = "close"

            row.update(
                {
                    "result": result,
                    "ds_diff_attrs": ds_diff_attrs,
                    "dv_diff_attrs": dv_diff_attrs,
                }
            )
        rows.append(row)

    df = pd.DataFrame(rows)

    return df


def _get_ds_diff_attrs(
    attrs1: Dict[str, str], attrs2: Dict[str, str]
) -> Dict[str, str]:
    """Get the differing attributes between datasets.
    Attributes are considered "different" if they are in both datasets and
    aren't equal. The different attributes are usually just "creation_date",
    "history", and "tracking_id"
    Parameters
    ----------
    attrs1 : Dict[str, str]
        The first dataset's attributes.
    attrs2 : Dict[str, str]
        The second dataset's attributes.
    Returns
    -------
    Dict[str, str]
        A dictionary of differing attributes.
    Example
    -------
    {'creation_date': '2023-02-02T20:15:18Z',
     'history': '2023-02-02T20:15:18Z ;rewrote data to be consistent with CMIP
                for variable pr found in table 3hr.;# \nOutput from
                20180129.DECKv1b_piControl.ne30_oEC.edison',
     'tracking_id': 'hdl:21.14100/509df588-5944-4ab4-b8e7-6238fe0f94f7'
    }
    """
    diff_attrs = {
        k: (attrs1[k], attrs2[k])
        for k in attrs1
        if k in attrs2 and attrs1[k] != attrs2[k]
    }
    return diff_attrs


def _get_dv_diff_attrs(ds1: xr.Dataset, ds2: xr.Dataset) -> Dict[str, Dict[str, str]]:
    """Gets the differing attributes between the data variables in the datasets.
    Parameters
    ----------
    ds1 : xr.Dataset
        The first dataset.
    ds2 : xr.Dataset
        The second dataset.
    Returns
    -------
    Dict[str, Dict[str, str]]
        A dictionary with the key being the name of the data variable and the
        value being a dictionary of differing attributes.
    Examples
    --------
    {'areacello':
        {'history': '2023-02-02T20:19:54Z altered by CMOR: replaced missing
                    value flag (9.96921e+36) and corresponding data with
                    standard missing value (1e+20).'
        }
    }
    """
    ds1_dv = ds1.data_vars
    ds2_dv = ds2.data_vars

    diff_attrs = {}

    for key in ds1_dv.keys():
        dv_attrs1 = ds1_dv[key].attrs
        dv_attrs2 = ds2_dv[key].attrs

        attrs = _get_ds_diff_attrs(dv_attrs1, dv_attrs2)

        if attrs != {}:
            diff_attrs[key] = _get_ds_diff_attrs(dv_attrs1, dv_attrs2)

    return diff_attrs


def _remove_diff_attrs(
    ds: xr.Dataset,
    ds_diff_attrs: Dict[str, str],
    dv_diff_attrs: Dict[str, Dict[str, str]],
) -> xr.Dataset:
    """Remove all differing attributes in the dataset.

    Parameters
    ----------
    ds : xr.Dataset
        The dataset.
    ds_diff_attrs : Dict[str, str]
        The dataset's differing attributes.
    dv_diff_attrs : Dict[str, Dict[str, str]]
        The dataset's data variables' differing attributes.

    Returns
    -------
    xr.Dataset
        The dataset with all differing attributes removed.
    """
    ds_new = ds.copy()

    for key in ds_diff_attrs.keys():
        del ds_new.attrs[key]

    for var, attrs in dv_diff_attrs.items():
        for attr in attrs.keys():
            del ds_new[var].attrs[attr]

    return ds_new

In [3]:
# 1. Get the output file paths
m_paths = get_filepaths(MASTER_OUTPUT)
d_paths = get_filepaths(DEV_OUTPUT)

# 2. Make sure the the filepaths are aligned, which means both master and dev branches successfully produced the same variable datasets.

m_paths_split = [path.split("master/")[-1] for path in m_paths.values()]
d_paths_split = [path.split(f"{DEV_BRANCH}/")[-1] for path in d_paths.values()]

diff_files = set(m_paths_split) ^ set(d_paths_split)

if len(diff_files) > 0:
    raise ValueError(f"There are unique files generated between branches: {diff_files}")

if len(m_paths) != len(d_paths):
    raise ValueError(
        "The number of files generated does not align between both branchs. "
        f"master ({len(m_paths)}) vs. dev ({len(d_paths)})"
    )

## 1. Compare the datasets


In [4]:
# 2. Compare the results between branches and store as a DataFrame.
df = compare_results(m_paths)

100%|██████████| 109/109 [01:13<00:00,  1.48it/s]


In [5]:
df.value_counts(subset=["result"])

result   
identical    104
close          4
not_close      1
Name: count, dtype: int64

In [7]:
# Close datasets
list(df.loc[df.result == "close"].var_key)

['cl', 'pfull', 'clw', 'cli']

## 2. Checking diffs for 1/109 datasets that are `not_close`


In [8]:
df_not_close = df.loc[df.result == "not_close"].sort_values(by=["var_key"])
print(list(df_not_close.var_key))

['mrso']


In [9]:
df_not_close

Unnamed: 0,var_key,master_path,dev_path,result,ds_diff_attrs,dv_diff_attrs
63,mrso,/p/user_pub/e3sm/e3sm_to_cmip/test-cases/maste...,/p/user_pub/e3sm/e3sm_to_cmip/test-cases/115-b...,not_close,"{'creation_date': ('2023-11-01T18:25:29Z', '20...",{}


In [11]:
for index, row in df_not_close.iterrows():
    v_key, d_path, m_path = row.var_key, row.dev_path, row.master_path
    print(v_key)
    print("-----")

    ds_d = xr.open_dataset(d_path)
    ds_m = xr.open_dataset(m_path)

    # Master variable and stats.
    var_d = ds_d[v_key].values
    mean_d = ds_d[v_key].mean(dim=None, skipna=True)
    sum_d = ds_d[v_key].sum(dim=None, skipna=True)

    # Master variable and stats.
    var_m = ds_m[v_key].values
    mean_m = ds_m[v_key].mean(dim=None, skipna=True)
    sum_m = ds_m[v_key].sum(dim=None, skipna=True)

    try:
        np.testing.assert_allclose(var_d, var_m, atol=1e-7, rtol=1e-7, equal_nan=True)
    except AssertionError as e:
        print("MEAN")
        print(f"    dev: {mean_d.item()}, master: {mean_m.item()}")
        print("SUM")
        print(f"    dev: {sum_d.item()}, master: {sum_m.item()}")
        print("COMPARISON")
        print(e)
        print("\n")

mrso
-----
MEAN
    dev: 998.3673095703125, master: 998.3673095703125
SUM
    dev: 277705856.0, master: 277705856.0
COMPARISON

Not equal to tolerance rtol=1e-07, atol=1e-07

Mismatched elements: 25556 / 777600 (3.29%)
Max absolute difference: 0.00073242
Max relative difference: 6.726486e-07
 x: array([[[nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],...
 y: array([[[nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],...




### CONCLUSION:

- 104/109 are identical
- 5/109 are close and meet threshold of `rtol=1e-7` (`cl` `clw`, `cli`, `pfull`)
- 1/109 is not close (`mrso`)
  - `mrso` is fine (max relative difference of 6.726486e-07)
