In [1]:
import matplotlib.cm as cm
import matplotlib.pyplot as plt
from matplotlib import collections as mc
import numpy as np
import os
import pylab as pl
from scipy.spatial import cKDTree
import sys
import xarray as xr
import yaml

### Files paths

In [2]:
yaml_file = './config.yaml'
params = yaml.safe_load(open(yaml_file))

INPUT_GRID = params['input_grid_nc']
INPUT_PATH_SSH_FILTERED = params['input_path_data_filtered']
INPUT_PATH_SEG_MASKS = params['input_path_data_seg_masks']

OUTPUT_SUBSET_PRE_PROCESSED = params['output_subset_pre_processed']

# TODO: for every path in SSH, read the file
# concatenate them, or read them with open_mfdataset()

### Reduce the mesh to a subset

In [3]:
data_mesh = xr.open_dataset(INPUT_GRID, engine='netcdf4')

# Deletion of useless fields
vars_keys = data_mesh.data_vars
for key in vars_keys:
    if key != 'lat' and key != 'lon' and key != 'edges' and key != 'nodes':
        data_mesh = data_mesh.drop_vars(key)

data_mesh = data_mesh.drop_vars('nz')
data_mesh = data_mesh.drop_vars('nz1')

# RoI definition
model_lon = data_mesh.lon.values
model_lat = data_mesh.lat.values
left = params['input_left']
right = params['input_right']
bottom = params['input_bottom']
top = params['input_top']
region_mask = (model_lon > left) & (model_lon < right) & (model_lat < top) & (model_lat > bottom)

# RoI: edges extraction
edge_0 = data_mesh.edges[0].values
edge_1 = data_mesh.edges[1].values
edge_0 -= 1
edge_1 -= 1

edges_subset = []
for i in range(len(edge_0)):
    if region_mask[edge_0[i]] & region_mask[edge_1[i]]:
        edges_subset.append([edge_0[i], edge_1[i]])

edges_subset = np.array(edges_subset, dtype="int32")
data_mesh = data_mesh.drop_vars('edges')
data_mesh['edges'] = (('edges_subset', 'n2'), edges_subset)

# Nodes extraction
nodes_subset = []
for i in range(len(region_mask)):
    if region_mask[i]:
        nodes_subset.append(i)

nodes_subset = np.array(nodes_subset, dtype="int32")
data_mesh['nodes'] =(('nodes_subset'), nodes_subset)

print(data_mesh)

<xarray.Dataset>
Dimensions:  (nod2: 8852366, nodes_subset: 757747, edges_subset: 2268763, n2: 2)
Dimensions without coordinates: nod2, nodes_subset, edges_subset, n2
Data variables:
    lon      (nod2) float64 -177.4 -177.2 -177.3 -177.3 ... 178.4 178.4 178.3
    lat      (nod2) float64 -78.05 -78.05 -78.08 -78.02 ... -77.85 -77.82 -77.85
    nodes    (nodes_subset) int32 1035320 1035324 1035345 ... 7853535 7853537
    edges    (edges_subset, n2) int32 1035320 1035378 ... 7106912 7106914


### Interpolation of SSH and segmentation mask to the unstructured subset grid

#### First, some helper functions from https://github.com/nextGEMS/nextGEMS_Cycle3/blob/main/FESOM/STARTHERE_FESOM.ipynb

In [4]:
def lon_lat_to_cartesian(lon, lat, R=6371000):
    
    lon_r = np.radians(lon)
    lat_r = np.radians(lat)

    x = R * np.cos(lat_r) * np.cos(lon_r)
    y = R * np.cos(lat_r) * np.sin(lon_r)
    z = R * np.sin(lat_r)
    return x, y, z

In [5]:
def create_indexes_and_distances(model_lon, model_lat, lons, lats, k=1, workers=2):
    
    xs, ys, zs = lon_lat_to_cartesian(lons.flatten(), lats.flatten())
    xt, yt, zt = lon_lat_to_cartesian(model_lon, model_lat)

    tree = cKDTree(list(zip(xs, ys, zs)))
    distances, inds = tree.query(list(zip(xt, yt, zt)), k=k, workers=workers)

    return distances, inds

### The actual interpolation process

In [6]:
# SSH and segmentation information files
data_files = []
for file in os.listdir(INPUT_PATH_SEG_MASKS):
    data_files.append(file)

# Target coordinates(the unstructured FESOM mesh)
model_lon = data_mesh.lon[data_mesh.nodes].values
model_lat = data_mesh.lat[data_mesh.nodes].values

for file in data_files:
    month = xr.open_dataset(INPUT_PATH_SEG_MASKS + '/' + str(file))
    
    # Source coordinates(the SSH and segmentation masks information)
    data_lon, data_lat = np.meshgrid(month['LONGITUDE'], month['LATITUDE'])
    data_lon = data_lon.T
    data_lat = data_lat.T
    
    for day in range(month.dims['TIME']):
        # SSH values
        data_ssh_values = month.ssh.values[day]
        
        # To avoid problems with stretched SSH values
        data_ssh_values[np.abs(data_ssh_values) > 100] = np.nan
        
        # Segmentation mask
        data_seg_mask_values = month.seg_mask.values[day]
        
        # Perform a K-Nearest Neighbors between matrix and mesh points
        distances, inds = create_indexes_and_distances(model_lon, model_lat, data_lon, data_lat, k=1, workers=-1)
        
        flat_ssh = data_ssh_values.flatten()
        flat_seg_mask = data_seg_mask_values.flatten()
        
        ssh = flat_ssh[inds]
        seg_mask = flat_seg_mask[inds]
        
        # Area of influence to take care of KNN indexes crossing land
        radius_of_influence = 10000
        
        ssh[distances >= radius_of_influence] = np.nan
        ssh = np.ma.masked_invalid(ssh)
        
        seg_mask[distances >= radius_of_influence] = 0
        seg_mask = np.ma.masked_invalid(seg_mask)
        
        print(np.min(ssh), np.max(ssh), min(distances), max(distances))
        
        # Store SSH and seg_mask in the subset mesh file - TODO: I might create a 'graphs_n' variable as big as the total sum of timestamps in the data, and append everything in the end(NEED RAM FOR THIS)
        ssh = np.array(ssh, dtype="float64")
        seg_mask = np.array(seg_mask, dtype="float64")
        
    
        '''continue here'''
        
        '''
        data_mesh['ssh'] = (('nodes_subset'), ssh)
        data_mesh['seg_mask'] = (('nodes_subset'), seg_mask)
        data_mesh
        '''
        
        break
    
    print(month)
    break

-1.8223326206207275 1.3817895650863647 2.996811811857859 9859.58150920928
<xarray.Dataset>
Dimensions:    (LONGITUDE: 1200, LATITUDE: 480, TIME: 28)
Coordinates:
  * LONGITUDE  (LONGITUDE) float32 -70.0 -69.92 -69.83 ... 29.75 29.83 29.92
  * LATITUDE   (LATITUDE) float32 -60.0 -59.92 -59.83 ... -20.25 -20.17 -20.08
  * TIME       (TIME) float32 5.706e+05 5.706e+05 ... 5.712e+05 5.712e+05
Data variables:
    ssh        (TIME, LONGITUDE, LATITUDE) float64 -0.7358 -0.719 ... 9.969e+36
    seg_mask   (TIME, LONGITUDE, LATITUDE) int64 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0


### Loss of information when going from structured SSH/seg_mask to unstructured

In [7]:
not_nan_ssh = np.count_nonzero(~np.isnan(data_mesh.ssh.values))
not_nan_ssh_actual_proportion = not_nan_ssh*100/data_mesh.dims['nodes_subset']

print(f"Nodes with initial SSH data:\t\t\t\t{data_mesh.dims['nodes_subset']}")
print(f"Nodes with SSH/seg_mask data after pre-processing:\t{not_nan_ssh}")
print(f"Actual % of SSH/seg_mask info after pre-processing:\t {not_nan_ssh_actual_proportion:.3f}%")

AttributeError: 'Dataset' object has no attribute 'ssh'

### Group of plots to understand where we started and where we got

In [None]:
model_lon_roi = data_mesh.lon[data_mesh.nodes].values
model_lat_roi = data_mesh.lat[data_mesh.nodes].values
uns_ssh = data_mesh.ssh.values
str_ssh = month.ssh.values[0]
uns_seg_mask = data_mesh.seg_mask.values
str_seg_mask = month.seg_mask.values[0]

fig, axes = plt.subplots(2, 2, figsize=(18, 9))

im = axes[0][0].imshow(np.flipud(str_ssh.T), cmap=cm.seismic, vmin=-1, vmax=1)
axes[0][0].set_title("SSH(SLA) values after the interpolation to regular grid")
cb = plt.colorbar(im, orientation='horizontal', pad=0.1, ax=axes[0][0])

im2 = axes[0][1].imshow(np.flipud(str_seg_mask.T), cmap=cm.viridis, vmin=0, vmax=2)
axes[0][1].set_title("Segmentation mask")
cb = plt.colorbar(im2, orientation='horizontal', pad=0.1, ax=axes[0][1])

im3 = axes[1][0].scatter(model_lon_roi, model_lat_roi, c=uns_ssh, s=0.1, cmap=cm.seismic, vmin=-1, vmax=1)
axes[1][0].set_title(f"The remaining nodes with SSH data after the interpolation to unstructured mesh")
plt.colorbar(im3, orientation='horizontal', pad=0.1, ax=axes[1][0])

im4 = axes[1][1].scatter(model_lon_roi, model_lat_roi, c=uns_seg_mask, s=0.1, cmap=cm.viridis, vmin=0, vmax=2)
axes[1][1].set_title(f"The remaining nodes with seg_mask data after the interpolation to unstructured mesh")
plt.colorbar(im4, orientation='horizontal', pad=0.1, ax=axes[1][1])

#plt.close()