# Exploring and regridding the EN4 Salinity dataset
downloaded 11.2.2020 from https://www.metoffice.gov.uk/hadobs/en4/download-en4-2-1.html
via download.txt, then unzipped, downloaded 2020.12 manually

In [None]:
import sys
import dask
import xesmf as xe
import numpy as np
import xarray as xr
import cartopy.crs as ccrs
import matplotlib
import matplotlib.pyplot as plt

In [None]:
%matplotlib inline
matplotlib.rc_file('../rc_file')
%config InlineBackend.print_figure_kwargs={'bbox_inches':None}
%load_ext autoreload
%autoreload 2

In [None]:
# %%time
# %cd /projects/0/prace_imau/prace_2013081679/andre/EN4
# !pwd
# !wget -i download.txt
# for y in np.arange(1990,2020):
#     !unzip EN.4.2.1.analyses.g10.{y}.zip
#     !rm EN.4.2.1.analyses.g10.{y}.zip

In [None]:
sys.path.append('..')
from paths import path_prace, file_ex_ocn_ctrl, file_ex_ocn_lpd

In [None]:
do = xr.open_dataset(file_ex_ocn_ctrl, decode_times=False)
dl = xr.open_dataset(file_ex_ocn_lpd, decode_times=False)

In [None]:
# da = xr.open_mfdataset(f'{path_prace}/EN4/EN.4.2.1.f.analysis.g10.*.nc', chunks={'depth':42})
# da_mean = da.mean(dim='time').compute()
# da_mean.to_netcdf(f'{path_prace}/EN4/EN4_mean.nc')
da_mean = xr.open_dataset(f'{path_prace}/EN4/EN4_mean.nc')

In [None]:
X, Y = np.meshgrid(da_mean.lon, da_mean.lat)
plt.contourf(X, Y, da_mean.salinity.isel(depth=0), vmin=30, vmax=40, levels=40)
plt.colorbar()
plt.title(r'EN4 surface salinity')
plt.xlabel(r'longitude [$^\circ$E]')
plt.ylabel(r'latitude [$^\circ$N]')

In [None]:
X, Y = np.meshgrid(da_mean.lat, -da_mean.depth/1e3)
im = plt.contourf(X, Y, da_mean.salinity.sel(lon=330), vmin=34, vmax=37.5, levels=28)
plt.colorbar(im)
plt.title(r'EN4 transect 30$^\circ$W')
plt.xlabel(r'latitude [$^\circ$N]')
plt.ylabel('depth [km]')

In [None]:
X, Y = np.meshgrid(da_mean.lon, -da_mean.depth/1e3)
plt.contourf(X, Y, da_mean.salinity.sel(lat=-34), vmin=34, vmax=37.5, levels=28)
plt.colorbar()
plt.title(r'EN4 transect 34$^\circ$S')
plt.xlabel(r'latitude [$^\circ$E]')
plt.ylabel('depth [km]')

In [None]:
plt.title('depths')
plt.plot(np.arange(42), da_mean.depth)
plt.plot(np.arange(42), do.z_t/1e2)

### interpolating to high and low res CESM ocean grids
1. rectangular 1x1 grid of EN4 to high res grid
2. interpolate along depth to high and low res depth coordinates

In [None]:
da_mean

In [None]:
# adding lat/lon bounds
da_mean = da_mean.assign_coords({'lon':da.lon-.5})
da_mean = da_mean.assign_coords({'lat':da.lat+.5})
da_mean = da_mean.assign_coords({'lon_b':np.arange(0,361)})
da_mean = da_mean.assign_coords({'lat_b':np.arange(-83,91)})

### horizontal interpolation

In [None]:
%%time
# 1x1 to high; 3min 52s
regridder_high = xe.Regridder(da_mean.salinity, do.SALT.rename({'TLAT':'lat','TLONG':'lon'}), 'bilinear', periodic=True)
# 1x1 to low; 3.84 s
regridder_low = xe.Regridder(da_mean.salinity, dl.SALT.rename({'TLAT':'lat','TLONG':'lon'}), 'bilinear', periodic=True)

In [None]:
%%time
# high: 26.6 s
# low 21.3 s
da_mean_salt_high = regridder_high(da_mean.salinity).rename({'lat':'TLAT', 'lon':'TLONG'})
da_mean_salt_low  = regridder_low (da_mean.salinity).rename({'lat':'TLAT', 'lon':'TLONG'})

In [None]:
%%time
# high: 26.6 s
# low 21.3 s
da_mean_temp_high = regridder_high(da_mean.temperature).rename({'lat':'TLAT', 'lon':'TLONG'})
da_mean_temp_low  = regridder_low (da_mean.temperature).rename({'lat':'TLAT', 'lon':'TLONG'})

### vertical interpolation

In [None]:
%%time
# 16.9 s
salinity_high = da_mean_salt_high.interp(depth=do.z_t/1e2)
salinity_low  = da_mean_salt_low .interp(depth=dl.z_t/1e2)
salinity_high.values[0,:,:] = da_mean_salt_high.isel(depth=0).values
salinity_low .values[0,:,:] = da_mean_salt_low .isel(depth=0).values
salinity_high.values[-3,:,:] = da_mean_salt_high.isel(depth=-1).values
salinity_low .values[-1,:,:] = da_mean_salt_low .isel(depth=-1).values

In [None]:
%%time
# 16.9 s
temp_high = da_mean_temp_high.interp(depth=do.z_t/1e2)
temp_low  = da_mean_temp_low .interp(depth=dl.z_t/1e2)
temp_high.values[0,:,:]  = da_mean_temp_high.isel(depth=0).values
temp_low .values[0,:,:]  = da_mean_temp_low .isel(depth=0).values
temp_high.values[-3,:,:] = da_mean_temp_high.isel(depth=-1).values
temp_low .values[-1,:,:] = da_mean_temp_low .isel(depth=-1).values

In [None]:
salinity_high.to_netcdf(f'{path_prace}/EN4/EN4_mean_salinity_high.nc')
salinity_low .to_netcdf(f'{path_prace}/EN4/EN4_mean_salinity_low.nc' )

In [None]:
(temp_high-273.15).to_netcdf(f'{path_prace}/EN4/EN4_mean_temperature_high.nc')
(temp_low -273.15).to_netcdf(f'{path_prace}/EN4/EN4_mean_temperature_low.nc' )

### inspecting fields

In [None]:
da_mean_temp_low.isel(depth=0).plot()

In [None]:
da_mean_salt_low.isel(depth=0).plot()

In [None]:
da_mean_salt_high.isel(depth=0).plot()

In [None]:
f = plt.figure()
ax = f.add_subplot(111, projection=ccrs.PlateCarree(central_longitude=0))
ax.pcolormesh(dl.TLONG, dl.TLAT, da_mean_salt_low.where(dl.REGION_MASK>0).isel(depth=0),
               transform=ccrs.PlateCarree(), vmin=33, vmax=38)

In [None]:
f = plt.figure()
ax = f.add_subplot(111, projection=ccrs.PlateCarree(central_longitude=0))
ax.pcolormesh(do.TLONG, do.TLAT, da_mean_salt_high.where(do.REGION_MASK>0).isel(depth=0),
               transform=ccrs.PlateCarree(), vmin=33, vmax=38)

In [None]:
plt.figure(figsize=(12,1))
plt.scatter(da_mean.depth, [0]*42, marker='|')
plt.scatter(dl.z_t/1e2   , [1]*60, marker='|')
plt.scatter(do.z_t/1e2   , [2]*42, marker='|')
plt.xlim(-10,6000)
# plt.xlim(-1,6)

### comparing a vertical profile

In [None]:
profile = da_mean.salinity.sel({'lat':30, 'lon':330}, method='nearest')
profile_high = profile.interp(depth=do.z_t/1e2)
profile_low = profile.interp(depth=dl.z_t/1e2)

In [None]:
plt.scatter(profile, -profile.depth/1e3, marker='x', s=10)
plt.scatter(profile_high+.1, -profile_high.depth/1e3, marker='x', s=10)
plt.scatter(profile_low+.2, -profile_low.depth/1e3, marker='x', s=10)

plt.scatter(salinity_high.sel(nlon=800, nlat=1500)+.3, -salinity_high.z_t/1e5)
# plt.scatter(salinity_low.sel(nlon=800, nlat=1500)+.3, -salinity_low.z_t/1e5)
# plt.scatter(da_mean.salinity.sel({'lat':30, 'lon':330}, method='nearest'), -da_mean.depth/1e3, marker='x', s=10)