## Overview

This notebook compares the test, regridded ref, and diffs variables between xCDAT/xESMF
and CDAT/ESMF. It uses the `GPCP_v3.2-PRECT-ANN-polar_N` variable to serve as the example.

### Issue

We are finding that the xCDAT/xESMF "Model - Observation" plots have some problems at
the latitude edge (related GitHub [comment)](https://github.com/E3SM-Project/e3sm_diags/pull/931#issuecomment-2652403024).

<div style="display: flex; justify-content: space-around; align-items: center; border: 1px solid #ccc; padding: 10px;">
    <div style="text-align: center;">
        <h4>Example old (v2.12.1)</h4>
        <img src="https://web.lcrc.anl.gov/public/e3sm/diagnostic_output/ac.forsyth2/zppy_weekly_comprehensive_v2_www/test_pr651_both_commits_20250117/v2.LR.historical_0201/image_check_failures_comprehensive_v2/e3sm_diags/atm_monthly_180x360_aave/model_vs_obs_1982-1983/polar/GPCP_v3.2/GPCP_v3.2-PRECT-ANN-polar_N.png_actual.png" alt="Example old (v2.12.1)">
    </div>
    <div style="text-align: center;">
        <h4>Example new (v3.0.0rc2)</h4>
        <img src="https://web.lcrc.anl.gov/public/e3sm/diagnostic_output/ac.forsyth2/zppy_weekly_comprehensive_v2_www/test_pr651_both_commits_20250117/v2.LR.historical_0201/image_check_failures_comprehensive_v2/e3sm_diags/atm_monthly_180x360_aave/model_vs_obs_1982-1983/polar/GPCP_v3.2/GPCP_v3.2-PRECT-ANN-polar_N.png_actual.png" alt="Example new (v3.0.0rc2)">
    </div>
    <div style="text-align: center;">
        <h4>Example difference of both</h4>
        <img src="https://web.lcrc.anl.gov/public/e3sm/diagnostic_output/ac.forsyth2/zppy_weekly_comprehensive_v2_www/test_pr651_both_commits_20250117/v2.LR.historical_0201/image_check_failures_comprehensive_v2/e3sm_diags/atm_monthly_180x360_aave/model_vs_obs_1982-1983/polar/GPCP_v3.2/GPCP_v3.2-PRECT-ANN-polar_N.png_diff.png" alt="Example difference of both">
    </div>
</div>

### My Theory

My theory is that there are some differences in how xESMF and ESMF regrids as mentioned
pull request [#931](https://github.com/E3SM-Project/e3sm_diags/pull/931#issue-2828706137)., These differences propagate to the "Model - Observation" plots.


### 0. Notebook Setup Code


In [41]:
import cdms2 as cd
import numpy as np
import xarray as xr
import xcdat as xc # noqa: F401

# The variable in the dataset to compare.
VAR_KEY = "PRECT"

def print_stats(arr1, arr2, label1="Array 1", label2="Array 2"):
    stats = {
        "Min": (np.min(arr1), np.min(arr2)),
        "Max": (np.max(arr1), np.max(arr2)),
        "Mean": (np.mean(arr1), np.mean(arr2)),
        "Sum": (np.sum(arr1), np.sum(arr2)),
        "Std": (np.std(arr1), np.std(arr2)),
    }

    print(f"{'Stat':<10} {label1:<15} {label2:<15}")
    print("-" * 40)
    for stat, values in stats.items():
        print(f"{stat:<10} {values[0]:<15.6f} {values[1]:<15.6f}")

### 1. First, regrid the variables using xCDAT and xESMF.

The test and reference datasets are a result of Xarray/xCDAT's subsetting.

- Test filepath: `/lcrc/group/e3sm/public_html/cdat-migration-fy24/25-02-14-branch-930-polar/polar/GPCP_v3.2/GPCP_v3.2-PRECT-ANN-polar_N_test.nc`
- Ref filepath: `/lcrc/group/e3sm/public_html/cdat-migration-fy24/25-02-14-branch-930-polar/polar/GPCP_v3.2/GPCP_v3.2-PRECT-ANN-polar_N_ref.nc`


In [59]:
import xcdat as xc
# 1. Open the datasets with Xarray.
ds_test = xr.open_dataset("/lcrc/group/e3sm/public_html/cdat-migration-fy24/25-02-14-branch-930-polar-after/polar/GPCP_v3.2/GPCP_v3.2-PRECT-ANN-polar_N_test.nc").bounds.add_missing_bounds()
ds_ref = xr.open_dataset("/lcrc/group/e3sm/public_html/cdat-migration-fy24/25-02-14-branch-930-polar-after/polar/GPCP_v3.2/GPCP_v3.2-PRECT-ANN-polar_N_ref.nc").bounds.add_missing_bounds()

ds_test = ds_test.rename_dims({"bnds": "nbnd"})

# 2. Regrid the reference variable to the test variable using xESMF and "conservative".
test_grid_xc = ds_test.regridder.grid
ds_ref_regrid = ds_ref.regridder.horizontal(
    VAR_KEY, test_grid_xc, tool="xesmf", method="conservative"
)

# 3. Get the difference between the test and regridded reference variables.
test_var_xc = ds_test[VAR_KEY].values.copy()
ref_var_xc_reg = ds_ref_regrid[VAR_KEY].values.copy()
diff_var_xc = test_var_xc - ref_var_xc_reg

In [60]:
ds_ref_regrid["PRECT"].sum()

### 2. Regrid using CDAT + ESMF

The test and reference datasets are a result of CDAT's subsetting on `polar_N`.

- Test filepath: `/lcrc/group/e3sm/public_html/cdat-migration-fy24/25-02-14-branch-930-polar-main/polar/GPCP_v3.2/GPCP_v3.2-PRECT-ANN-polar_N_test.nc`
- Ref filepath: `/lcrc/group/e3sm/public_html/cdat-migration-fy24/25-02-14-branch-930-polar-main/polar/GPCP_v3.2/GPCP_v3.2-PRECT-ANN-polar_N_ref.nc`


In [44]:
# 1. Open the datasets with cdms2.
test_var_cd = cd.open("/lcrc/group/e3sm/public_html/cdat-migration-fy24/25-02-14-branch-930-polar-main/polar/GPCP_v3.2/GPCP_v3.2-PRECT-ANN-polar_N_test.nc")(VAR_KEY)
ref_var_cd = cd.open("/lcrc/group/e3sm/public_html/cdat-migration-fy24/25-02-14-branch-930-polar-main/polar/GPCP_v3.2/GPCP_v3.2-PRECT-ANN-polar_N_ref.nc")(VAR_KEY)

# 2. Regrid the reference variable to the test variable using ESMF and "conservative".
test_grid_cd = test_var_cd.getGrid()
ref_var_cd_reg = ref_var_cd.regrid(
    test_grid_cd, regridTool="esmf", regridMethod="conservative"
)

# 3. Get the difference between the test and regridded reference variables.
diff_var_cd = test_var_cd - ref_var_cd_reg


#### 2a. Compare the grids used for regridding.

**RESULT: The grids are the same.**


In [32]:
np.testing.assert_allclose(test_grid_xc.lat.values, test_grid_cd.getLatitude(), rtol=1e-5, atol=0)
np.testing.assert_allclose(test_grid_xc.lon.values, test_grid_cd.getLongitude(), rtol=1e-5, atol=0)

print("xCDAT Grid Latitude Values:", test_grid_xc.lat.values)
print("CDAT Grid Latitude Values:", test_grid_cd.getLatitude()[:])

print("xCDAT Grid Longitude Values:", test_grid_xc.lon.values)
print("CDAT Grid Longitude Values:", test_grid_cd.getLongitude()[:])

xCDAT Grid Latitude Values: [50.5 51.5 52.5 53.5 54.5 55.5 56.5 57.5 58.5 59.5 60.5 61.5 62.5 63.5
 64.5 65.5 66.5 67.5 68.5 69.5 70.5 71.5 72.5 73.5 74.5 75.5 76.5 77.5
 78.5 79.5 80.5 81.5 82.5 83.5 84.5 85.5 86.5 87.5 88.5 89.5]
CDAT Grid Latitude Values: [50.5 51.5 52.5 53.5 54.5 55.5 56.5 57.5 58.5 59.5 60.5 61.5 62.5 63.5
 64.5 65.5 66.5 67.5 68.5 69.5 70.5 71.5 72.5 73.5 74.5 75.5 76.5 77.5
 78.5 79.5 80.5 81.5 82.5 83.5 84.5 85.5 86.5 87.5 88.5 89.5]
xCDAT Grid Longitude Values: [  0.5   1.5   2.5   3.5   4.5   5.5   6.5   7.5   8.5   9.5  10.5  11.5
  12.5  13.5  14.5  15.5  16.5  17.5  18.5  19.5  20.5  21.5  22.5  23.5
  24.5  25.5  26.5  27.5  28.5  29.5  30.5  31.5  32.5  33.5  34.5  35.5
  36.5  37.5  38.5  39.5  40.5  41.5  42.5  43.5  44.5  45.5  46.5  47.5
  48.5  49.5  50.5  51.5  52.5  53.5  54.5  55.5  56.5  57.5  58.5  59.5
  60.5  61.5  62.5  63.5  64.5  65.5  66.5  67.5  68.5  69.5  70.5  71.5
  72.5  73.5  74.5  75.5  76.5  77.5  78.5  79.5  80.5  81.5  82.5  83

#### 2b. Compare the test variables.

**RESULT: The test variables are within rtol 1e-5 and the stats are close, good to go.**


In [33]:
try:
    np.testing.assert_allclose(test_var_xc, test_var_cd.data, rtol=1e-5, atol=0)
except AssertionError as e:
    print("Arrays are not within relative tolerance (1e-5).")
    print(e)
else:
    print("Arrays are the within relative tolerance (1e-5).")


Arrays are the within relative tolerance (1e-5).


#### 2b. Compare the _regridded_ ref variables.

**RESULT: The ref variables are within rtol 1e-5 and stats align**


In [34]:
try:
    np.testing.assert_allclose(ref_var_xc_reg, ref_var_cd_reg.data, rtol=1e-5, atol=0)
except AssertionError as e:
    print("Arrays are not within relative tolerance (1e-5).")
    print(e)
else:
    print("Arrays are the within relative tolerance (1e-5).")

Arrays are the within relative tolerance (1e-5).


In [35]:
print_stats(ref_var_xc_reg, ref_var_cd_reg.data, label1="xCDAT Ref", label2="CDAT Ref")

Stat       xCDAT Ref       CDAT Ref       
----------------------------------------
Min        0.164364        0.164364       
Max        10.046874       10.046873      
Mean       1.398573        1.398573       
Sum        20139.449219    20139.449219   
Std        1.039437        1.039437       


#### 3b. Compare the difference variable (test - regridded reference).


In [36]:
try:
    np.testing.assert_allclose(diff_var_xc, diff_var_cd.data, rtol=1e-5, atol=0)
except AssertionError as e:
    print("Arrays are not within relative tolerance (1e-5).")
    print(e)
else:
    print("Arrays are the within relative tolerance (1e-5).")

Arrays are not within relative tolerance (1e-5).

Not equal to tolerance rtol=1e-05, atol=0

Mismatched elements: 178 / 14400 (1.24%)
Max absolute difference among violations: 4.42087185e-07
Max relative difference among violations: 0.00327499
 ACTUAL: array([[-0.23122 , -0.524473, -0.434358, ..., -0.489304, -0.160534,
        -0.113844],
       [-0.025744, -0.188844, -0.063587, ...,  0.008939, -0.029974,...
 DESIRED: array([[-0.23122 , -0.524473, -0.434358, ..., -0.489304, -0.160534,
        -0.113844],
       [-0.025744, -0.188844, -0.063587, ...,  0.008939, -0.029974,...


In [37]:
print_stats(diff_var_xc, diff_var_cd.data, label1="xCDAT Diff", label2="CDAT Diff")

Stat       xCDAT Diff      CDAT Diff      
----------------------------------------
Min        -6.194607       -6.194607      
Max        3.538580        3.538580       
Mean       -0.133721       -0.133721      
Sum        -1925.586426    -1925.586490   
Std        0.459430        0.459430       


## Conclusion

### Root Cause

The root cause of the differences can be found in the `align_grids_to_lower_res()` function.
https://github.com/E3SM-Project/e3sm_diags/blob/main/e3sm_diags/driver/utils/regrid.py#L342

Before this function is called, the test and reference datasets the `lev` axis.
As as a result, the `output_grid` being produced include the `lev` axis, which is a
larger grid that influences the regridding results in an undesirable way.

```python
<xarray.Dataset> Size: 11kB
Dimensions:   (lon: 360, nbnd: 2, lat: 40, lev: 72, bnds: 2)
Coordinates:
  * lon       (lon) float64 3kB 0.5 1.5 2.5 3.5 4.5 ... 356.5 357.5 358.5 359.5
  * lat       (lat) float64 320B 50.5 51.5 52.5 53.5 ... 86.5 87.5 88.5 89.5
  * lev       (lev) float64 576B 0.1238 0.1828 0.2699 ... 986.2 993.8 998.5
Dimensions without coordinates: nbnd, bnds
Data variables:
    lon_bnds  (lon, nbnd) float64 6kB 0.0 1.0 1.0 2.0 ... 359.0 359.0 360.0
    lat_bnds  (lat, nbnd) float64 640B 50.0 51.0 51.0 52.0 ... 89.0 89.0 90.0
    lev_bnds  (lev, bnds) float64 1kB 0.09432 0.1533 0.1533 ... 996.1 1.001e+03
```

### Fix

We need to drop unnecessary variables and only keep the ones needed (e.g., "PRECT" and bounds variables). This will produce the correct with just "lat" and "lon" that is fed into xESMF for regridding.

```python
<xarray.Dataset> Size: 11kB
Dimensions:   (lon: 360, nbnd: 2, lat: 40)
Coordinates:
  * lon       (lon) float64 3kB 0.5 1.5 2.5 3.5 4.5 ... 356.5 357.5 358.5 359.5
  * lat       (lat) float64 320B 50.5 51.5 52.5 53.5 ... 86.5 87.5 88.5 89.5
Dimensions without coordinates: nbnd
Data variables:
    lon_bnds  (lon, nbnd) float64 6kB 0.0 1.0 1.0 2.0 ... 359.0 359.0 360.0
    lat_bnds  (lat, nbnd) float64 640B 50.0 51.0 51.0 52.0 ... 89.0 89.0 90.0
```

### Next steps

I am confirming whether this finding is correct for all other cases by doing a more comprehensive test of the diffs you found in your PR description.
