In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import dask
import holoviews as hv
import hvplot.xarray
import cftime

import ismip6_helper

In [None]:
hv.extension('bokeh')

## Loading ISMIP6 Antarctica outputs

* ISMIP6 Antarctica outputs are ~1.1 TB in total. Officially, they are available through Globus, but we've pulled the whole dataset and put it on GCloud at `gs://ismip6/` (byte-for-byte identical -- we don't want to be responsible for hosting modified versions of things like this).
* Following CMIP conventions, every variable is a separate NetCDF file. Nominally, these are CF-compliant and follow a standardized set of file and variable naming rules, but, following CMIP conventions ðŸ™‚, there are a scattering of errors. See the [ISMIP6 output specifications](https://theghub.org/groups/ismip6/wiki/MainPage/ISMIP6ProjectionsAntarctica).
* All of the outputs are uniform rectangular grids in EPSG:3031 projection, but there are multiple resolutions.


Ideally, lazy loading of this dataset should be easy and concise. We would like loading a dataset like this into an Xarray DataTree to be a two line operation:

```python
catalog = helper_library.open("gs://lightweight-reference-file")
# If desired, filter the catalog
dt = catalog.to_datatree()
```

Where there are two important properties we care about here:

1. We want to be able to create `lightweight-reference-file` for an existing dataset without needing to change the underlying bytes
2. We want to be able to encode "fixes" somewhere before it becomes an Xarray DataTree -- "fixes" are things like different variable naming conventions, misspelled files, incorrect time axes, etc.

The code below actually loads the ISMIP6 outputs into a DataTree. It's a bit more than two lines.

The main issues are inconsistencies in the output files. For example:
* A few files are mis-named (missing an underscore) â†’ corrected by `ismip6_helper.get_file_index()`
* Some grids were defined with `x` and `y` coordinates (in EPSG:3031 projection) while others were specified by `lat`, `lon` points â†’ corrected by `ismip6_helper.correct_grid_coordinates()`
* Timestamps are specified in a variety of different formats with various CF-compliance issues â†’ fixed by `ismip6_helper.open_ismip6_dataset()` which automatically detects and corrects:
  - Typo: `unit` instead of `units` attribute
  - Invalid use `MM-DD-YYYY` instead of `YYYY-MM-DD`
  - Invalid dates (e.g., day 0 â†’ day 1)
  - etc...

In [None]:
# Create a dataframe index by scanning the ISMIP6 Antarctica output files
# This is ~200 lines of code to build an index of ISMIP6 data files from the filenames
ismip6_df = ismip6_helper.get_file_index()

# For the purposes of this demo, filter down the number of files we have to load
ismip6_df = ismip6_df.query('experiment in ["ctrl_proj_std", "exp05", "ctrl_proj"] and variable in ["lithk", "base", "sftgrf"] and institution in ["JPL1", "AWI", "DOE"]')

# Build a DataTree of the outputs
datasets = {}
for _, row in ismip6_df.iterrows():
    try:
        p = f'{row["institution"]}_{row["model_name"]}/{row["experiment"]}/{row["variable"]}' # DataTree path
        
        # Use the new helper function that automatically fixes time encoding issues
        ds = ismip6_helper.open_ismip6_dataset(row["url"], chunks={'time': 1})
        ds = ismip6_helper.correct_grid_coordinates(ds, data_var=row["variable"])

        datasets[p] = ds
    except Exception as e:
        print(f"Failed to load {p}: {e}")
        
ismip6_dt = xr.DataTree.from_dict(datasets)

#ismip6_dt

### Select and plot one variable

Once we have the DataTree loaded, we can easily filter down to variables of interest: `ismip6_dt['JPL1_ISSM']['exp05']['lithk']`

**This part works well enough.**

The example below produces a plot of the change in ie thickness since the beginning of the simulation. So far, we've only lazily loaded the data, so the actual data hasn't been downloaded. We call `.compute()` on the thickness change variable to force loading of the data in order to make the interactive plot responsive.

In [None]:
# Select one dataset from the DataTree
dt = ismip6_dt['JPL1_ISSM']['exp05']['lithk']

# Compute the change in thickness relative to the first time step
# Since the datasets are lazily loaded, we now want to actually force computation of a result
# so that the interactive plot will be responsive.
delta_thickness = (dt['lithk'] - dt['lithk'].isel(time=0)).rename('delta_lithk').compute()

# Determine a useful color scale range
vmag = np.max(np.abs(delta_thickness.quantile([0.01, 0.99]).values))

# Plot with a slider to change the date
delta_thickness.hvplot.image(x='x', y='y', clim=(-vmag, vmag), cmap='RdBu').opts(
        aspect='equal',
        title="Change in ice thickness relative to the first timestep",
        colorbar_opts={'title': 'Change in thickness (m)'},
    )

### Regridding multiple models to a common comparison grid

While all of the ISMIP6 outputs were interpolated to a regular grid, these grids have different resolutions. So if we want to do any cross-model comparison, we need to get things onto a common grid.

We have some more complicated ideas about how to do regridding, but we also want to make sure that simple things work.

Ideally, it would be possible to call `interp` on a DataTree like this:

```python
comparison_grid = xr.Dataset({
    'x': (['x'], np.arange(-3040e3, 3040e3, 16e3)),
    'y': (['y'], np.arange(-3040e3, 3040e3, 16e3)),
    'time': (['time'], xr.date_range('2016-01-01', '2100-12-31', freq='10Y').values),
})

ismip6_dt_regridded = ismip6_dt.interp(x=comparison_grid.x, y=comparison_grid.y)
```

This doesn't actually work yet, but we can use `map_over_datasets` to do the same thing. Not terrible, but could be cleaner.

In [None]:
comparison_grid = xr.Dataset({
    'x': (['x'], np.arange(-3040e3, 3040e3, 16e3)),
    'y': (['y'], np.arange(-3040e3, 3040e3, 16e3)),
    'time': (['time'], xr.date_range('2016-01-01', '2100-12-31', freq='10Y').values),
})

regridded = ismip6_dt.map_over_datasets(
    lambda x: x.interp(
        x=comparison_grid.x,
        y=comparison_grid.y,
        time=comparison_grid.time,
        method='nearest',
        kwargs={'fill_value': np.nan}
    ) if ('x' in x.dims and 'y' in x.dims) else x
)

Now that we're working on a common comparison grid, we can do some cross-model comparison. As an example, we'll plot the standard deviation of the change in ice thickness since the first timestep of each model.

```python
exp05 = ismip6_dt.subtree['exp05']
lithk_subset = exp05.subtree['lithk']

# or...

lithk_subset = ismpi6_dt.subtree['exp05', 'lithk']

lithk.std()
lithk.interp(args).std()
```

In [None]:
# Calculate the standard deviation of the change in lithk across models
delta_lithk_all = xr.concat([
    (node.ds['lithk'].isel(time=slice(1, None)) - node.ds['lithk'].isel(time=0)) 
    for node in regridded.subtree 
    if node.path.endswith('exp05/lithk') and node.has_data
], dim='model').std(dim='model').compute()

delta_lithk_all.hvplot.image(
    x='x', y='y', 
    clim=(0, 200), 
    cmap='gray_r',
    clabel='Std dev of thickness change (m)'
).opts(aspect='equal', title='Standard deviation of ice thickness change across models')

### Computed scalars: mass above flotation
**TODO**

In [None]:
tmp_ivaf_reference = xr.open_dataset('external_data/ismip6_computed_scalars/computed_ivaf_AIS_JPL1_ISSM_exp05.nc', engine='h5netcdf', decode_times=False)
tmp_ivaf_minus_ctrl_reference = xr.open_dataset('external_data/ismip6_computed_scalars/computed_ivaf_minus_ctrl_proj_AIS_JPL1_ISSM_exp05.nc', engine='h5netcdf', decode_times=False)

ivaf_reference = xr.merge([tmp_ivaf_reference['ivaf'], tmp_ivaf_minus_ctrl_reference['ivaf'].rename('ivaf_minus_ctrl')])
(ivaf_reference - ivaf_reference.isel(time=0)).hvplot.line(
    x='time',
    ylabel='Cumulative change in volume above flotation (m^3)',
    title='Ice Volume Above Flotation for JPL1_ISSM exp05'
    )


```matlab
ivaf_total_region=sum((thickness_i(pos_region)+ocean_density/ice_density*min(bed_i(pos_region),0)).*groundmask_i(pos_region).*mask_i(pos_region).*scalefac_model(pos_region))*(resolution*1000)^2; %in m^3
```

In [None]:
ice_density = 917 # kg/m^3
ocean_density = 1028 # kg/m^3

In [None]:
# Calculate the scaling factor for each grid cell in regridded using pyproj
# The scaling factor accounts for map projection distortion
import pyproj

# Create 2D grids for x and y
xx, yy = np.meshgrid(comparison_grid.x.values, comparison_grid.y.values)

# Set up the projection for EPSG:3031 (Antarctic Polar Stereographic)
proj = pyproj.Proj('EPSG:3031')

# Convert projected coordinates to lat/lon
lons, lats = proj(xx, yy, inverse=True)

# Get the factors at each grid point using the Proj object
# get_factors returns: (meridional_scale, parallel_scale, areal_scale, 
#                       angular_distortion, meridian_parallel_angle, 
#                       meridian_convergence, tissot_semimajor, tissot_semiminor)
# We want the areal_scale (index 2)
factors = proj.get_factors(lons.ravel(), lats.ravel(), radians=False)
areal_scale = factors.areal_scale.reshape(xx.shape)

# Create the scaling factor as an xarray DataArray
scale_factor = xr.DataArray(
    areal_scale,
    coords={'y': comparison_grid.y.values, 'x': comparison_grid.x.values},
    dims=['y', 'x'],
    name='scalefac',
    attrs={
        'long_name': 'Area scaling factor',
        'description': 'Ratio of true area to projected area from pyproj.get_factors',
        'units': '1',
        'projection': 'EPSG:3031',
    }
)

scale_factor.hvplot.image().opts(aspect='equal')

In [None]:
scale_file = xr.open_dataset('external_data/ismip6_computed_scalars/af2_el_ismip6_ant_01.nc')
scale_file_interp = scale_file['af2'].interp(x=comparison_grid.x, y=comparison_grid.y, method='linear')

(scale_file_interp - (1/scale_factor)).hvplot.image().opts(aspect='equal', clim=(-0.1, 0.1), title='Difference between computed and reference scaling factor')

In [None]:
for model in ismip6_dt.children.keys():
    for experiment in ismip6_dt[model].children.keys():
        print(f"{model} - {experiment}")
        dt = ismip6_dt[model][experiment]

        lithk = dt['lithk']['lithk']
        base = dt['base']['base']
        sftgrf = dt['sftgrf']['sftgrf']

        resolution = comparison_grid.x[1] - comparison_grid.x[0]  # in meters

        ivaf = ((lithk + ocean_density/ice_density * np.minimum(base, 0)) * sftgrf * (1/scale_factor)).sum(dim=['x', 'y']) * (resolution**2)  # in m^3
        ivaf = ivaf.rename('ivaf')
        ismip6_dt[model][experiment]['ivaf'] = ivaf

In [None]:
ivaf = ismip6_dt['JPL1_ISSM']['exp05']['ivaf'].compute()
ivaf_minus_ctrl = (ivaf - ismip6_dt['JPL1_ISSM']['ctrl_proj']['ivaf']).compute()

In [None]:
calc = ((ivaf - ivaf.isel(time=0)).hvplot.line(x='time', label='exp05') *
        (ivaf_minus_ctrl - ivaf_minus_ctrl.isel(time=0)).hvplot.line(x='time', label='exp05 - ctrl_proj')
        ).opts(
    ylabel='Cumulative change in\nvolume above flotation (m^3)',
    title='Computed changed in ice volume above flotation for JPL1_ISSM exp05',
    legend_position='bottom_left', show_grid=True
    )
ref = (ivaf_reference - ivaf_reference.isel(time=0)).hvplot.line(
    x='time',
    ylabel='Cumulative change in\nvolume above flotation (m^3)',
    title='Reference (from doi.org/10.5281/zenodo.3940766)'
    ).opts(legend_position='bottom_left', show_grid=True)

calc + ref

In [None]:
ismip6_dt['AWI_PISM1/ctrl_proj_std'].subtree.as_dataset()
ismip6_dt['AWI_PISM1/ctrl_proj_std'].subtree.condense()
ismip6_dt['AWI_PISM1/ctrl_proj_std'].subtree.gather_leaves() # gather_siblings

dt_of_exp05_datasets =ismip6_dt.subtree['exp05'].gather_siblings()


In [None]:
list(ismip6_dt['AWI_PISM1/ctrl_proj_std'].children)

In [None]:
dt = ismip6_dt.copy()

dt = dt.map_over_datasets(
    lambda x: xr.merge([y.to_dataset() for y in x.subtree]) if list(x.subtree.children) == ['base', 'lithk', 'sftgrf'] else x
)
dt

In [None]:
ismip6_dt['AWI_PISM1/ctrl_proj_std'].as_dataset()

In [None]:
xr.merge([x.to_dataset() for x in ismip6_dt['AWI_PISM1/ctrl_proj_std'].subtree])