## Regridding MOM6 output to arbitrary grid

This noteboook shows how to take an existing MOM6 output and regrid it to an arbitrary grid. It uses MOM6 'native' output file (see [here](https://mom6.readthedocs.io/en/dev-gfdl/api/generated/pages/Diagnostics.html#the-diag-table) and [here](https://github.com/CROCODILE-CESM/MOM_interface/blob/main/param_templates/diag_table.yaml) for more information about it), and the python packages [xesmf](https://xesmf.readthedocs.io/en/stable/) (for horizontal regridding) and [xgcm](https://xgcm.readthedocs.io/en/latest/index.html) (for vertical regridding).

We start by importing the necessary modules and reading in MOM6 output:

In [None]:
import xarray as xr
import xesmf as xe
import xgcm
import numpy as np
import matplotlib.pyplot as plt

mom6_ds = xr.open_dataset("./NWA-1deg-2007-2008.mom6.h.z._0001.2007-01.nc")
mom6_ds = mom6_ds.isel(time=2) # only select the first day

### Horizontal regridding

Assuming you are regridding from a given mesh to a lower resolution grid, it is generally faster to regrid in order of dimension that coarsens the data the most. This reduces the amount early on, so that less coarse regridding steps operate on smaller datasets (and are thus more efficient).

We then start with the horizontal regridding, as it operates on two dimensions at once, but if you are coarsening more aggressively along the vertical direction, you'll want to reverse the order.

For the horizontal regridding we need the target grid `grid_out` and the input grid. Note that xesmf requires them to have `lat` and `lon` variables. We also need to take into account that the variables do not all share the same grid: for example, temperature `thetao` is measured at `xh`,`yh`, but the X-velocity is measured at `xq`,`yh`. We then need to define multiple input grids for the regridding method.

In [None]:
roms_ds = xr.open_dataset("./ROMS_20060102_5depth.nc")

roms_lat = roms_ds.lat_rho.values[:,0]
roms_lon = roms_ds.lon_rho.values[0,:]
mom6_lat_yh = mom6_ds.yh.values
mom6_lon_xh = mom6_ds.xh.values
mom6_lat_yq = mom6_ds.yq.values
mom6_lon_xq = mom6_ds.xq.values

max_lat = np.min(
        [np.max(roms_lat), np.max(mom6_lat_yh), np.max(mom6_lat_yq)]
)
max_lon = np.min(
        [np.max(roms_lon), np.max(mom6_lon_xh), np.max(mom6_lon_xq)]
)
min_lat = np.max(
        [np.min(roms_lat), np.min(mom6_lat_yh), np.min(mom6_lat_yq)]
)
min_lon = np.max(
        [np.min(roms_lon), np.min(mom6_lon_xh), np.min(mom6_lon_xq)]
)

eta_rho_max_lat = np.where(roms_ds['lat_rho'][:,0].values <= max_lat)[0].max()
eta_rho_min_lat = np.where(roms_ds['lat_rho'][:,0].values >= min_lat)[0].min()
xi_rho_max_lon = np.where(roms_ds['lon_rho'][0,:].values <= max_lon)[0].max()
xi_rho_min_lon = np.where(roms_ds['lon_rho'][0,:].values >= min_lon)[0].min()

subset_roms = roms_ds.isel(
    eta_rho=slice(eta_rho_min_lat, eta_rho_max_lat),
    xi_rho= slice( xi_rho_min_lon,  xi_rho_max_lon)
)

yh_max_lat = np.where(mom6_ds['yh'].values <= max_lat)[0].max()
yh_min_lat = np.where(mom6_ds['yh'].values >= min_lat)[0].min()
xh_max_lon = np.where(mom6_ds['xh'].values <= max_lon)[0].max()
xh_min_lon = np.where(mom6_ds['xh'].values >= min_lon)[0].min()

# print( (yh_max_lat, yh_min_lat, xh_max_lon, xh_min_lon) )

subset_mom6 = mom6_ds.isel(
    yh= slice(yh_min_lat, yh_max_lat),
    xh= slice(xh_min_lon, xh_max_lon)
)

grid_out = xr.Dataset(
    {
        "lat": (["lat"], subset_roms.lat_rho.values[:,0], {"units": "degrees_north"}),
        "lon": (["lon"], subset_roms.lon_rho.values[0,:], {"units": "degrees_east"}),
    }
)

In [None]:
# xh,yh regridder
grid_in = {'lon': mom6_ds['xh'], 'lat': mom6_ds['yh']}
regridder_hh = xe.Regridder(grid_in, grid_out, "bilinear",extrap_method=None)
mom6_hh_regridded = regridder_hh(mom6_ds) # it automatically only regrids variables with both xh and yh dimensions

# # xh,yq regridder
# grid_in = {'lon': mom6_ds_hq['xh'], 'lat': mom6_ds_hq['yq']}
# regridder_hq = xe.Regridder(grid_in, grid_out, "bilinear",extrap_method=None)
# mom6_hq_regridded = regridder_hq(mom6_ds_hq) # it automatically only regrids variables with both xh and yq dimensions

# # xq,yh regridder
# grid_in = {'lon': mom6_ds_qh['xq'], 'lat': mom6_ds_qh['yh']}
# regridder_qh = xe.Regridder(grid_in, grid_out, "bilinear",extrap_method=None)
# mom6_qh_regridded = regridder_qh(mom6_ds_qh) # it automatically only regrids variables with both xq and yh dimensions

Comparing temperature data before and after regridding:

In [None]:
temp_orig = mom6_ds.thetao.isel(z_l=0)
temp_regr = mom6_hh_regridded.thetao.isel(z_l=0)
temp_roms = subset_roms.temp.isel(ocean_time=0, depth=0)

X_orig, Y_orig = np.meshgrid(temp_orig['xh'], temp_orig['yh'])
X_regr, Y_regr = np.meshgrid(temp_regr['lon'], temp_regr['lat'])
# X_roms, Y_roms = np.meshgrid(temp_roms['lon_rho'], temp_roms['lat_rho'])

fig, axes = plt.subplots(1, 4, figsize=(15, 5), sharey=True)

sc1 = axes[0].pcolormesh(X_orig, Y_orig, temp_orig.values, cmap='viridis', vmin=-5, vmax=30)
axes[0].set_title('MOM6')
axes[0].set_xlabel('Longitude')
axes[0].set_ylabel('Latitude')
fig.colorbar(sc1, ax=axes[0], orientation='vertical', label='thetao at z*=15.59m')

sc2 = axes[1].pcolormesh(X_regr, Y_regr, temp_roms.values, cmap='viridis', vmin=-5, vmax=30)
axes[1].set_title('ROMS')
axes[1].set_xlabel('Longitude')
fig.colorbar(sc2, ax=axes[1], orientation='vertical', label='thetao at z*=15.59m')

sc3 = axes[2].pcolormesh(X_regr, Y_regr, temp_regr.values, cmap='viridis', vmin=-5, vmax=30)
axes[2].set_title('MOM6_REGRIDDED')
axes[2].set_xlabel('Longitude')
fig.colorbar(sc3, ax=axes[2], orientation='vertical', label='thetao at z*=15.59m')

sc4 = axes[3].pcolormesh(X_regr, Y_regr, (temp_roms.values-temp_regr.values), cmap='viridis', vmin=-10, vmax=10)
axes[3].set_title('ROMS-MOM6_REGRIDDED')
axes[3].set_xlabel('Longitude')
fig.colorbar(sc4, ax=axes[3], orientation='vertical', label='thetao at z*=15.59m')


plt.tight_layout()
plt.show()

### Vertical regridding

We will first regrid the original data, then we will regrid the data that were already subsampled horizontally.

To regrid the data vertically, we define a `xgcm` vertical grid specificying the positions of the coordinate in the MOM6 discretization approach (see [here](https://xgcm.readthedocs.io/en/latest/api.html#grid) for xgcm manual and [here](https://mom6.readthedocs.io/en/dev-gfdl/api/generated/pages/Discrete_Grids.html) for MOM6 gridding).

In [None]:
mom6_grid_z = xgcm.Grid(
    mom6_ds,
    coords={'Z': {'inner': 'z_l', 'outer': 'z_i'} }, 
    periodic=False
)

We then define the target array, i.e. the values at which we want to interpolate the fields that have a vertical coordinate. We can then transform each variable using [`xgcm.Grid`'s tranform method](https://xgcm.readthedocs.io/en/latest/transform.html) and generate a dataset with the regridded fields:

In [None]:
z_target = np.array([2.5,500,1000,1500,2000])
thetao_transformed = mom6_grid_z.transform(mom6_ds.thetao, 'Z', z_target)

# regridding original data
regridded_z = xr.Dataset()
for var_name, da in mom6_ds.data_vars.items():
    if 'z_l' in da.dims:
        print(f"Re-gridding {var_name}")
        regridded_z[var_name] = mom6_grid_z.transform(da, 'Z', z_target)
print('Done.')

# regridding subsampled data
regridded_zh = xr.Dataset()
for var_name, da in mom6_hh_regridded.data_vars.items():
    if 'z_l' in da.dims:
        print(f"Re-gridding {var_name}")
        regridded_zh[var_name] = mom6_grid_z.transform(da, 'Z', z_target)
print('Done.')

#### Profile plots

Plotting temperature profile at nearest to (-70,30)

In [None]:
target_x = -70
target_y =  30

temp_orig = mom6_ds.thetao.sel(
    xh = target_x,
    yh = target_y,
    method="nearest"
)

temp_regr_z = regridded_z.thetao.sel(
    xh = target_x,
    yh = target_y,
    method="nearest"
)

temp_regr_zh = regridded_zh.thetao.sel(
    lon = target_x,
    lat = target_y,
    method="nearest"
)

colors = ['y','g','b']
fig, axes = plt.subplots(1, 4, figsize=(15, 5), sharey=True)
axes[0].set_ylim(0, 8000)
axes[0].invert_yaxis()

for j, temp in enumerate([temp_orig, temp_regr_z, temp_regr_zh]):
    axes[0].plot(temp.values, temp['z_l'], color=colors[j])

plt1 = axes[1].plot(temp_orig.values, temp_orig['z_l'], marker='o', color=colors[0])
plt2 = axes[2].plot(temp_regr_z.values, temp_regr_z['z_l'], marker='o', color=colors[1])
plt3 = axes[3].plot(temp_regr_zh.values, temp_regr_zh['z_l'], marker='o', color=colors[2])

for j in range(4):
    axes[j].set_xlim(0, 25)

plt.tight_layout()
plt.show()

#### Contour plots

Plotting temperature values at level nearest to 500m depth.

In [None]:
target_depth = 0

temp_orig = mom6_ds.thetao.sel(
    z_l = target_depth,
    method="nearest"
)

temp_regr_zh = regridded_zh.thetao.sel(
    z_l = target_depth,
    method="nearest"
)

temp_roms = subset_roms.temp.isel(ocean_time=0).sel(
    depth=target_depth,
    method="nearest"
)

Xo,Yo     = np.meshgrid(temp_orig['xh'], temp_orig['yh'])
Xzh, Yzh  = np.meshgrid(temp_regr_zh['lon'], temp_regr_zh['lat'])

fig, axes = plt.subplots(1, 2, figsize=(15, 5), sharey=False, constrained_layout=True)

sc0 = axes[0].pcolormesh(Xo, Yo, temp_orig.values, cmap='viridis', vmin=-5, vmax=25)
axes[0].set_title('MOM6')
axes[0].set_xlabel('Longitude')
axes[0].set_ylabel('Latitude')
fig.colorbar(sc0, ax=axes, orientation='vertical', label='thetao at z*='+str(target_depth)+'m')
axes[0].set_xlim(np.min(Xo), np.max(Xo))
axes[0].set_ylim(np.min(Yo), np.max(Yo))

sc1 = axes[1].pcolormesh(Xzh, Yzh, temp_roms.values, cmap='viridis', vmin=-5, vmax=25)
axes[1].set_title('ROMS')
axes[1].set_xlabel('Longitude')
axes[1].set_xlim(np.min(Xzh), np.max(Xzh))
axes[1].set_ylim(np.min(Yzh), np.max(Yzh))

# sc2 = axes[2].pcolormesh(Xzh, Yzh, temp_regr_zh.values, cmap='viridis', vmin=-5, vmax=25)
# axes[2].set_title('MOM6 regridded')
# axes[2].set_xlabel('Longitude')

# sc3 = axes[3].pcolormesh(Xzh, Yzh, (temp_roms.values-temp_regr_zh.values), cmap='bwr', vmin=-10, vmax=10)
# axes[3].set_title('ROMS-MOM6 regridded')
# axes[3].set_xlabel('Longitude')
# fig.colorbar(sc3, ax=axes, orientation='vertical', label='delta thetao at z*=500m')

for ax in axes:
    ax.set_aspect('equal')  # or a number, or 'equal'
#     # ax.set_xlim(-100,-35)
    # ax.set_ylim(0,60)

plt.savefig("mom6-roms_01.png", transparent=True, dpi=300)
plt.show()

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 5), sharey=True, constrained_layout=True)

sc2 = axes[0].pcolormesh(Xzh, Yzh, temp_regr_zh.values, cmap='viridis', vmin=-5, vmax=25)
axes[0].set_title('MOM6 regridded')
axes[0].set_xlabel('Longitude')
fig.colorbar(sc2, ax=axes, orientation='vertical', label='thetao at z*='+str(target_depth)+'m')

sc3 = axes[1].pcolormesh(Xzh, Yzh, (temp_roms.values-temp_regr_zh.values), cmap='bwr', vmin=-10, vmax=10)
axes[1].set_title('ROMS-MOM6 regridded')
axes[1].set_xlabel('Longitude')
fig.colorbar(sc3, ax=axes, orientation='vertical', label='delta thetao at z*='+str(target_depth)+'m')

for ax in axes:
    ax.set_aspect('equal')  # or a number, or 'equal'
    ax.set_xlim(np.min(Xzh), np.max(Xzh))
    ax.set_ylim(np.min(Yzh), np.max(Yzh))

plt.savefig("mom6-roms_02.png", transparent=True, dpi=300)
plt.show()

In [None]:
target_depth = 500

temp_orig = mom6_ds.thetao.sel(
    z_l = target_depth,
    method="nearest"
)

temp_regr_zh = regridded_zh.thetao.sel(
    z_l = target_depth,
    method="nearest"
)

temp_roms = subset_roms.temp.isel(ocean_time=0).sel(
    depth=target_depth,
    method="nearest"
)

Xo,Yo     = np.meshgrid(temp_orig['xh'], temp_orig['yh'])
Xzh, Yzh  = np.meshgrid(temp_regr_zh['lon'], temp_regr_zh['lat'])

fig, axes = plt.subplots(1, 2, figsize=(15, 5), sharey=False, constrained_layout=True)

sc0 = axes[0].pcolormesh(Xo, Yo, temp_orig.values, cmap='viridis', vmin=-5, vmax=25)
axes[0].set_title('MOM6')
axes[0].set_xlabel('Longitude')
axes[0].set_ylabel('Latitude')
fig.colorbar(sc0, ax=axes, orientation='vertical', label='thetao at z*='+str(target_depth)+'m')
axes[0].set_xlim(np.min(Xo), np.max(Xo))
axes[0].set_ylim(np.min(Yo), np.max(Yo))

sc1 = axes[1].pcolormesh(Xzh, Yzh, temp_roms.values, cmap='viridis', vmin=-5, vmax=25)
axes[1].set_title('ROMS')
axes[1].set_xlabel('Longitude')
axes[1].set_xlim(np.min(Xzh), np.max(Xzh))
axes[1].set_ylim(np.min(Yzh), np.max(Yzh))

for ax in axes:
    ax.set_aspect('equal')

plt.savefig("mom6-roms_03.png", transparent=True, dpi=300)
plt.show()

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 5), sharey=True, constrained_layout=True)

sc2 = axes[0].pcolormesh(Xzh, Yzh, temp_regr_zh.values, cmap='viridis', vmin=-5, vmax=25)
axes[0].set_title('MOM6 regridded')
axes[0].set_xlabel('Longitude')
fig.colorbar(sc2, ax=axes, orientation='vertical', label='thetao at z*='+str(target_depth)+'m')

sc3 = axes[1].pcolormesh(Xzh, Yzh, (temp_roms.values-temp_regr_zh.values), cmap='bwr', vmin=-10, vmax=10)
axes[1].set_title('ROMS-MOM6 regridded')
axes[1].set_xlabel('Longitude')
fig.colorbar(sc3, ax=axes, orientation='vertical', label='delta thetao at z*='+str(target_depth)+'m')

for ax in axes:
    ax.set_aspect('equal')  # or a number, or 'equal'
    ax.set_xlim(np.min(Xzh), np.max(Xzh))
    ax.set_ylim(np.min(Yzh), np.max(Yzh))

plt.savefig("mom6-roms_04.png", transparent=True, dpi=300)
plt.show()