In [1]:
import numpy as np
import pandas as pd
import xarray as xr
import holoviews as hv
import hvplot.xarray
import cftime

import ismip6_index, grid_utils

In [2]:
hv.extension('bokeh')

### Loading ISMIP6 Antarctica outputs

* ISMIP6 Antarctica outputs are ~1.1 TB in total. Officially, they are available through Globus, but we've pulled the whole dataset and put it on GCloud at `gs://ismip6/`.
* Following CMIP conventinos, every variable is a separate NetCDF file. Nominally, these are CF-compliant and follow a standardized set of file and variable naming rules, but, following CMIP conventions üôÇ, there are a scattering of errors. See the [ISMIP6 output specifications](https://theghub.org/groups/ismip6/wiki/MainPage/ISMIP6ProjectionsAntarctica).
* All of the outputs are uniform rectangular grids in EPSG:3031 project, but there are multiple resolutions.

Ideally, (lazy) loading the outputs might look something like this, leveraging the patterns established by [xarray](https://github.com/pydata/xarray), [intake-esm](https://github.com/intake/intake-esm), and [xMIP](https://github.com/jbusecke/xMIP):

Note: This is not runnable code. This is a concept of what it *should* look like. See the next cell for currently-runnable code.

> ```python
> ismip6_cat = intake.open_esm_datastore('gs://ismip6/ismip6.json')
> ismip6_dt = ismip6_cat.to_datatree(preprocess=xmip.fix_ismip6)
> my_variable = ismip6_dt['JPL1/ISSM']['exp05']['lithk']
> ```

To actually do this today, you could run:

In [3]:
ismip6_df = ismip6_index.get_file_index() # This is ~200 lines of code to build an index of ISMIP6 data files from the filenames

# Find a specific dataset and load it as an xarray Dataset
path = ismip6_df.query('institution == "JPL1" and model_name == "ISSM" and experiment == "exp05" and variable == "lithk"')['url'].values[0]
ds = xr.open_dataset(path, engine='h5netcdf', decode_cf=True, decode_times=True)

# Some of the ISMIP6 datasets have inconsistent grid coordinate definitions, so this helper function fixes that
ds = grid_utils.correct_grid_coordinates(ds, data_var="lithk") # Another ~300 lines for fixing gridding inconsistencies

ds

Loading index from cache: .cache/ismip6_index.parquet
‚ö†Ô∏è  Grid correction: Dataset missing x/y coordinates for 'lithk'
   Detected dimensions: y=761, x=761
   Estimated resolution: dx=8.0 km, dy=8.0 km
   Creating coordinates: x=[-3040.0, 3040.0] km, y=[-3040.0, 3040.0] km
   ‚úì Grid correction complete



  ny = ds.dims[y_dim]
  nx = ds.dims[x_dim]


### Regridding multiple models to a common comparison grid

We think it's important for there to be a "batteries-included" way to regridding to a common grid.
Regridding should be lazy and eventually support a pretty wide range of possible grids including rectilinear in lat/lon or projected coordiantes, healpix, and unstructued meshes.

[xESMF](https://xesmf.readthedocs.io/en/stable/) is probably the closest to being what we're looking for, though its dependence on the Fortran/C++ ESMF makes it an annoying piece of the stack to include.

Eventually, we'd want regridding to look something like this:

> ```python
> comparison_grid = xr.Dataset({
>     'x': (['x'], np.arange(-30400e3, 3040e3, 16e3)),
>     'y': (['y'], np.arange(-30400e3, 3040e3, 16e3)),
> })
> 
> model_outputs = regrid(model_outputs, target=comparison_grid, func=np.mean)
> ```

In [4]:
matching_files = ismip6_df.query('experiment == "exp05" and variable == "lithk"').iloc[::5] # Take a reduced set for this demo
datasets = {f'{row["institution"]}_{row["model_name"]}': xr.open_dataset(row["url"], engine='h5netcdf', decode_cf=True, decode_times=True) for _, row in matching_files.iterrows()}
lithk_dt = xr.DataTree.from_dict(datasets)

comparison_grid = xr.Dataset({
    'x': (['x'], np.arange(-3040e3, 3040e3, 16e3)),
    'y': (['y'], np.arange(-3040e3, 3040e3, 16e3)),
})

# Regrid all datasets to comparison_grid using xarray interpolation
regridded_children = {}

for child_name, child_node in lithk_dt.children.items():
    print(f"Regridding {child_name}...")
    
    # Get the dataset from the DataTree node (because DataTree does not support interp directly)
    child_ds = child_node.ds
    
    # Fix grid coordinates if needed
    child_ds = grid_utils.correct_grid_coordinates(child_ds, data_var="lithk")

    # Check the time coordinate type and convert any np.datetime to cftime
    if np.issubdtype(child_ds['time'].dtype, np.datetime64):
        time_vals = pd.to_datetime(child_ds['time'].values)
        cftime_vals = [cftime.DatetimeNoLeap(t.year, t.month, t.day, t.hour, t.minute, t.second) for t in time_vals]
        child_ds = child_ds.assign_coords(time=('time', cftime_vals))
    
    # Interpolate to the comparison grid
    regridded = child_ds.interp(
        x=comparison_grid.x,
        y=comparison_grid.y,
        method='linear',
        kwargs={'fill_value': np.nan}
    )
    
    regridded_children[child_name] = regridded

# Create new DataTree with regridded datasets sharing the same grid
lithk_dt_regridded = xr.DataTree.from_dict(regridded_children)

lithk_dt_regridded

Regridding AWI_PISM1...
Regridding JPL1_ISSM...
‚ö†Ô∏è  Grid correction: Dataset missing x/y coordinates for 'lithk'
   Detected dimensions: y=761, x=761
   Estimated resolution: dx=8.0 km, dy=8.0 km
   Creating coordinates: x=[-3040.0, 3040.0] km, y=[-3040.0, 3040.0] km
   ‚úì Grid correction complete



  ny = ds.dims[y_dim]
  nx = ds.dims[x_dim]


Regridding ULB_fETISh_32km...
‚ö†Ô∏è  Grid correction: Dataset missing x/y coordinates for 'lithk'
   Detected dimensions: y=761, x=761
   Estimated resolution: dx=8.0 km, dy=8.0 km
   Creating coordinates: x=[-3040.0, 3040.0] km, y=[-3040.0, 3040.0] km
   Verifying consistency with existing lat/lon coordinates...
   ‚úì Coordinates are consistent with lat/lon
   ‚úì Grid correction complete



### Taking the standard deviation across models

Once we get a unified DataTree on a common grid, actually taking the standard deviation of models works pretty well.

In [5]:
# Time index selection to get the closets timestamp to 2100
time = cftime.DatetimeNoLeap(2100, 1, 1)
lithk_dt_regridded_timestep = lithk_dt_regridded.sel(time=time, method='nearest')

lithk_var = xr.concat(
    [child['lithk'] for child in lithk_dt_regridded_timestep.children.values()],
    dim='m'
).std(dim='m')

lithk_var.hvplot.image(
    x='x',
    y='y',
    cmap='viridis',
    colorbar=True,
    title='Standard deviation of ice thickness'
).opts(aspect='equal')