# NEMO CF Demo with GYRE dataset

## Parameters

In [None]:
# parameters

esm_vfc_data_dir = "../esm-vfc-data/"
nemo_catalog_url = "https://raw.githubusercontent.com/ESM-VFC/esm-vfc-catalogs/master/catalogs/NEMO_GYRE_Test.yaml"

## Tech preamble

In [None]:
import numpy as np
import xarray as xr

In [None]:
# set up intake catalog
import intake
from esmvfc_cattools import download_zenodo_files_for_entry
import os

os.environ["ESM_VFC_DATA_DIR"] = esm_vfc_data_dir

In [None]:
# set up plotting
import hvplot.pandas
import hvplot.xarray
import geoviews.feature as gf
from cartopy import crs

## Load catalog and fetch data

In [None]:
model_data_cat = intake.open_catalog(nemo_catalog_url)
list(model_data_cat)

In [None]:
import logging
logging.basicConfig(level=logging.DEBUG)

In [None]:
for entry in model_data_cat:
    download_zenodo_files_for_entry(model_data_cat[entry], force_download=True)

In [None]:
model_data_cat.NEMO_GYRE_Test_grid_T(xarray_kwargs={"decode_cf": False}).read()

In [None]:
data_ds = {k: v(xarray_kwargs={"decode_cf": False}).read() for k, v in model_data_cat.items() if "mesh_mask" not in k}
print(data_ds.keys())

aux_ds = {k: v(xarray_kwargs={"decode_cf": False}).read() for k, v in model_data_cat.items() if k not in data_ds}
print(aux_ds.keys())

In [None]:
%pip install git+https://github.com/willirath/nemo_cf@7-convert-coords-data-fields --upgrade

In [None]:
from nemo_cf.nemo_cf import update_mesh_mask_dataset

In [None]:
aux_ds["NEMO_GYRE_Test_mesh_mask"] = update_mesh_mask_dataset(aux_ds["NEMO_GYRE_Test_mesh_mask"])
aux_ds["NEMO_GYRE_Test_mesh_mask"]

In [None]:
def remove_time_counter(ds):
    try:
        return ds.drop(["time_counter", ])    
    except ValueError as e:
        return ds

In [None]:
aux_ds["NEMO_GYRE_Test_mesh_mask"] = remove_time_counter(aux_ds["NEMO_GYRE_Test_mesh_mask"])
aux_ds["NEMO_GYRE_Test_mesh_mask"]

## Grid detection

- vertical coords
    - in the data files (`..._grid_T.nc`, `..._icemod.nc`, etc.), the vertical dim is always exactly one of
    `["deptht", "depthu", "depthv", "depthf", "depthw"]`
    - in the mesh-mask files, there's two vars for the vertical axis: `["gdept_1d", "gdepw_1d"]`
    - we detect actual vertical grids by comparing `["deptht", "depthu", "depthv", "depthf", "depthw"]`
    to `["gdept_1d", "gdepw_1d"]`
    
- horizontal coords
    - in the data files (`..._grid_T.nc`, `..._icemod.nc`, etc.), there horizontal coords are always
    `["nav_lon", "nav_lat"]`
    - in the mesh-mask files, there's the latitude fields `["gphit", "gphiu", "gphiv", "gphif"]` and
    the longitude fields `["glamt", "glamu", "glamv", "glamf"]`
    - we detect actual horizontal grids by comparing `["nav_lon", "nav_lat"]` to `["gphit", "gphiu", "gphiv", "gphif"]` and
    `["glamt", "glamu", "glamv", "glamf"]`

In [None]:
def detect_horizontal_grid(ds, mesh_mask):
    tol = min(
        abs(mesh_mask["glamt"].diff("x")).min().data,
        abs(mesh_mask["gphit"].diff("y")).min().data
    ) / 10.0
    for hgrid in ["t", "u", "v", "f"]:
        lat_diff = abs(ds["nav_lat"] - mesh_mask[f"gphi{hgrid}"]).max()
        lon_diff = abs(ds["nav_lon"] - mesh_mask[f"glam{hgrid}"]).max()
        if (lat_diff < tol) & (lon_diff < tol):
            return hgrid

In [None]:
detect_horizontal_grid(data_ds["NEMO_GYRE_Test_grid_V"], aux_ds["NEMO_GYRE_Test_mesh_mask"])

In [None]:
def detect_vertical_grid(ds, mesh_mask):
    tol = abs(mesh_mask["gdept_1d"].diff("z")).min().data / 10.0
    depth_coord = [k for k in ds.coords if "depth" in k][0]
    for vgrid in ["t", "w"]:
        z_diff = abs(abs(ds[depth_coord]) - abs(mesh_mask[f"gdep{vgrid}_1d"])).min()
        if z_diff < tol:
            return vgrid

In [None]:
detect_vertical_grid(data_ds["NEMO_GYRE_Test_grid_V"], aux_ds["NEMO_GYRE_Test_mesh_mask"])

In [None]:
def detect_grids(ds, mesh_mask):
    return {
        "hgrid": detect_horizontal_grid(ds, mesh_mask),
        "vgrid": detect_vertical_grid(ds, mesh_mask)
    }

In [None]:
grids = {k: detect_grids(data_ds[k], aux_ds["NEMO_GYRE_Test_mesh_mask"]) for k in data_ds}

In [None]:
grids

In [None]:
def safely_rename_dims(dataset, rename_dims_dict=None):
    valid_rename_dims_dict = {
        dim: rename_dims_dict[dim]
        for dim in filter(lambda dim: dim in dataset.dims, rename_dims_dict)
    }

    return dataset.rename_dims(valid_rename_dims_dict)


def safely_rename_vars(dataset, rename_vars_dict=None):
    valid_rename_vars_dict = {
        var: rename_vars_dict[var]
        for var in filter(
            lambda var: var in dataset.data_vars or var in dataset.coords,
            rename_vars_dict
        )
    }

    return dataset.rename_vars(valid_rename_vars_dict)


def update_data_file_coords(dataset, hgrid=None, vgrid=None):
    """Ensure spatial coords are consistent."""
    
    if vgrid.lower() == "w":
        nemo_grid = vgrid
    else:
        nemo_grid = hgrid
    
    rename_dims_dict = {
        f"depth{nemo_grid.lower()}": f"z_{vgrid.lower()}",
        "y": f"y_{hgrid.lower()}",
        "x": f"x_{hgrid.lower()}"
    }
    dataset = safely_rename_dims(dataset, rename_dims_dict=rename_dims_dict)

    rename_vars_dict = {
        f"depth{nemo_grid.lower()}": f"gdep{vgrid.lower()}_1d",
        "nav_lat": f"gphi{hgrid.lower()}",
        "nav_lon": f"glam{hgrid.lower()}",
    }
    dataset = safely_rename_vars(dataset, rename_vars_dict=rename_vars_dict)

    for v in dataset.data_vars:
        if "coordinates" in dataset[v].attrs:
            for old, new in rename_vars_dict.items():
                dataset[v].attrs["coordinates"] = (
                    dataset[v].attrs["coordinates"].replace(old, new)
                )

    return dataset

In [None]:
data_ds = {
    k: update_data_file_coords(
        data_ds[k],
        **detect_grids(data_ds[k], aux_ds["NEMO_GYRE_Test_mesh_mask"])
    ) for k in data_ds.keys()
}

In [None]:
def fix_calendar(ds):
    for v in list(ds.data_vars) + list(ds.coords):
        try:
            if ds[v].attrs["calendar"] == "360d":
                ds[v].attrs["calendar"] = "360_day"
        except KeyError as e:
            pass
    return ds

In [None]:
data_ds = {
    k: fix_calendar(data_ds[k]) for k in data_ds
}

In [None]:
def consolidate_dim_names(ds):
    mapping = {
        "x_t": "x_c", "y_t": "y_c",
        "x_u": "x_r", "y_u": "y_c",
        "x_v": "x_c", "y_v": "y_r",
        "x_f": "x_r", "y_f": "y_r",
        "z_t": "z_c", "z_w": "z_l"
    }
    return safely_rename_dims(ds, rename_dims_dict=mapping)


def add_dim_coords(ds):
    for dim in ds.dims:
        if dim.endswith("_c"):
            ds.coords[dim] = (
                [dim, ],
                np.arange(ds.dims[dim]),
                {
                    "axis": dim[0].upper(),
                    "c_grid_axis_shift": 0
                }
            )
        if dim.endswith("_r"):
            ds.coords[dim] = (
                [dim, ],
                np.arange(ds.dims[dim]) + 0.5,
                {
                    "axis": dim[0].upper(),
                    "c_grid_axis_shift": +0.5
                }
            )
        if dim.endswith("_l"):
            ds.coords[dim] = (
                [dim, ],
                np.arange(ds.dims[dim]) - 0.5,
                {
                    "axis": dim[0].upper(),
                    "c_grid_axis_shift": -0.5
                }
            )
    return ds

In [None]:
data_ds = {k: consolidate_dim_names(v) for k, v in data_ds.items()}

In [None]:
from xgcm import Grid

In [None]:
#TODO: Fix for mesh-mask as well.

In [None]:
Grid(
    add_dim_coords(xr.decode_cf(xr.merge(data_ds.values()))),
    periodic=['X', 'Y']
)