# Finding particle location using rtree in Parcels

This notebook uses the pyindex library from fbriol to find efficiently in which mesh cell a particle is.

It is based on this documentation:
https://gist.github.com/fbriol/026f9cbe38a60ecceeefa8fe899368ca

While pyindex code is very fast, we lose here most of the efficiency by looping sequentially over the particles.

This notebook shows which features are still missing (or we don't know about it?) in pyindex to be directly used by Parcels: to use the structured mesh topology.

This goal is achieved by running the `find_cell` function defined below.


In [1]:
import numpy as np
import xarray as xr
import pyindex.core as core
import time

In [2]:
# Loading NEMO ORCA0083 data (1/12 deg resolution)

data_path = '/Users/delandmeter/data/NEMO-MEDUSA/ORCA0083-N006/'
mesh_mask = data_path + 'domain/coordinates.nc'
dataset = xr.open_dataset(mesh_mask, decode_times=False)

In [3]:
glon = np.array(dataset.glamf).squeeze()
glat = np.array(dataset.gphif).squeeze()

ndim, mdim = glon.shape
print(ndim,mdim)

3059 4322


In [4]:
system = core.geodetic.System()
tree = core.geodetic.RTree()

In [5]:
tree.packing(np.asarray((glon.flatten(), glat.flatten())).T)

In [6]:
# Testing the algorithm

lon = 74.
lat = 89.

distance, index = tree.query([[lon, lat]], k=4, within=True)
xi = [0] * len(index[0])
yi = [0] * len(index[0])
for i, ind in enumerate(index[0]):
    yi[i] = int(int(ind) / mdim)
    xi[i] = int(ind) % mdim
    
assert np.allclose([lon, lat], [glon[yi[0], xi[0]], glat[yi[0], xi[0]]], atol=1/12.)

In [7]:
def interpolate_lonlat(glon, glat, lon, lat, xi, yi, xsi, eta):
    '''bi-linear interpolation within cell [j:j+2, i:i+2] as a function of the relative coordinates
       Here we simply interpolate the lon, lat coordinates, retrieving original coordinates
    '''
    
    phi = [(1-xsi)*(1-eta), xsi*(1-eta), xsi*eta, (1-xsi)*eta]
    px = np.array([glon[yi, xi], glon[yi, xi+1], glon[yi+1, xi+1], glon[yi+1, xi]])
    px = np.where(px[:] - lon > 180, px-360, px)
    px = np.where(px[:] - lon <-180, px+360, px)
    py = np.array([glat[yi, xi], glat[yi, xi+1], glat[yi+1, xi+1], glat[yi+1, xi]])
    lon_test = np.dot(phi, px)
    lat_test = np.dot(phi, py)
    assert np.allclose([lon_test, lat_test], [lon, lat])

In [8]:
def get_relative_coordinates(glon, glat, lon, lat, xi, yi):
    '''returns relative coordinates xsi, eta
       that are the coordinates of the (lon, lat) point remapped into a square cell [0,1] x [0,1]
    '''
    invA = np.array([[1, 0, 0, 0],
                     [-1, 1, 0, 0],
                     [-1, 0, 0, 1],
                     [1, -1, 1, -1]])
    px = np.array([glon[yi, xi], glon[yi, xi+1], glon[yi+1, xi+1], glon[yi+1, xi]])
    px = np.where(px[:] - lon > 180, px-360, px)
    px = np.where(px[:] - lon <-180, px+360, px)
    py = np.array([glat[yi, xi], glat[yi, xi+1], glat[yi+1, xi+1], glat[yi+1, xi]])
    a = np.dot(invA, px)
    b = np.dot(invA, py)
    
    aa = a[3]*b[2] - a[2]*b[3]
    bb = a[3]*b[0] - a[0]*b[3] + a[1]*b[2] - a[2]*b[1] + lon*b[3] - lat*a[3]
    cc = a[1]*b[0] - a[0]*b[1] + lon*b[1] - lat*a[1]
    if abs(aa) < 1e-12:  # Rectilinear cell, or quasi
        eta = -cc / bb
    else:
        det2 = bb*bb-4*aa*cc
        if det2 > 0:  # so, if det is nan we keep the xsi, eta from previous iter
            det = np.sqrt(det2)
            eta = (-bb+det)/(2*aa)
        else:  # should not happen, apart from singularities
            eta = 1e6
    if abs(a[1]+a[3]*eta) < 1e-12:  # this happens when recti cell rotated of 90deg
        xsi = ((lat-py[0])/(py[1]-py[0]) + (lat-py[3])/(py[2]-py[3])) * .5
    else:
        xsi = (lon-a[0]-a[2]*eta) / (a[1]+a[3]*eta)
    return(xsi, eta)

In [9]:
def find_cell(lon, lat, k=8):
    if k > 2000:
        print('find_cell should request k=2000 nearest neighbours in the rtree query')
        raise Exception
    distance, index = tree.query([[lon, lat]], k=k)
      
    yi = (index[0]/mdim).astype(np.int)
    xi = (index[0] % mdim).astype(np.int)
    
    xi_final = -1
    yi_final = -1
    for i in range(len(xi)):
        if yi[i] == ndim-1 or xi[i] == mdim-1 :
            continue
        (xsi, eta) = get_relative_coordinates(glon, glat, lon, lat, xi[i], yi[i])
        if xsi >= 0 and xsi <= 1 and eta >=0 and eta <= 1:
            xi_final = xi[i]
            yi_final = yi[i]
            break
    if xi_final == -1:
        find_cell(lon, lat, k=2*k)
        return
    interpolate_lonlat(glon, glat, lon, lat, xi_final, yi_final, xsi, eta)

In [10]:
find_cell(-107.00708753636134, 65.7313690045531)

In [11]:
def find_particles(n):
    plon = np.random.uniform(-180.0,180.0, n)
    plat = np.random.uniform(-70.0, 88.0, n)
    for i, lon, lat in zip(range(len(plon)), plon, plat):
        if i % (n/10) == 0:
            print('iter %d/%d' % (i, n))
        find_cell(lon, lat)
    return(plon, plat)
    
# global globCount
# globCount = 0
tic = time.time()
plon, plat = find_particles(100000)
tac = time.time()

print('CPU time %g s' % (tac-tic))

iter 0/100000
iter 10000/100000
iter 20000/100000
iter 30000/100000
iter 40000/100000
iter 50000/100000
iter 60000/100000
iter 70000/100000
iter 80000/100000
iter 90000/100000
CPU time 29.7793 s
