# Example notebook to map a coarse grid nodes to a fine grid nodes and calculate inverse distance weights

This Jupyter notebook shows how we can calculate information - node mapping and
wegihts - to interpolate values from a coarse grid to a fine grid.

In [None]:
# Import modules
import numpy as np
import xarray as xr
from shapely.geometry import Point
import schismviz.suxarray as sx

## Read a coarse and a fine grid

In [None]:
# Read a fine grid. This is a target grid.
path_hgrid_gr3 = 'hgrid.gr3'
grid_fine = sx.read_hgrid_gr3(path_hgrid_gr3)

In [None]:
# Read a coarse grid. This is a source (donor) grid.
path_coarse_gr3 = 'bay_delta_coarse_v4.gr3'
grid_coarse = sx.read_hgrid_gr3(path_coarse_gr3)

## Create Shapely points of the fine grid

In [None]:
# Use xarray `apply_ufunc` to create points.
points_fine = grid_fine.node_points

## Find the element indices of the coarse grid where the fine grid points are located.

The element indices can be searched quickly using the STRtree algorithm in Shapely.
Note that the STRtree will return a list of element node indices only when
it finds a coarse grid element that contains the fine grid node. The first
element of the return value has the indices with search results, and the second
element has the indices of search results.

In [None]:
face_ind_coarse_from_node_fine = \
    grid_coarse.elem_strtree.query(points_fine,predicate='intersects')

In [None]:
face_nodes_coarse = grid_coarse.Mesh2_face_nodes.values

## Create a node mapping matrix

We want to collect the element node indices of the coarse grid at each fine
grid node. When a fine grid node does not belong to any coarse grid element,
we will find the nearest coarse grid node.

In [None]:
# Create an empty array to store the node mapping. `-1` is the fill value.
map_to_nodes_coarse = np.full((grid_fine.nMesh2_node, 3), -1, dtype=int)

### Copy over the element indices from the coarse.

Because the STRtree search does not return results when a fine grid node is
not within in a coarse grid element, we need to copy only those indices.

In [None]:
map_to_nodes_coarse[face_ind_coarse_from_node_fine[0],:] = \
    face_nodes_coarse[face_ind_coarse_from_node_fine[1]][:, :3] - 1

### Collect the fine node indices not mapped to coarse grid elements

In [None]:
nodes_not_found = list(set(range(grid_fine.nMesh2_node)) - set(face_ind_coarse_from_node_fine[0]))
nodes_not_found.sort()

### Find the nearest coarse nodes from the fine nodes not mapped to the coarse grid elements

Note that the node indices are zero-based.

In [None]:
nodes_coarse_nearest = xr.apply_ufunc(lambda p: grid_coarse.node_strtree.nearest(p),
                points_fine.isel(nSCHISM_hgrid_node=nodes_not_found),
                vectorize=True,
                dask='parallelized')

In [None]:
# Save the nearest node indices
map_to_nodes_coarse[nodes_not_found, 0] = nodes_coarse_nearest

### Save the result into a DataArray

In [None]:
da_map_to_nodes_coarse = xr.DataArray(map_to_nodes_coarse,
                                      dims=('nSCHISM_hgrid_node', 'three'),
                                      coords={'nSCHISM_hgrid_node': grid_fine.ds.nSCHISM_hgrid_node},
                                      attrs={'_FillValue': -1, 'start_index': 0},
                                      name='map_to_nodes_coarse')

## Calculate mapping weights

### Calculate mapping weights using the inverse distance to mapped nodes

When a fine grid node is too close to a coarse grid node, the distance becomes
zero and the inverse distance becomes inf. We need to deal with this case. We
want to set the weight to 1 for the nearest coarse node, and 0 for the others.

Also, note that the filled values, -1, is not filtered. So, it needs to be ignored.

In [None]:
def _calculate_weight(conn, points):
    """ Calculate distance between a point and a set of points.
    """
    x = grid_coarse.Mesh2_node_x.values[conn]
    y = grid_coarse.Mesh2_node_y.values[conn]
    xy = np.array([p.xy for p in points])
    dist = np.apply_along_axis(np.linalg.norm, 1, np.stack((x, y), axis=1) - xy)
    weight = np.reciprocal(dist)
    # Find where we see the infinite values
    mask = np.where(np.isinf(weight))
    # Adjust the weights for the node
    weight[mask[0], :] = 0.
    weight[mask] = 1.
    return weight

chunk_size = None
da_weight = xr.apply_ufunc(_calculate_weight,
               da_map_to_nodes_coarse.chunk({'nSCHISM_hgrid_node': chunk_size}),
               points_fine.chunk({'nSCHISM_hgrid_node': chunk_size}),
               input_core_dims=[['three'], []],
               output_core_dims=[['three']],
               dask='parallelized',
               output_dtypes=float).persist()

## Save the results

In [None]:
# Normalize the weights.
# This is optional.
da_weight = xr.apply_ufunc(lambda x: x / np.sum(x, axis=1)[:, np.newaxis],
                           da_weight,
                           dask='parallelized').persist()

In [None]:
# Create a dataset and save it
ds_map_and_weight = da_map_to_nodes_coarse.to_dataset(name=da_map_to_nodes_coarse.name)
ds_map_and_weight['weight'] = da_weight
path_map_and_weight = 'map_and_weight.nc'
ds_map_and_weight.to_netcdf(path_map_and_weight)