In [1]:
# general python
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import numpy as np
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt

In [2]:
# general eWC
import ewatercycle
import ewatercycle.forcing
import ewatercycle.models
import ewatercycle.analysis

In [3]:
from copy import copy

#### set up paths

In [4]:
path = Path.cwd()
forcing_path = path / "Forcing"

In [5]:
import xarray as xr

In [6]:
from ewatercycle.util import (
    find_closest_point,
    fit_extents_to_grid,
    get_time,
    merge_esvmaltool_datasets,
    reindex,
    to_absolute_path,
)

## Test fail like in the unit test

In [7]:
def test_merge_esmvaltool_datasets():
    files = list((forcing_path).glob("OBS6_ERA5*.nc"))
    datasets = [xr.open_dataset(file) for file in files]
    ds = merge_esvmaltool_datasets(datasets)
    for var in ["tas", "pr", "rsds"]:
        assert not ds[var].mean(dim=["lat", "lon"]).isnull().any("time")

    assert "height" in ds["tas"].attrs


In [8]:
test_merge_esmvaltool_datasets()

AssertionError: 

In [9]:
files = list((forcing_path).glob("OBS6_ERA5*.nc"))
datasets = [xr.open_dataset(file) for file in files]
ds = merge_esvmaltool_datasets(datasets)

In [10]:
ds['tas']

## But if we load it as in the code

In [11]:
def merge_esvmaltool_datasets(datasets: list[xr.Dataset]) -> xr.Dataset:
    """Merge the separate output datasets from an ESMValTool recipe into one dataset.

    ESMValTool has bad management of floating point precision in coordinates. Every
    CMORized file can have different rounding errors in the values of its coordinates.
    This will prevent easy merging with xarray's open_mfdataset, or combine_by_coords.
    By rounding to the 7th decimal place, more than sufficient precision is preserved
    ('waldo-on-a-page' precision)[1], while solving the floating point inprecision
    issue.

    References:
        [1] Randall Monroe, 2019. xkcd: Coordinate Precision. https://xkcd.com/2170/
    """
    TOLERANCE = 1e-7
    datasets = copy(datasets)

    # First check that the coordinates all line up before merging.
    for coord in ["lat", "lon"]:
        coords = [ds[coord].to_numpy() for ds in datasets]
        if len(set([c.size for c in coords])) > 1:
            msg = f"The coordinate '{coord}' is not of the same size in every dataset."
            raise ValueError(msg)
        all_coords = np.array(coords)
        if not np.all((all_coords - all_coords.mean(axis=0)) < TOLERANCE):
            msg = f"Coordinate {coord} deviates more than {TOLERANCE}. Merging failed."
            raise ValueError(msg)

    removed = {
        "lat_bnds": False,
        "lon_bnds": False,
    }
    for i in range(len(datasets)):
        # Bounds are not aligned, and can be missing in derived vars,
        #  so we remove all except the first lat/lon bounds we encounter.
        for coord in ["lat_bnds", "lon_bnds"]:
            if coord in datasets[i]:
                if removed[coord]:
                    datasets[i] = datasets[i].drop_vars(coord)
                removed[coord] = True

        # xr.align doesn't work for lumped forcing. this works for both lumped and dist:
        for coord in ["lat", "lon"]:
            datasets[i][coord] = datasets[0][coord]

        # the time coordinates are messed up for some files, see:
        #   https://github.com/eWaterCycle/infra/issues/157
        #   the following is a workaround.
        if "time_bnds" in datasets[i] and xr.infer_freq(datasets[i]["time"]) == "D":
            datasets[i]["time"] = datasets[i]["time_bnds"].isel(
                bnds=0
            ) + np.timedelta64(12, "h")
            datasets[i] = datasets[i].drop_vars("time_bnds")

        # A "height" coordinate can be present, which will result in conflicts.
        #   Instead, we move it to the variable's attributes.
        if "height" in datasets[i].variables:
            data_vars = list(datasets[i].data_vars)
            if "time_bnds" in data_vars:
                data_vars.remove("time_bnds")
            # var = data_vars[-1]
            var = "tas"
            datasets[i][var].attrs.update(
                {
                    "height": float(datasets[i]["height"]),
                    "height_units": datasets[i]["height"].attrs["units"],
                }
            )
            datasets[i] = datasets[i].drop_vars(("height",))

    return xr.combine_by_coords(datasets, combine_attrs="drop_conflicts")  # type: ignore[return-value]


## Suddendly it does work

In [12]:
ds = merge_esvmaltool_datasets(datasets)

In [13]:
ds['tas']

In [14]:
ds['tas'].attrs.update()

In [15]:
datasets = []
i = 0
datasets.append(xr.open_dataset(files[i]))

In [16]:
'height' in datasets[0].variables.keys()

True

In [17]:
removed = {
    "lat_bnds": False,
    "lon_bnds": False,
}
for i in range(len(datasets)):
    # Bounds are not aligned, and can be missing in derived vars,
    #  so we remove all except the first lat/lon bounds we encounter.
    for coord in ["lat_bnds", "lon_bnds"]:
        if coord in datasets[i]:
            if removed[coord]:
                datasets[i] = datasets[i].drop_vars(coord)
            removed[coord] = True

    # xr.align doesn't work for lumped forcing. this works for both lumped and dist:
    for coord in ["lat", "lon"]:
        datasets[i][coord] = datasets[0][coord]

    # the time coordinates are messed up for some files, see:
    #   https://github.com/eWaterCycle/infra/issues/157
    #   the following is a workaround.
    if "time_bnds" in datasets[i] and xr.infer_freq(datasets[i]["time"]) == "D":
        datasets[i]["time"] = datasets[i]["time_bnds"].isel(
            bnds=0
        ) + np.timedelta64(12, "h")
        datasets[i] = datasets[i].drop_vars("time_bnds")

    # A "height" coordinate can be present, which will result in conflicts.
    #   Instead, we move it to the variable's attributes.
    if "height" in datasets[i].variables:
        data_vars = list(datasets[i].data_vars)
        if "time_bnds" in data_vars:
            data_vars.remove("time_bnds")
        # var = data_vars[-1]
        var = "tas"
        datasets[i][var].attrs.update(
            {
                "height": float(datasets[i]["height"]),
                "height_units": datasets[i]["height"].attrs["units"],
            }
        )
        datasets[i] = datasets[i].drop_vars(("height",))

In [18]:
datasets[0]['tas']

In [19]:
datasets[i]['tas'].attrs.update({"Height":2})

In [20]:
datasets[i]

In [21]:
datasets[i]['tas']

In [22]:
if "height" in datasets[i].variables:
    data_vars = list(datasets[i].data_vars)
    if "time_bnds" in data_vars:
        data_vars.remove("time_bnds")
    var = data_vars[-1]
    datasets[i][var].attrs.update(
        {
            "height": float(datasets[i]["height"]),
            "height_units": datasets[i]["height"].attrs["units"],
        }
    )
    print( datasets[i][var].attrs)
    datasets[i] = datasets[i].drop_vars(("height",))


In [23]:
datasets[i].data_vars

Data variables:
    lon_bnds  (lon, bnds) float64 528B ...
    lat_bnds  (lat, bnds) float64 400B ...
    tas       (time, lat, lon) float32 99kB ...

In [24]:
ds

In [25]:
def test_merge_esmvaltool_datasets():
    files = list((forcing_path).glob("OBS6_ERA5*.nc"))
    datasets = [xr.open_dataset(file) for file in files]
    ds = merge_esvmaltool_datasets(datasets)
    for var in ["tas", "pr", "rsds"]:
        assert not ds[var].mean(dim=["lat", "lon"]).isnull().any("time")

    assert "height" in ds["tas"].attrs

In [26]:
test_merge_esmvaltool_datasets()