## Regridding eORCA025 data onto eORCA0083 using xESMF

What follows is a rehash of a few examples of how to use xESMF to regrid data (mostly from [roocs](https://github.com/roocs)). The objective is to regrid an eORCA025 data set on to the eORCA0083 grid for initial conditions to be used in the CANARI project. Usually my go to tool is SOSIE/SCRIP to regrid data, but I thought I'd venture into the Python realm.

In [8]:
import os; os.environ['PROJ_LIB'] = '/dssgfs01/working/jdha/miniforge3/envs/analysis/share/proj' # avoid basemap import error

In [9]:
%matplotlib inline
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import numpy as np
import xarray as xr
import cf_xarray as cfxr
import xesmf as xe
import scipy.sparse as sps
import clisops.core as clore
import clisops as cl
import textwrap
import math
import copy as cp
import warnings

xr.set_options(display_style='html');

In [3]:
def infill(arr_in, n_iter=None, bathy=None):
    """
    Returns data with any NaNs replaced by iteratively taking the geometric
    mean of surrounding points until all NaNs are removed or n_inter-ations
    have been performed. Input data must be 2D and can include a
    bathymetry array as to provide land barriers to the infilling.

    Args:
        arr_in          (ndarray): data array 2D
        n_iter              (int): number of smoothing iterations
        bathy           (ndarray): bathymetry array (land set to zero)

    Returns:
        arr_mod         (ndarray): modified data array
    """

    # Check number of dims
    if arr_in.ndim != 2:
        raise ValueError("Array must have two dimensions")

    # Intial setup to prime things for the averaging loop
    if bathy is None:
        bathy = np.ones_like(arr_in, dtype=float)
    if n_iter is None:
        n_iter = np.inf
    ind = np.where(np.logical_and(np.isnan(arr_in), np.greater(bathy, 0.)))
    counter = 0
    jpj, jpi = arr_in.shape
    # Infill until all NaNs are removed or N interations completed
    while np.sum(ind)>0 and counter<n_iter:

        # TODO: include a check to see if number of NaNs is decreasing

        # Create indices of neighbouring points
        ind_e = cp.deepcopy(ind); ind_w = cp.deepcopy(ind)
        ind_n = cp.deepcopy(ind); ind_s = cp.deepcopy(ind)

        ind_e[1][:] = np.minimum(ind_e[1][:]+1, jpi-1)
        ind_w[1][:] = np.maximum(ind_w[1][:]-1, 0    )
        ind_n[0][:] = np.minimum(ind_n[0][:]+1, jpj-1)
        ind_s[0][:] = np.maximum(ind_s[0][:]-1, 0    )

        # Replace NaNs
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=RuntimeWarning)
            arr_in[ind] = np.nanmean(np.vstack((arr_in[ind_e],
                                                arr_in[ind_w],
                                                arr_in[ind_n],
                                                arr_in[ind_s])), axis=0)

        # Find new indices for next loop
        ind = np.where(np.logical_and(np.isnan(arr_in),
                                      np.greater(bathy, 0.)))
        counter += 1

    return arr_in

### Specify source grid

In [4]:
# source data
src_fname = './src_data.nc'
# source grid
src_dom = './src_domain_cfg.nc'

# create dataset
ds_src_grid = xr.open_dataset(src_dom).isel(time_counter=0).rename({'glamt': 'lon', 'gphit': 'lat'})
ds_src_grid = ds_src_grid.set_coords(("lat", "lon"))

# create dataset
ds_src_data = xr.open_dataset(src_fname).rename({'nav_lon': 'lon', 'nav_lat': 'lat'})


### Specify destination grid

In [5]:
# destination grid
dst_dom = './dst_domain_cfg.nc'

# create dataset
ds_dst_grid = xr.open_dataset(dst_dom).isel(time_counter=0).rename({'x': 'lon', 'y': 'lat'})
ds_dst_grid = ds_dst_grid.set_coords(("lat", "lon"))

In [None]:
# create the regridding weights
regridder_eORCA0083=xe.Regridder(ds_src_grid, ds_dst_grid, 'bilinear', periodic=True, ignore_degenerate=True, unmapped_to_nan=True)
fn = '/dssgfs01/working/jdha/regridder_eORCA025_to_eORCA0083.nc'
regridder_eORCA0083.to_netcdf(fn)

In [None]:
fn = '/dssgfs01/working/jdha/regridder_eORCA025_to_eORCA0083.nc'
regridder_eORCA0083_loaded=xe.Regridder(ds_src_grid, ds_dst_grid, 'bilinear', periodic=True, ignore_degenerate=True, unmapped_to_nan=True, weights=fn)

In [None]:
# numpy 2.0 work around
regridder_eORCA0083_loaded.shape_in = tuple(map(int, regridder_eORCA0083_loaded.shape_in))
regridder_eORCA0083_loaded.shape_out = tuple(map(int, regridder_eORCA0083_loaded.shape_out))

In [None]:
# regridding temperature and salinity
ds_toce = regridder_eORCA0083_loaded(ds_src_data.thetao_con[0,:,:,:]).to_dataset(name='toce')
ds_soce = regridder_eORCA0083_loaded(ds_src_data.so_abs[0,:,:,:]).to_dataset(name='so')

In [None]:
ds = xr.merge([ds_toce, ds_soce])

In [27]:
ds=ds.chunk({"deptht": 75, "lat": 128, "lon": 128})

In [30]:
ds = ds.drop_vars({"time_centered","time_counter"})

In [31]:
# Define compression settings for each variable
compression = {"zlib": True, "complevel": 1}  # Enable deflation with level 5 compression
encoding = {
    "toce": compression,
    "so": compression,
}

# Write to a compressed NetCDF file
ds.to_netcdf("ICs_y1979m01.nc", encoding=encoding)