In [None]:
!pip install gcm-filters -q

In [None]:
!pip install gsw -q

In [None]:
!pip install numpy_groupies -q

# Dask start

In [None]:
from dask.distributed import Client

client = Client("tcp://127.0.0.1:36939")
client

In [None]:
from dask.distributed import Client, progress

# Import

In [None]:
import s3fs
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import gcm_filters
import gsw
from xhistogram.xarray import histogram
import eos_module
from dask import compute
from dask.diagnostics import ProgressBar
from dask import delayed
import xrft_module
import pandas as pd
import xesmf as xe
import cartopy.crs as ccrs

In [None]:
# import imp
# imp.reload(xrft_module)

# Load data

Load Pangeo catalog

In [None]:
from validate_catalog import all_params
params_dict, cat = all_params()
params_dict.keys()

Winter data:

In [None]:
enatl01w = cat.eNATL60(region='1',datatype='surface_hourly', season='fma').to_dask().chunk({'time_counter':100, 'x' : -1, 'y' : -1})

Summer data:

In [None]:
enatl01s = cat.eNATL60(region='1',datatype='surface_hourly', season='aso').to_dask().chunk({'time_counter':100, 'x' : -1, 'y' : -1})

In [None]:
selected_data_set = enatl01s

In [None]:
selected_data_set

# Prepare data for filter

## Slicing

In [None]:
selected_data_set.time_counter.size

In [None]:
# x_slice = slice(400,700)
# y_slice = slice(500,800)

In [None]:
x_slice = slice(0,-1)
y_slice = slice(0,-1)

In [None]:
data_slice = dict(time_counter=slice(0,-1),y=y_slice,x=x_slice)

In [None]:
coord_slice = dict(y=y_slice,x=x_slice)

In [None]:
fig, ax = plt.subplots(ncols=1, constrained_layout=True, figsize=(5,4), subplot_kw={'projection': ccrs.PlateCarree()})
selected_data_set.isel(time_counter=0).sosstsst[coord_slice].plot(x='lon', y='lat', ax=ax)
#ax.coastlines()

## Grid

In [None]:
area = xr.DataArray((selected_data_set.e1t[coord_slice] * selected_data_set.e2t[coord_slice]), dims=['y','x'],
                   coords={'y':np.arange(len(selected_data_set.y[y_slice])),
                           'x':np.arange(len(selected_data_set.x[x_slice]))}
                  ).load()
# area.plot()

Spacings:

In [None]:
dxw = xr.DataArray(selected_data_set.e1u[coord_slice], dims=['y','x'],
                   coords={'y':np.arange(len(selected_data_set.y[y_slice])),
                           'x':np.arange(len(selected_data_set.x[x_slice]))}
                  ).load() # x-spacing centered at western cell edge

dyw = xr.DataArray(selected_data_set.e2u[coord_slice], dims=['y','x'],
                   coords={'y':np.arange(len(selected_data_set.y[y_slice])),
                           'x':np.arange(len(selected_data_set.x[x_slice]))}
                  ).load() # y-spacing centered at western cell edge

dxs = xr.DataArray(selected_data_set.e1v[coord_slice], dims=['y','x'],
                   coords={'y':np.arange(len(selected_data_set.y[y_slice])),
                           'x':np.arange(len(selected_data_set.x[x_slice]))}
                  ).load() # x-spacing centered at southern cell edge

dys = xr.DataArray(selected_data_set.e2v[coord_slice], dims=['y','x'],
                   coords={'y':np.arange(len(selected_data_set.y[y_slice])),
                           'x':np.arange(len(selected_data_set.x[x_slice]))}
                  ).load() # y-spacing centered at southern cell edge

In [None]:
dx_min = min(dxw.min(),dyw.min(),dxs.min(),dys.min())
dx_min = dx_min.values
dx_min

Diffusivities for thefilter in the east-west and north-south directions

In [None]:
kappa_w = xr.ones_like(dxw)
kappa_s = xr.ones_like(dxw)

## Mask

In [None]:
wet_mask = xr.DataArray(selected_data_set.umask[coord_slice],
                        dims=['y','x'],
                        coords={'y':np.arange(len(selected_data_set.y[y_slice])),
                           'x':np.arange(len(selected_data_set.x[x_slice]))}).load()

In [None]:
wet_mask.plot()

## High-resolution dataset

In [None]:
dataHR = xr.Dataset({"sst" : selected_data_set.sosstsst[data_slice],
                     "sal" : selected_data_set.sosaline[data_slice],
                     "wet_mask" : wet_mask},
                    coords=selected_data_set[data_slice].coords)

# Filter

## Define filter

Parameters (to determine) :

In [None]:
filter_scale = 30000

In [None]:
gaussian_filter = gcm_filters.Filter(
    filter_scale=filter_scale,    
    dx_min=dx_min,
    filter_shape=gcm_filters.FilterShape.GAUSSIAN,
    grid_type=gcm_filters.GridType.IRREGULAR_WITH_LAND,
    grid_vars={
        'wet_mask': wet_mask, 
        'dxw': dxw, 'dyw': dyw, 'dxs': dxs, 'dys': dys, 'area': area, 
        'kappa_w': kappa_w, 'kappa_s': kappa_s
    }
)

## Apply filter to data

In [None]:
# intermediate dataset: HR data + filtering (result has still the spatial resoltion if HR data)
data_filtered = xr.Dataset({}, coords=selected_data_set[data_slice].coords)

"Lazy" filter execution with Dask

In [None]:
data_filtered["sst"] = gaussian_filter.apply(dataHR.sst, dims=['y', 'x'])
data_filtered["sal"] = gaussian_filter.apply(dataHR.sal, dims=['y', 'x'])

# Spatial coarsening (downsampling)

In [None]:
coarsen_window_size=12

In [None]:
# def downsample(data, window_size) :
#     return data.coarsen(x=window_size, boundary='pad').mean().coarsen(y=window_size,boundary='pad').mean()*
def downsample(data, window_size, area_2d, wet_mask) :
    area_masked = area_2d*wet_mask
    data_weighted = area_masked*data
    data_weighted_sums = data_weighted.coarsen(x=coarsen_window_size, y=coarsen_window_size, boundary='trim', \
                                           coord_func={'x':'mean', 'y':'mean', 'lat':'mean', 'lon':'mean','e1t':'sum', 'e2t':'sum'}).sum(skipna=False)
    area_window_sums = area_masked.coarsen(x=coarsen_window_size, y=coarsen_window_size, boundary='trim').sum(skipna=False)
    result = data_weighted_sums/area_window_sums
    result['x'] = np.arange(len(result['x']))
    result['y'] = np.arange(len(result['y']))
    
    # coarsened coordinates:
    new_e1t = data['e1t'].coarsen(x=coarsen_window_size,boundary='trim', coord_func={'x':'sum', 'y':'sum','e1t':'sum', 'e2t':'sum'}).mean(skipna=False).coarsen(y=coarsen_window_size, boundary='trim', coord_func={'x':'sum', 'y':'sum','e1t':'sum', 'e2t':'sum'}).sum(skipna=False).\
drop(['depth', 'e1f', 'e1t', 'e1u', 'e1v', 'e2f', 'e2t', 'e2u', 'e2v', 'lat', 'lon', 'nav_lat', 'nav_lon'])
    new_e2t = data['e2t'].coarsen(x=coarsen_window_size,boundary='trim', coord_func={'x':'sum', 'y':'sum','e1t':'sum', 'e2t':'sum'}).mean(skipna=False).coarsen(y=coarsen_window_size, boundary='trim', coord_func={'x':'sum', 'y':'sum','e1t':'sum', 'e2t':'sum'}).sum(skipna=False).\
drop(['depth', 'e1f', 'e1t', 'e1u', 'e1v', 'e2f', 'e2t', 'e2u', 'e2v', 'lat', 'lon', 'nav_lat', 'nav_lon'])
    result['e1t'] = (['y','x'], new_e1t.values) 
    result['e2t'] = (['y','x'], new_e2t.values) 
    return result

In [None]:
dataLR = downsample(data_filtered, coarsen_window_size, area, wet_mask)

In [None]:
dataHR_ds = downsample(dataHR, coarsen_window_size, area, wet_mask)

In [None]:
# choose the index of time_counter for 2D spatial plots
it = [0]

In [None]:
fig, ax = plt.subplots(ncols=4, constrained_layout=True, figsize=(18,4))
fig.suptitle('Temperature')
(dataHR.isel(time_counter=it)).sst.plot(ax=ax[0],x='lon', y='lat')
ax[0].set(title=r'Unfiltered')
(data_filtered.isel(time_counter=it)).sst.plot(ax=ax[1],x='lon', y='lat')
#ax[1].pcolormesh(dataset_filtered.lon, dataset_filtered.lat, dataset_filtered.sst.compute())
ax[1].set(title=r'Filtered')
(dataHR_ds.isel(time_counter=it)).sst.plot(ax=ax[2],x='lon', y='lat')
ax[2].set(title=r'Coarsened')
(dataLR.isel(time_counter=it)).sst.plot(ax=ax[3], x='lon', y='lat')
ax[3].set(title=r'Filtered + Coarsened')
plt.show()

## ReGridder

In [None]:
#.where((ds.nav_lon < -68) & (ds.nav_lon > -78) & (ds.nav_lat < 40) & (ds.nav_lat > 30), drop=True)

In [None]:
data_regrid = xe.util.grid_global(0.25, 0.25)
data_regrid  # contains lat/lon values of cell centers and boundaries.

In [None]:
regridder = xe.Regridder(dataHR, data_regrid, "bilinear")
data_regrid = regridder(dataHR)

In [None]:
data_regrid = data_regrid.where((data_regrid.lon < -68) & (data_regrid.lon > -78) & (data_regrid.lat < 39) & (data_regrid.lat > 30), drop=True)

In [None]:
data_regrid.isel(time_counter=it).sst.plot()

# Spatial spectra

## Interpolation on a spatially uniform grid 

In [None]:
def uniform_grid_interpolation(data) :
    # first fill in nans
    data_filled_nans = data.interpolate_na(dim="x", method="linear")
    data_filled_nans = data_filled_nans.interpolate_na(dim="y", method="linear")
    data_filled_nans = data_filled_nans.interpolate_na(dim="x", method="nearest", fill_value="extrapolate")
    
    # transform coordinates to uniform spacings
    delta_x_min_m = data['e1t'].min().values
    delta_y_min_m = data['e2t'].min().values
    delta_min = np.min([delta_x_min_m, delta_y_min_m])
    new_x = np.cumsum(delta_min/data['e1t'], axis=1)
    new_y = np.cumsum(delta_min/data['e2t'], axis=0)
    interpolated_field = data_filled_nans.interp(x=new_x, y=new_y).\
drop(['depth', 'e1f', 'e1t', 'e1u', 'e1v', 'e2f', 'e2t', 'e2u', 'e2v', 'lat', 'lon', 'nav_lat', 'nav_lon'])
    interpolated_field.attrs['delta_min'] = delta_min
    return interpolated_field

In [None]:
fig, ax = plt.subplots(ncols=2, constrained_layout=True, figsize=(10,4))
(dataLR.isel(time_counter=it)).sst.plot(ax=ax[0])
(uniform_grid_interpolation(dataLR).isel(time_counter=it)).sst.plot(ax=ax[1])
ax[0].set(title='Original filtered data')
ax[1].set(title='Interpolated and filled data')
plt.show()

In [None]:
dataLR

## Isotropic power spectra

In [None]:
def wrapper_xrtf_isotropic_PS(data, variable, binning_factor=4) :
    interpolated_data = uniform_grid_interpolation(data[variable])
    isotropic_power_spec = xrft_module.isotropic_power_spectrum(interpolated_data, dim=['x','y'], detrend='linear',scaling="spectrum", nfactor=4, truncate=True)
    Nmin = min([len(data['x']), len(data['y'])])
    isotropic_power_spec.coords['wavenumber'] = ('freq_r', isotropic_power_spec['freq_r']/(Nmin*interpolated_data.attrs['delta_min']*0.001))
    return isotropic_power_spec.rename('Spatial spectrum')

In [None]:
# spec_temp_variance_HR = wrapper_xrtf_isotropic_PS(dataHR, 'sst', binning_factor = 4)
# %time mean_spec_temp_variance_HR = spec_temp_variance_HR.mean(dim='time_counter').compute()

In [None]:
# spec_temp_variance_filtered = wrapper_xrtf_isotropic_PS(data_filtered, 'sst', binning_factor = 4)
# %time mean_spec_temp_variance_filtered = spec_temp_variance_filtered.mean(dim='time_counter').compute()

In [None]:
# spec_temp_variance_LR = wrapper_xrtf_isotropic_PS(dataLR, 'sst', binning_factor = 2)
# %time mean_spec_temp_variance_LR = spec_temp_variance_LR.mean(dim='time_counter').compute()

In [None]:
spec_temp_variance_HR_ds = wrapper_xrtf_isotropic_PS(dataHR_ds, 'sst', binning_factor = 2)
%time mean_spec_temp_variance_HR_ds = spec_temp_variance_HR_ds.mean(dim='time_counter').compute()

In [None]:
# mean_spec_temp_variance_LR.to_netcdf("Data/region1_filter_30km_coarse_12/mean_spec_temp_variance_LR_summer.nc")
# mean_spec_temp_variance_filtered.to_netcdf("Data/region1_filter_30km_coarse_12/mean_spec_temp_variance_filtered_summer.nc")
# mean_spec_temp_variance_HR.to_netcdf("Data/region1_filter_30km_coarse_12/mean_spec_temp_variance_HR_summer.nc")
mean_spec_temp_variance_HR_ds.to_netcdf("Data/region1_filter_30km_coarse_12/mean_spec_temp_variance_HR_ds_summer.nc")

In [None]:
fig, ax = plt.subplots(ncols=1, constrained_layout=True, figsize=(5,4))
mean_spec_temp_variance_HR.plot(x='wavenumber', ax=ax, ls='-', marker='', label='HR')
mean_spec_temp_variance_filtered.plot(x='wavenumber', ax=ax, ls='-', marker='', label='Filtered')
mean_spec_temp_variance_LR.plot(x='wavenumber', ax=ax, ls='-', marker='', label='LR')
ax.plot(mean_spec_temp_variance_HR['wavenumber'], 5e-6*mean_spec_temp_variance_HR['wavenumber']**(-2),ls='--', label=r'$k^{-2}$')
ax.set(yscale='log', xscale='log', ylabel='Temperature variance')
ax.grid(True)
ax.legend()
plt.show()