# Nested Grids documentation

This how-to guide explains how to use nested grids in Parcels v4, using the new `uxarray` integration. We will demonstrate how to set up a simulation with multiple nested grids, and how to handle particle transitions between these grids.

We wil base this on the LOCATE benchmark dataset ([Hernandez et al 2024](https://gmd.copernicus.org/articles/17/2221/2024/)), which contains a regional grid, a nested coastal grid and a nested harbour grid for the Barcelona region.

In [None]:
import matplotlib.pyplot as plt
import matplotlib.tri as mtri
import numpy as np
import uxarray as ux
import xarray as xr
from shapely.geometry import MultiPoint, Point, Polygon
from triangle import triangulate

import parcels

data_dir = "/Users/erik/Desktop/Parcelsv4_test/Parcels_Benchmarks_Nested_LOCATE"

## Setting up the individual nest domains
We first load the three grids using `xarray`. Since the variable names differ between the regional and nested grids, we rename the temperature variable in the regional grid for consistency.

In [None]:
nests = ["TAR/harbour", "TAR/coastal", "regional"]
ds_in = {}
for nest in nests:
    ds_in[nest] = xr.open_mfdataset(f"{data_dir}/{nest}/*.nc")

ds_in["regional"] = ds_in["regional"].rename({"thetao": "temperature"})

The two nested grids are rectangular, but rotated with respect to the regional grid. We will first identify the corners of the rectangles by finding the minimum-area rotated bounding box around the grid points using the `find_rotated_rectangle()` function. Of course, in a real-world application you would typically have the polygon coordinates available from the grid generation process so this step would not be necessary.

In [None]:
def find_rotated_rectangle(da):
    """
    Return the 4 corner coordinates (lon, lat) of the minimum-area rotated rectangle
    that encloses the finite values in an xarray DataArray.
    """
    i, j = np.nonzero(np.asarray(np.isfinite(da)))
    pts = np.column_stack((da.longitude.values[j], da.latitude.values[i]))

    rect = MultiPoint(pts).convex_hull.minimum_rotated_rectangle
    coords = np.array(rect.exterior.coords)[:-1]  # last point repeats the first
    return coords

We plot the three grids and overlay the identified rectangles to verify that they correctly capture the extent of the nested grids.

In [None]:
n_fieldsets = len(ds_in)
fig, axes = plt.subplots(1, n_fieldsets, figsize=(5 * n_fieldsets, 4))

rectangle = {}
for ax, (name, ds) in zip(axes, ds_in.items()):
    da = ds.temperature.isel(time=0)
    da.plot(ax=ax)
    rect = find_rotated_rectangle(da)
    rectangle[name] = rect
    ax.plot(
        np.append(rect[:, 0], rect[0, 0]), np.append(rect[:, 1], rect[0, 1]), "-r", lw=2
    )
    ax.set_title(name)

plt.tight_layout()
plt.show()

## Creating a Delaunay triangulation of the nests

Now comes the important part: we need to create a Delaunay triangulation of the three nests, so that we can efficiently determine in which nest a particle is located at any given time. We use the `triangle` package to perform the triangulation, and `shapely` to handle the geometric operations. 

Note that we need to keep the edges of the rectangles in the triangulation, so we need a [constrained (PSLG) Delaunay triangulation](https://en.wikipedia.org/wiki/Constrained_Delaunay_triangulation).

The result is a set of triangles covering the three nests, which we can use to determine in which nest a particle is located at any given time. It is important that the list of polygons is ordered so that the smallest nest is first, so that triangles in overlapping areas are assigned to the correct nest.

In [None]:
def constrained_triangulate_keep_edges(polygons):
    """
    Build a PSLG from the polygon boundaries, run constrained Delaunay (triangle)
    so original polygon edges are preserved, then assign triangles to polygons
    by centroid containment.

    Args:
      polygons: list of (Ni,2) numpy arrays (polygon boundary points in order)

    Returns:
      pts: (P x 2) vertices returned by triangle
      tris: (n_face x 3) array of triangle vertex indices (into pts)
      face_poly: (n_face,) array mapping each triangle -> polygon index (or -1 if outside)
    """
    # flatten vertices and create segments so polygon edges are constrained
    verts = []
    segments = []
    offset = 0
    for poly in polygons:
        Ni = len(poly)
        verts.extend(poly.tolist())
        segments.extend([[offset + j, offset + ((j + 1) % Ni)] for j in range(Ni)])
        offset += Ni
    verts = np.asarray(verts, dtype=float)
    segments = np.asarray(segments, dtype=int)

    mode = "p"  # "p" = PSLG (constrained triangulation)
    B = triangulate({"vertices": verts, "segments": segments}, mode)

    pts = B["vertices"]
    tris = B["triangles"].astype(int)

    # assign triangles to polygons using centroid test
    shapely_polys = [Polygon(p) for p in polygons]
    centers = pts[tris].mean(axis=1)
    face_poly = np.full(len(tris), -1, dtype=int)
    for ti, c in enumerate(centers):
        for ip in range(len(shapely_polys)):
            if shapely_polys[ip].contains(Point(c)):
                face_poly[ti] = ip
                break

    return pts, tris, face_poly

We can then run the triangulation and plot the resulting triangles to verify that they correctly cover the three nests.

In [None]:
polygons = [r for r in rectangle.values()]

points, face_tris, face_poly = constrained_triangulate_keep_edges(polygons)

triangles_by_poly = {
    i: face_tris[face_poly == i]
    if np.any(face_poly == i)
    else np.empty((0, 3), dtype=int)
    for i in range(len(polygons))
}

fig, ax = plt.subplots()
for i, tris in triangles_by_poly.items():
    if tris.size:
        ax.triplot(points[:, 0], points[:, 1], tris, label=f"Nest {i}")
ax.scatter(points[:, 0], points[:, 1], s=10, c="k")
ax.set_aspect("equal")
ax.legend()
plt.show()

Then, we convert the triangulation into a Parcels FieldSet using `Parcels.UxGrid()`.

In [None]:
# build an xarray dataset compatible with UGRID / uxarray
n_node = points.shape[0]
n_face = face_tris.shape[0]
n_max_face_nodes = face_tris.shape[1]

ds_tri = xr.Dataset(
    {
        "node_lon": (("n_node",), points[:, 0]),
        "node_lat": (("n_node",), points[:, 1]),
        "face_node_connectivity": (("n_face", "n_max_face_nodes"), face_tris),
        "face_polygon": (
            (
                "time",
                "nz",
                "n_face",
            ),
            face_poly[np.newaxis, np.newaxis, :],
            {
                "long_name": "Nest id",
                "description": "2=regional, 1=coastal, 0=harbour",
                "location": "face",
                "mesh": "delaunay",
            },
        ),
    },
    coords={
        "time": np.array([np.timedelta64(0, "ns")]),
        "nz": np.array([0]),
        "n_node": np.arange(n_node),
        "n_face": np.arange(n_face),
    },
    attrs={"Conventions": "UGRID-1.0"},
)

uxda = ux.UxDataArray(ds_tri["face_polygon"], uxgrid=ux.Grid(ds_tri))

NestID = parcels.Field(
    "NestID",
    uxda,
    parcels.UxGrid(uxda.uxgrid, z=uxda["nz"], mesh="spherical"),
    interp_method=parcels.interpolators.UxPiecewiseConstantFace,
)
fieldset = parcels.FieldSet([NestID])

We can confirm that the FieldSet has been created correctly by running a Parcels simulation where particles sample the `NestID` field, which indicates in which nest each particle is located at any given time.

In [None]:
X, Y = np.meshgrid(
    np.linspace(0.3, 2.9, 50),
    np.linspace(40.25, 41.7, 40),
)

NestParticle = parcels.Particle.add_variable(parcels.Variable("nestID", dtype=np.int32))
pset = parcels.ParticleSet(
    fieldset, pclass=NestParticle, lon=X.flatten(), lat=Y.flatten()
)


def SampleNestID(particles, fieldset):
    particles.nestID = fieldset.NestID[particles]


pset.execute(SampleNestID, runtime=np.timedelta64(1, "s"), dt=np.timedelta64(1, "s"))

Indeed, the visualisation below shows that particles correctly identify the nest they are in based on their location.

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 5))

triang = mtri.Triangulation(
    uxda.uxgrid.node_lon.values,
    uxda.uxgrid.node_lat.values,
    triangles=uxda.uxgrid.face_node_connectivity.values,
)

plot_args = {
    "cmap": "viridis",
    "edgecolors": "k",
    "linewidth": 0.5,
    "vmin": 0,
    "vmax": 2.0,
}
ax.tripcolor(
    triang, facecolors=np.squeeze(uxda[0, :].values), shading="flat", **plot_args
)
ax.scatter(pset.lon, pset.lat, c=pset.nestID, **plot_args)
ax.set_aspect("equal")
ax.set_title("Nesting visualisation (triangulation and interpolated particle values)")
plt.tight_layout()
plt.show()

## Advecting particles with nest transitions

We can now set up a particle advection simulation using the nested grids. We first combine all the Fields into a single FieldSet. 

We rename the individual Fields by appending the nest index to their names, so that we can easily identify which Field belongs to which nest. We also add the `NestID` Field to the FieldSet (note that Parcels v4 supports combining structured and unstructured Fields into one FieldSet, which is very convenient for this usecase).

In [None]:
fields = [NestID]
for i, ds in enumerate(ds_in.values()):
    # TODO : remove depth dimension when Parcels supports 2D copernicusmarine datasets
    ds = ds.assign_coords(depth=np.array([0]))
    ds["depth"].attrs["axis"] = "Z"

    fset = parcels.FieldSet.from_copernicusmarine(ds)
    for fld in fset.fields.values():
        fld.name = f"{fld.name}{i}"
        fields.append(fld)
fieldset = parcels.FieldSet(fields)

We then define a custom Advection kernel that advects particles using the appropriate velocity Field based on the `NestID` at the particle's location. Note that for simplicity, we use Eulerian advection here, but in a real-world application you would typically use a higher-order scheme.

In [None]:
def AdvectEE_Nests(particles, fieldset):
    particles.nestID = fieldset.NestID[particles]

    # TODO because of KernelParticle bug (GH #2143), we need to copy lon/lat/time to local variables
    time = particles.time
    z = particles.z
    lat = particles.lat
    lon = particles.lon
    u = np.zeros_like(particles.lon)
    v = np.zeros_like(particles.lat)

    unique_ids = np.unique(particles.nestID)
    for nid in unique_ids:
        mask = particles.nestID == nid
        UVField = getattr(fieldset, f"UV{nid}")
        (u[mask], v[mask]) = UVField[time[mask], z[mask], lat[mask], lon[mask]]

    particles.dlon += u * particles.dt
    particles.dlat += v * particles.dt

    # TODO particle states have to be updated manually because UVField is not called with `particles` argument (becaise of GH #2143)
    particles.state = np.where(
        np.isnan(u) | np.isnan(v),
        parcels.StatusCode.ErrorInterpolation,
        particles.state,
    )


def DeleteErrorParticles(particles, fieldset):
    any_error = particles.state >= 50  # This captures all Errors
    particles[any_error].state = parcels.StatusCode.Delete


X, Y = np.meshgrid(
    np.linspace(1.22, 1.26, 5),
    np.linspace(41.02, 41.08, 4),
)

pset = parcels.ParticleSet(
    fieldset, pclass=NestParticle, lon=X.flatten(), lat=Y.flatten()
)
ofile = parcels.ParticleFile("nest_particles.zarr", outputdt=np.timedelta64(1, "h"))
pset.execute(
    [AdvectEE_Nests, DeleteErrorParticles],
    runtime=np.timedelta64(5, "D") - np.timedelta64(1, "h"),
    dt=np.timedelta64(1, "h"),
    output_file=ofile,
)

And finally we plot the particles moving through the nested grids, confirming that they correctly transition between the nests based on their location.

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 3))

ds_out = xr.open_zarr("nest_particles.zarr")

cmap = plt.get_cmap("viridis", 3)
plt.plot(ds_out.lon.T, ds_out.lat.T, "k", linewidth=0.5)
sc = ax.scatter(ds_out.lon, ds_out.lat, c=ds_out.nestID, s=4, cmap=cmap, vmin=0, vmax=2)
xl, yl = ax.get_xlim(), ax.get_ylim()

for rect in rectangle.values():
    ax.plot(
        np.append(rect[:, 0], rect[0, 0]), np.append(rect[:, 1], rect[0, 1]), "-r", lw=1
    )
ax.set_xlim(xl)
ax.set_ylim(yl)
ax.set_aspect("equal")

cbar = plt.colorbar(sc, ticks=[0, 1, 2], ax=ax)
cbar.set_label("Nest ID")
ax.set_title("Particle advection through nests")
plt.tight_layout
plt.show()