# Regridding scalar fields from 0.1 to 04 degrees (`ocn` to `ocn_rect`)
This notebook regrids scalar data from the high resolution grid of 0.1 degrees to the low resolution rectangular grid at 0.4 degrees.
We use the xESMF `conservative` option.

In [None]:
import os
import sys
sys.path.append("..")
from shutil import copyfile
import numpy as np
import xesmf as xe
import xarray as xr
import matplotlib
import matplotlib.pyplot as plt

In [None]:
%config InlineBackend.print_figure_kwargs={'bbox_inches':None}
%matplotlib inline
%load_ext autoreload
%autoreload 2
matplotlib.rc_file('../rc_file')

In [None]:
from paths import path_samoc, file_ex_ocn_ctrl, file_ex_ocn_rect
from grid import generate_lats_lons

In [None]:
ds      = xr.open_dataset('/projects/0/samoc/andre/CESM/ctrl/ocn_yrly_TEMP_PD_0001.nc')
ds_ocn  = xr.open_dataset(file_ex_ocn_ctrl, decode_times=False)
ds_rect = xr.open_dataset(file_ex_ocn_rect, decode_times=False)

# using `nccurv2ncrect.sc` from Michael
`src/regrid/regrid_yrly_TEMP_PD.py`

In [None]:
ds1.TEMP[0,:,:].plot(vmin=0,vmax=30)

In [None]:
ds2 = xr.open_dataset(f'{path_samoc}/ctrl_rect/TEMP_new_test.interp900x602.nc')

In [None]:
ds2.TEMP[0,:,:].plot(vmin=0,vmax=30)

In [None]:
plt.figure()
for t in np.arange(1,301):
    fn = f'{path_samoc}/ctrl_rect/TEMP_PD_yrly_{t:04d}.interp900x602.nc'
    try:
        assert os.path.exists(fn)
        da = xr.open_dataset(fn, decode_times=False).TEMP[0,:,:]
        plt.scatter(t, da.mean())
    except:
        pass

## xESMF (does not work)

In [None]:
ds_rect.coords

need to copy `t_lon`, `t_lat`, and `depth_t` coordinates to new file, retain time coordinate

### renaming coordinates

In [None]:
ds      = ds     .rename({'TLAT' : 'lat', 'TLONG': 'lon'})
ds_rect = ds_rect.rename({'t_lat': 'lat', 't_lon': 'lon'})

### filling in the lat/lon values into missing continents & create bounding lats/lons

In [None]:
lats,lons = generate_lats_lons('ocn')
ds['lat'].values = lats
ds['lon'].values = lons

In [None]:
ds.lon.values

In [None]:
def add_bounding_lat_lon(ds):
    """ adding lon_b and lat_b to """
    
    # rectangular grid
    if np.ndim(ds.lat)==1 and np.ndim(ds.lon)==1:
        lon_b = ds.lon.values-0.2
        lon_b = np.append(lon_b, 359.8)

        lats = ds.lat.values
        lat_b = (lats[:-1]+lats[1:])/2
        lat_S = np.array([lats[0] - (lats[1]-lats[0])/2])
        lat_N = np.array([89.99])
        lat_b = np.concatenate([lat_S, lat_b, lat_N])
        
        ds.expand_dims(['nlon_b', 'nlat_b'])
        ds['lon_b'] = ('nlon_b', lon_b)
        ds['lat_b'] = ('nlat_b', lat_b)
        
    # rectilinear grid
    elif np.ndim(ds.lat)==2 and np.ndim(ds.lon)==2:
        lons = ds.lon.values
        lon_S = np.array((lons[0,:],))
        print(np.shape(lons), np.shape(lon_S))
        lons = np.concatenate([lon_S, lons], axis=0)
        
        
        lon_b = (lons[:,1:]+lons[:,:-1])/2
        lon_W = np.array((lons[:, 0] - (lons[:,1]-lons[:,0])/2,)).T
        lon_E = np.array((lons[:,-1] - (lons[:,-1]-lons[:,-2])/2,)).T
        lon_b = np.concatenate([lon_W, lon_b, lon_E], axis=1)
        
        lats = ds.lat.values
        lat_W = np.array((lats[:,0],)).T
        print(np.shape(lats), np.shape(lat_W))
        lats = np.concatenate([lat_W, lats], axis=1)
        
        lat_b = (lats[1:,:]+lats[:-1,:])/2
        lat_S = np.array((lats[ 0,:] - (lats[ 1,:]-lats[ 0,:])/2,))
        lat_N = np.array((lats[-1,:] - (lats[-1,:]-lats[-2,:])/2,))
        print(np.shape(lats), np.shape(lat_S), np.shape(lat_N))
        lat_b = np.concatenate([lat_S, lat_b, lat_N], axis=0)
        
#         plt.plot(lat_b[100,:])
#         plt.figure()
#         plt.plot(lon_b[:,100])
        
        ds.expand_dims(['nlon_b', 'nlat_b'])

        ds['lon_b'] = (['nlat_b', 'nlon_b'], lon_b)
        ds['lat_b'] = (['nlat_b', 'nlon_b'], lat_b)

#     ds = ds.assign_coords(lon_b=ds.lon_b)
#     ds = ds.assign_coords(lat_b=ds.lat_b)
    
    return ds

ds_rect = add_bounding_lat_lon(ds_rect)
ds = add_bounding_lat_lon(ds)

In [None]:
ds.lat_b.plot()

In [None]:
ds.lon_b.plot()

The problem for the 'conservative' option appears to be the non-monotonic lon of the grid, as exemplified by the failure of pcolormesh to handle the input.

In [None]:
plt.pcolormesh(ds.lon_b[1:,1:], ds.lat_b[1:,1:], ds.TEMP[0,:,:])

In [None]:
np.any(ds.lat_b>90.)

### creating regridder object

In [None]:
%%time
# 40 sec
regridder_rect = xe.Regridder(ds, ds_rect, 'conservative', reuse_weights=True)#, periodic=True)

### regridding

In [None]:
da_rect = regridder_rect(ds).astype('int64').rename({'lat': 't_lat', 'lon': 't_lon'})

## Testing whether OHC is the same