# Working with Structured Grids as Unstructured Grids

In this notebook, we explore working with curvilinear structured grid with an unstructured representation. We are interested in expressing the interpolation of c-grid velocity on a curvilinear grid using a `uxarray.Grid` object. The purpose of doing this is to explore the possibility of representing all data in Parcels as unstructured grid data. For validation, we aim to reproduce the interpolation of a velocity field on an aqua-planet NEMO data set (see https://docs.oceanparcels.org/en/latest/examples/tutorial_nemo_curvilinear.html for a complete description)

## Loading in the example data set

In [None]:
from datetime import timedelta

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

import parcels

example_dataset_folder = parcels.download_example_dataset("NemoCurvilinear_data")
dsm = xr.open_dataset(f"{example_dataset_folder}/mesh_mask.nc4")
dsm = dsm.squeeze(drop=True) 
# make sure there's actual labels for the original x and y
dsm = dsm.assign_coords(x=list(range(dsm.sizes["x"])), y=list(range(dsm.sizes["y"])))
dsm


## Describing a structured curvilinear grid as an unstructured grid

In our initial attempts to read in structured grid data with `uxarray.from_structured_grid`, we found that `uxarray` does not support curvilinear grids, where the `x` and `y` coordinates are dependent on two indices. Because of this, we spend some time mapping the curvilinear grid coordinates to a UGRID compliant `xarray.dataset`; see [xugrid documentation](https://deltares.github.io/xugrid/examples/quick_overview.html#from-xarray-dataset) for more background on this. Essentially, we need to 

* Flatten the latitude and longitude vorticity points to 1-D array and store each as data variables `node_lat` and `node_lon`, respectively, with dimension `node`. In NEMO, these are the `gphif` and `glamf` data variables.
* Create a `node` coordinate that is a integer list of node ids
* Define the connectivity information between "faces", which are synomous with tracer cells on a c-grid, and the corner nodes. Ultimately, this information is stored in a data variable called `face_node_connectivity` which has dimensions `(n_face,nmax_face)`, where `n_face` is the number of "faces" (tracer points) and `nmax_face` is the max number of nodes per face. For the structured grid, `nmax_face=4`.


In [72]:
xuds = xr.Dataset()
xuds = xuds.assign(node_lon=(("node", ),dsm.glamf.stack(node=["x", "y"]).data))
xuds = xuds.assign(node_lat=(("node", ),dsm.gphif.stack(node=["x", "y"]).data))
xuds = xuds.assign_coords(node=(("node", ), list(range(xuds.sizes["node"]))))

# Now we construct face node connectivity
dsm_minus_WS = dsm.isel(x=slice(1, None), y=slice(1, None))

x_node_offsets = xr.DataArray([-1, 0, 0, -1], dims=("node", ), coords={"node": ["ll", "lr", "ur", "ul"]})
y_node_offsets = xr.DataArray([-1, -1, 0, 0], dims=("node", ), coords={"node": ["ll", "lr", "ur", "ul"]})
x_nodes = (xr.broadcast(dsm_minus_WS.x, dsm_minus_WS.y)[0] + x_node_offsets).rename("x_nodes")
y_nodes = (xr.broadcast(dsm_minus_WS.x, dsm_minus_WS.y)[1] + y_node_offsets).rename("y_nodes")

mesh_nodes = xr.Dataset()
mesh_nodes = mesh_nodes.assign(glamf=dsm.glamf.stack(node=["x", "y"]))
mesh_nodes = mesh_nodes.assign(gphif=dsm.gphif.stack(node=["x", "y"]))
mesh_nodes = mesh_nodes.assign_coords(node_id=(("node", ), list(range(mesh_nodes.sizes["node"]))))
mesh_nodes

mesh_elements = xr.Dataset()
mesh_elements = mesh_elements.assign(node_id=mesh_nodes.unstack().node_id.sel(x=x_corners, y=y_corners))
mesh_elements = mesh_elements.rename({"node_id": "node_id_in_element"})
mesh_elements = mesh_elements.assign_coords(
    x=mesh_elements.x.sel(corner="ur", drop=True).isel(y=0, drop=True),
    y=mesh_elements.y.sel(corner="ur", drop=True).isel(x=0, drop=True),
)
mesh_elements = mesh_elements.stack(element=["x", "y"])
mesh_elements = mesh_elements.drop_vars(["x", "y"])

xuds = xuds.assign(face_node_connectivity=(("n_face","nmax_face", ),mesh_elements.node_id_in_element.transpose().data))
xuds


## Converting to uxarray.Grid

Once we have the UGRID compliant `xarray.Dataset` , we can construct a `uxarray.Grid` from this object. This gives us the benefit of being able to use all the generated connectivity fields, cartesian coordinate calculations, etc. included with `uxarray`. Additionally, we now have a datastructure that can be used with our previously constructured spatial hashing methods.

In [73]:
import uxarray

uxgrid = uxarray.Grid(xuds)
uxgrid

## Work in progress below this point...
From here, we need to verify that the spatial hashing method works with quadrilateral grids.

In [74]:
import uxarray 
from typing import Union
import numpy as np


def calculate_hash_cell_size(uxobj: Union[uxarray.UxDataset, uxarray.Grid], scalefac: float = 1.0):
    """
    Calculate the hash cell size using the median edge length as a characteristic length scale. The characteristic
    length scale is optionally scaled by the provided `scalefac` parameter to obtain the hash cell grid size.
    At the moment, the hash cell size is returned in units of degrees.
    
    Parameters:
    - uxobj (uxarray.UXDataSet or uxarray.Grid type) : uxarray object containing the grid either as a Grid or UXDataSet
    - scalefac (float): Multiplier for the median triangle diagonal length to set the hash cell size
    
    Returns:
    - cell_size (float): The recommended hash cell size.
    """
    import numpy as np

    if isinstance(uxobj, uxarray.UxDataset):
        grid = uxobj.uxgrid
    elif isinstance(uxobj, uxarray.Grid):
        grid = uxobj
    else:
        raise TypeError(f"Unsupported type: {type(uxobj)}")

    # The uxarray grid.edge_node_distances method returns
    # the distances in units of degrees
    return grid.edge_node_distances.median().to_numpy()*180.0/np.pi*scalefac # Need to return the hash size in degrees


# Calculate hash cell size
hash_cell_size = calculate_hash_cell_size(uxgrid,0.5)
print( f"Hash cell size : {hash_cell_size}")

# Get the bounding box for the domain
x_max = uxgrid.node_lon.max().to_numpy()
x_min = uxgrid.node_lon.min().to_numpy()
y_max = uxgrid.node_lat.max().to_numpy()
y_min = uxgrid.node_lat.min().to_numpy()

# To determine how many hash cells we need, we divide the domain length by the hash_cell_size
Lx = (x_max-x_min)
Ly = (y_max-y_min)
print( f"Domain size (Lx,Ly) : ({Lx},{Ly})")
nxh = int(np.ceil(Lx/hash_cell_size))
nyh = int(np.ceil(Ly/hash_cell_size))
print( f"Number of hash cells (nxh,nyh) : ({nxh},{nyh})")


Hash cell size : 0.07429025100845717
Domain size (Lx,Ly) : (359.9960174560547,166.91999053955078)
Number of hash cells (nxh,nyh) : (4846,2247)


In [75]:
import numpy as np

def get_faces_in_hash_cells(uxobj: Union[uxarray.UxDataset, uxarray.Grid], xwest, xeast, nxh, nyh, cell_size):
    """
    Efficiently find the list of faces whose bounding box overlaps with the specified hash cells.
    
    Parameters:
    - uxobj (uxarray.UXDataSet or uxarray.Grid type) : uxarray object containing the grid either as a Grid or UXDataSet
    - xwest (float) : longitude of the western extent of the model domain
    - xeast (float) : longitude of the eastern extent of the model domain
    - nxh (int) : The number of hash cell vertices in the x-direction
    - nyh (int) : The number of hash cell vertices in the y-direction
    - cell_size (float): The size of each hash cell (assumed square, with width and height equal to cell_size).
    
    Returns:
    - overlapping_faces (dict): A dictionary where keys are the hash cell index and values are lists of triangle indices.
    """
    import numpy as np
    if isinstance(uxobj, uxarray.UxDataset):
        grid = uxobj.uxgrid
    elif isinstance(uxobj, uxarray.Grid):
        grid = uxobj
    else:
        raise TypeError(f"Unsupported type: {type(uxobj)}")
    
    overlapping_faces = [[] for i in range(nxh*nyh)]

    lon_bounds = grid.face_bounds_lon.to_numpy()
    lon_bounds_periodic_image = periodic_image(lon_bounds,xwest,xeast)
    lat_bounds = grid.face_bounds_lat.to_numpy()

    # Loop over each triangle element
    for eid in range(grid.n_face):
        
        # Calculate the bounding box of the triangle
        x_min = lon_bounds[eid,0]
        x_max = lon_bounds[eid,1]
        dx = x_max - x_min

        # Here, we need to make adjustments for potentially periodic boundaries.
        # We can look at calculating the bounding box using the reported value for x_max
        # from `grid.face_bounds_lon` or using its periodic image.
        x_max_periodic_image = lon_bounds_periodic_image[eid,1]
        dx_p = x_min - x_max_periodic_image

        # If the difference between x_min and the the periodic image of x_max is 
        # less than the difference between x_min and x_max, then we set x_max = x_min
        # and x_min = x_max_periodic_image
        if( dx_p < dx ):
            x_max = x_min
            x_min = x_max_periodic_image
        

        y_min = lat_bounds[eid,0]
        y_max = lat_bounds[eid,1]

        # Find the hash cell range that overlaps with the triangle's bounding box
        i_min = int(np.floor(x_min / cell_size))
        i_max = int(np.floor(x_max / cell_size))
        j_min = int(np.floor(y_min / cell_size))
        j_max = int(np.floor(y_max / cell_size))
        
        # Iterate over all hash cells that intersect the bounding box
        for j in range(j_min, j_max + 1):
            for i in range(i_min, i_max + 1):
                overlapping_faces[i+nxh*j].append(eid)
                    
    return overlapping_faces

import matplotlib.pyplot as plt


hashmap = get_faces_in_hash_cells(uxgrid,x_min,x_max,nxh,nyh,hash_cell_size)

## Count how many triangles are in each hash cell.
hashmap_triangle_count = np.zeros(len(hashmap))
k=0
for t in hashmap:
    hashmap_triangle_count[k]=len(t)
    k+=1

nnonzero = np.count_nonzero(hashmap_triangle_count)
print(f"Minimum triangle count per hash cell : {np.min(hashmap_triangle_count)}")
print(f"Maximum triangle count per hash cell : {np.max(hashmap_triangle_count)}")
print(f"Median triangle count per hash cell  : {np.median(hashmap_triangle_count)}")
print(f"Number of non-empty hash cells       : {nnonzero} ( {nnonzero/len(hashmap)*100.0} %)")

# Plotting a basic histogram
bin_edges = np.arange( np.min(hashmap_triangle_count)-0.5, np.max(hashmap_triangle_count)+1.5,1)
values, bins, bars = plt.hist(hashmap_triangle_count, bins=bin_edges, color='skyblue', edgecolor='black',align='mid')
plt.bar_label(bars, fontsize=13, color='navy')
# Adding labels and title
plt.xlabel('Values')
plt.ylabel('Frequency')
plt.title('Histogram of triangle count per hash cell')
 
# Display the plot
plt.show()

AssertionError: 

In [None]:
def get_element_id(uxgrid,y,x):


def spatial_c_grid_interpolation2D(uxarray, uxgrid, ti, z, y, x, time, particle=None, applyConversion=True):

    #(_, eta, xsi, zi, yi, xi) = self.U._search_indices(time, z, y, x, ti, particle=particle)

    #(eta, xsi, eid) = self.U._get_element_id(y,x,particle=particle) # TO DO : method to get the element id, with the barycentric coordinates
    (eta, xsi, eid) = get_element_id(uxgrid,y,x) # Example/demo routine
    
    px = grid.lon[grid.face_node_connectivity[:,eid]]
    py = grid.lat[grid.face_node_connectivity[:,eid]]

    # if grid._gtype in [GridType.RectilinearSGrid, GridType.RectilinearZGrid]:
    #     px = np.array([grid.lon[xi], grid.lon[xi + 1], grid.lon[xi + 1], grid.lon[xi]])
    #     py = np.array([grid.lat[yi], grid.lat[yi], grid.lat[yi + 1], grid.lat[yi + 1]])
    # else:
    #     px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]])
    #     py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]])

    if grid.mesh == "spherical":
        px[0] = px[0] + 360 if px[0] < x - 225 else px[0]
        px[0] = px[0] - 360 if px[0] > x + 225 else px[0]
        px[1:] = np.where(px[1:] - px[0] > 180, px[1:] - 360, px[1:])
        px[1:] = np.where(-px[1:] + px[0] > 180, px[1:] + 360, px[1:])
        
    xx = (1 - xsi) * (1 - eta) * px[0] + xsi * (1 - eta) * px[1] + xsi * eta * px[2] + (1 - xsi) * eta * px[3]
    assert abs(xx - x) < 1e-4
    c1 = i_u._geodetic_distance(py[0], py[1], px[0], px[1], grid.mesh, np.dot(i_u.phi2D_lin(0.0, xsi), py))
    c2 = i_u._geodetic_distance(py[1], py[2], px[1], px[2], grid.mesh, np.dot(i_u.phi2D_lin(eta, 1.0), py))
    c3 = i_u._geodetic_distance(py[2], py[3], px[2], px[3], grid.mesh, np.dot(i_u.phi2D_lin(1.0, xsi), py))
    c4 = i_u._geodetic_distance(py[3], py[0], px[3], px[0], grid.mesh, np.dot(i_u.phi2D_lin(eta, 0.0), py))

    U0 = self.U.data[ti, grid.face_node_connectivity[3,eid]] * c4
    U1 = self.U.data[ti, grid.face_node_connectivity[2,eid]] * c2
    V0 = self.V.data[ti, grid.face_node_connectivity[1,eid]] * c1
    V1 = self.V.data[ti, grid.face_node_connectivity[2,eid]] * c3

    # if grid.zdim == 1:
    #     if self.gridindexingtype == "nemo":
    #         U0 = self.U.data[ti, yi + 1, xi] * c4
    #         U1 = self.U.data[ti, yi + 1, xi + 1] * c2
    #         V0 = self.V.data[ti, yi, xi + 1] * c1
    #         V1 = self.V.data[ti, yi + 1, xi + 1] * c3
    #     elif self.gridindexingtype in ["mitgcm", "croco"]:
    #         U0 = self.U.data[ti, yi, xi] * c4
    #         U1 = self.U.data[ti, yi, xi + 1] * c2
    #         V0 = self.V.data[ti, yi, xi] * c1
    #         V1 = self.V.data[ti, yi + 1, xi] * c3
    # else:
    #     if self.gridindexingtype == "nemo":
    #         U0 = self.U.data[ti, zi, yi + 1, xi] * c4
    #         U1 = self.U.data[ti, zi, yi + 1, xi + 1] * c2
    #         V0 = self.V.data[ti, zi, yi, xi + 1] * c1
    #         V1 = self.V.data[ti, zi, yi + 1, xi + 1] * c3
    #     elif self.gridindexingtype in ["mitgcm", "croco"]:
    #         U0 = self.U.data[ti, zi, yi, xi] * c4
    #         U1 = self.U.data[ti, zi, yi, xi + 1] * c2
    #         V0 = self.V.data[ti, zi, yi, xi] * c1
    #         V1 = self.V.data[ti, zi, yi + 1, xi] * c3
    U = (1 - xsi) * U0 + xsi * U1
    V = (1 - eta) * V0 + eta * V1
    rad = np.pi / 180.0
    deg2m = 1852 * 60.0
    if applyConversion:
        meshJac = (deg2m * deg2m * math.cos(rad * y)) if grid.mesh == "spherical" else 1
    else:
        meshJac = deg2m if grid.mesh == "spherical" else 1

    jac = i_u._compute_jacobian_determinant(py, px, eta, xsi) * meshJac

    u = (
        (-(1 - eta) * U - (1 - xsi) * V) * px[0]
        + ((1 - eta) * U - xsi * V) * px[1]
        + (eta * U + xsi * V) * px[2]
        + (-eta * U + (1 - xsi) * V) * px[3]
    ) / jac
    v = (
        (-(1 - eta) * U - (1 - xsi) * V) * py[0]
        + ((1 - eta) * U - xsi * V) * py[1]
        + (eta * U + xsi * V) * py[2]
        + (-eta * U + (1 - xsi) * V) * py[3]
    ) / jac
    if isinstance(u, da.core.Array):
        u = u.compute()
        v = v.compute()
    return (u, v)