# UXArray with Parcels for FESOM2

This notebook is used to "dreamscape" how we want to be able to leverage uxarray and parcels together to implement particle simulations on unstructured grids.

## Implementing Spatial hashing in uxarray
We want to be able to implement the spatial hashing inside uxarray, in a manner similar to the `get_ball_tree` or `get_kd_tree` methods; you can see these documented at https://uxarray.readthedocs.io/en/latest/user-guide/tree_structures.html.

It would be nice to be able to do something like the following

```python
hash_grid = uxgrid.get_hash_grid( )
```

UXArray currently implements `KDTree` and `BallTree` classes in the `uxarray/grid/neighbors.py` module. In the `Grid` class (under `uxarray/grid/grid.py`), type-bound-procedures are defined for `get_kd_tree` and `get_ball_tree` which construct the `Grid._kd_tree` and `Grid._ball_tree` attributes respectively. 

### The HashGrid class
The main purpose of the `HashGrid` is to accelerate the lookup of unstructured grid elements based on physical location. This is done primarily through a hash table, where one can use a single integer index to obtain a short list of elements. The key to a good hash table is to have the integer index related to physical location. In this implementation, the "Hash Grid" is a structured uniformly spaced grid that has $N_x × N_y$ square grid cells of width $Δh$ that extend over the domain $[x_{min},x_{max}] × [y_{min},y_{max}]$. With this in mind, given a $(x,y)$ coordinate, we can calculate integer indices $(i,j)$ by computing

$$
i = int\left( \frac{x-x_{min}}{Δh} \right) \\
j = int\left( \frac{y-y_{min}}{Δh} \right)
$$
The single-valued hash index for the position $(x,y)$ is then $k = i + N_x j$. To create a useful hash table for looking up unstructured grid elements near a location $(x,y)$, we create a list of the elements that overlap the hash cell $(i,j)$; this "list of lists" provides us with the ability to obtain a list of elements for each hash index $k$. Further, since we can compute $k$ from an arbitrary position $(x,y)$, we can quickly return a list of elements near the point $(x,y)$. These elements are then candidates for further searching to determine if $(x,y)$ lies within.

We can implement a similar framework as used for the KD Tree and Ball Tree classes by defining a `HashGrid` class with :

1. A constructor 
2. Methods for defining the hash grid from an unstructured grid 
3. Methods for querying the hash grid for hash indices and elements from physical positions

The `HashGrid` class will also store attributes that are relevent for quick lookup of frequently used information in each of the type bound procedures. 

In [1]:
import uxarray as ux
import numpy as np

class HashGrid:
    """Data structure that provides the attributes and methods necessary for performing O(1) complexity lookups of positions on unstructured grids """

    def __init__(
            self,
            grid
    ):
        self._source_grid = grid
        self._nelements = self._source_grid.n_face

        # Hash grid size
        self.dh = self._hash_cell_size()
        # Lower left corner of the hash grid
        self.xmin = self._source_grid.node_lon.min().to_numpy()
        self.ymin = self._source_grid.node_lat.min().to_numpy()
        self.xmax = self._source_grid.node_lon.max().to_numpy()
        self.ymax = self._source_grid.node_lat.max().to_numpy()
        # Number of x points in the hash grid; used for
        # array flattening
        Lx = self.xmax - self.xmin
        Ly = self.ymax - self.ymin
        self.nx = int(np.ceil(Lx/self.dh))
        self.ny = int(np.ceil(Ly/self.dh))

        # Generate the mapping from the hash indices to unstructured grid elements
        self.faces = self._initialize_hash_to_faces()


    def _hash_cell_size(self):
        """Computes the size of the hash cells from the source grid. 
        The hash cell size is set to 1/2 of the median edge length in the grid (in degrees)"""
        return np.rad2deg(self._source_grid.edge_node_distances.median().to_numpy()*0.5)
    
    def _hash_index2d(self,x,y):
        """Computes the 2-d hash index (i,j) for the location (x,y), where x and y are given in spherical
        coordinates (in degrees)"""

        i = ( (x-self.xmin) / self.dh ).astype(int)
        j = ( (y-self.ymin) / self.dh ).astype(int)
        return i, j
    
    def _hash_index(self,x,y):
        """Computes the flattened hash index for the location (x,y), where x and y are given in spherical
        coordinates (in degrees). The single dimensioned hash index orders the flat index with all of the
        i-points first and then all the j-points."""
        i, j = self._hash_index2d(x,y)
        return i+self.nx*j
    
    def _initialize_hash_to_faces(self):
        """Create a mapping that relates unstructured grid faces to hash indices by determining
        which faces overlap with which hash cells"""

        index_to_face = [[] for i in range(self.nx*self.ny)]
        lon_bounds = self._source_grid.face_bounds_lon.to_numpy()
        lat_bounds = self._source_grid.face_bounds_lat.to_numpy()
        ib, jb = self._hash_index2d(lon_bounds,lat_bounds)

        for eid in range(self._source_grid.n_face):
            for j in range(jb[eid,0], jb[eid,1] + 1):
                for i in range(ib[eid,0], ib[eid,1] + 1):
                    index_to_face[i+self.nx*j].append(eid)

        return index_to_face
    
    def get_faces(self, x, y):
        """ Returns a list of faces associated with the hash of the coordinate (x,y)"""
        return self.faces[self._hash_index(x,y)]
    
    def get_hash_indices(self, x, y):
        """ For a list of points (x,y), return a list of hash cell indices """
        return self._hash_index(x,y)



### Modifications to the `uxarray.Grid` class
In this notebook, we use a type extension of the `uxarray.Grid` class in an effort to explore how we might modify `uxarray.Grid` and how we might use such modifications that provide us with a `HashGrid` for conveniently searching for elements that contain particles. 

On the outset, we envision adding an attribute called `_hash_grid` which stores a reference to a `HashGrid` object associated with the parent unstructured grid. This attribute can be simply created using a new type-bound procedure called `get_hash_grid`.

In [2]:
import uxarray as ux

class GridWithHash(ux.Grid):
    def get_hash_grid(self):
        self._hash_grid = HashGrid(self)
        return self._hash_grid
    
    def get_hash_indices(self, x, y):
        return self._hash_grid.get_hash_indices(x,y)
    
    def get_hash_faces(self, x, y):
        return self._hash_grid.get_faces(x,y)

## Re-working our example

Here, we re-work our [`UXArray-FESOM2-ParticlePushing.ipynb`](./UXArray-FESOM2-ParticlePushing.ipynb) example using this new structure

In [6]:
import uxarray as ux

grid_path="./data/channel_lizarbe/fesom.mesh.diag.nc"
data_path=["./data/channel_lizarbe/u.fesom.2005_cut.nc",
           "./data/channel_lizarbe/v.fesom.2005_cut.nc",
           "./data/channel_lizarbe/w.fesom.2005_cut.nc"]

uxds = ux.open_mfdataset(grid_path,data_path)
# Hack for this example only - after integrating into uxarray, this won't be necessary
uxds.uxgrid.__class__ = GridWithHash
# end of hack

# Get the hash grid
_ = uxds.uxgrid.get_hash_grid()


### Hashing particle positions (example)

In [8]:
num_particles = 10
xp = np.zeros(num_particles)
yp = np.zeros(num_particles)
for k in range(num_particles):
    xp[k] = np.random.uniform(uxds.uxgrid._hash_grid.xmin, uxds.uxgrid._hash_grid.xmax)
    yp[k] = np.random.uniform(uxds.uxgrid._hash_grid.ymin, uxds.uxgrid._hash_grid.ymax)

particle_hash_ids =  uxds.uxgrid.get_hash_indices(xp,yp)

# From the unique id, we can back out the i and j indices using modulo arithmetic
p_i = particle_hash_ids % uxds.uxgrid._hash_grid.nx
p_j = (particle_hash_ids/uxds.uxgrid._hash_grid.nx).astype(int)
for k in range(num_particles):
    print(f"Particle {k} hash id : {particle_hash_ids[k]} = ({p_i[k]},{p_j[k]})")

Particle 0 hash id : 21073 = (49,219)
Particle 1 hash id : 25314 = (66,263)
Particle 2 hash id : 18755 = (35,195)
Particle 3 hash id : 5638 = (70,58)
Particle 4 hash id : 935 = (71,9)
Particle 5 hash id : 15271 = (7,159)
Particle 6 hash id : 1597 = (61,16)
Particle 7 hash id : 18290 = (50,190)
Particle 8 hash id : 4770 = (66,49)
Particle 9 hash id : 12027 = (27,125)


## Putting this into an Interpolator protocol

See https://github.com/OceanParcels/Parcels/pull/1850 for details on the new proposed Parcels v4 API.

In [None]:
def barycentric_coordinates(xP, yP, xv, yv, atol=1e-9):
    """
    Compute the barycentric coordinates of a particle in a triangular element
    
    Parameters:
    - xP, yP: The coordinates of the particle
    - xv, yv (np.ndarray) : The vertices of the triangle as a length (3) array.
    
    Returns:
    - The barycentric coordinates (l1,l2,l3)
    - True if the point is inside the triangle, False otherwise.
    """
    

    A_ABC = xv[0]*(yv[1]-yv[2]) + xv[1]*(yv[2]-yv[0]) + xv[2]*(yv[0]-yv[1])
    A_BCP = xv[1]*(yv[2]-yP   ) + xv[2]*(yP   -yv[1]) + xP   *(yv[1]-yv[2])
    A_CAP = xv[2]*(yv[0]-yP   ) + xv[0]*(yP   -yv[2]) + xP   *(yv[2]-yv[0])
    A_ABP = xv[0]*(yv[1]-yP   ) + xv[1]*(yP   -yv[0]) + xP   *(yv[0]-yv[1])

    # Compute the vectors
    l1 = A_BCP/A_ABC
    l2 = A_CAP/A_ABC
    l3 = A_ABP/A_ABC

    inside_triangle = all( [l1 >= 0.0, l1 <= 1.0, 
                            l2 >= 0.0, l2 <= 1.0,
                            l3 >= 0.0, l3 <= 1.0,
                            abs(l1+l2+l3-1.0) <= atol] )
    
    return l1,l2,l3,inside_triangle