## OctTree!

Testing the time to look-up nearby records with the `OctTree` implementation.

The `OctTree` is used to find records within a spatio-temporal range of a given point, or within a box defined by lon, lat, & time bounds.

In [1]:
import os


os.environ["POLARS_MAX_THREADS"] = "1"

import inspect
import random
from datetime import datetime, timedelta
from string import ascii_letters, digits

import numpy as np
import polars as pl

from geotrees.octtree import OctTree
from geotrees.record import SpaceTimeRecord as Record
from geotrees.shape import SpaceTimeRectangle as Rectangle

## Set-up functions

For comparisons using brute-force approach

In [2]:
def generate_uid(n: int) -> str:
    """Generates a pseudo uid by randomly selecting from characters"""
    chars = ascii_letters + digits
    return "".join(random.choice(chars) for _ in range(n))


def check_cols(
    df: pl.DataFrame | pl.LazyFrame,
    cols: list[str],
    var_name: str = "dataframe",
) -> None:
    """
    Check that a dataframe contains a list of columns. Raises an error if not.

    Parameters
    ----------
    df : polars Frame
        Dataframe to check
    cols : list[str]
        Required columns
    var_name : str
        Name of the Frame - used for displaying in any error.
    """
    calling_func = inspect.stack()[1][3]
    if isinstance(df, pl.DataFrame):
        have_cols = df.columns
    elif isinstance(df, pl.LazyFrame):
        have_cols = df.collect_schema().names()
    else:
        raise TypeError("Input Frame is not a polars Frame")

    cols_in_frame = intersect(cols, have_cols)
    missing = [c for c in cols if c not in cols_in_frame]

    if len(missing) > 0:
        err_str = f"({calling_func}) - {var_name} missing required columns. "
        err_str += f"Require: {', '.join(cols)}. "
        err_str += f"Missing: {', '.join(missing)}."
        raise ValueError(err_str)

    return


def haversine_df(
    df: pl.DataFrame | pl.LazyFrame,
    lon: float,
    lat: float,
    radius: float = 6371,
    lon_col: str = "lon",
    lat_col: str = "lat",
) -> pl.DataFrame | pl.LazyFrame:
    """
    Compute haversine distance on earth surface between lon-lat positions
    in a polars DataFrame and a lon-lat position.

    Parameters
    ----------
    df : polars.DataFrame
        The data, containing required columns:
            * lon_col
            * lat_col
            * date_var
    lon : float
        The longitude of the position.
    lat : float
        The latitude of the position.
    radius : float
        Radius of earth in km
    lon_col : str
        Name of the longitude column
    lat_col : str
        Name of the latitude column

    Returns
    -------
    polars.DataFrame
        With additional column specifying distances between consecutive points
        in the same units as 'R'. With colname defined by 'out_colname'.
    """
    required_cols = [lon_col, lat_col]

    check_cols(df, required_cols, "df")
    return (
        df.with_columns(
            [
                pl.col(lat_col).radians().alias("_lat0"),
                pl.lit(lat).radians().alias("_lat1"),
                (pl.col(lon_col) - lon).radians().alias("_dlon"),
                (pl.col(lat_col) - lat).radians().alias("_dlat"),
            ]
        )
        .with_columns(
            (
                (pl.col("_dlat") / 2).sin().pow(2)
                + pl.col("_lat0").cos()
                * pl.col("_lat1").cos()
                * (pl.col("_dlon") / 2).sin().pow(2)
            ).alias("_a")
        )
        .with_columns(
            (2 * radius * (pl.col("_a").sqrt().arcsin()))
            .round(2)
            .alias("_dist")
        )
        .drop(["_lat0", "_lat1", "_dlon", "_dlat", "_a"])
    )


def intersect(a, b) -> set:
    """Intersection of a and b, items in both a and b"""
    return set(a) & set(b)


def nearby_ships(
    lon: float,
    lat: float,
    pool: pl.DataFrame,
    max_dist: float,
    lon_col: str = "lon",
    lat_col: str = "lat",
    dt: datetime | None = None,
    date_col: str | None = None,
    dt_gap: timedelta | None = None,
    filter_datetime: bool = False,
) -> pl.DataFrame:
    """
    Find observations nearby to a position in space (and optionally time).

    Get a frame of all records that are within a maximum distance of the
    provided point.

    If filter_datetime is True, then only records from the same datetime will
    be returned. If a specific filter is desired this should be performed
    before calling this function and set filter_datetime to False.

    Parameters
    ----------
    lon : float
        The longitude of the position.
    lat : float
        The latitude of the position.
    pool : polars.DataFrame
        The pool of records to search. Can be pre-filtered and filter_datetime
        set to False.
    max_dist : float
        Will return records that have distance to the point <= this value.
    lon_col : str
        Name of the longitude column in the pool DataFrame
    lat_col : str
        Name of the latitude column in the pool DataFrame
    dt : datetime | None
        Datetime of the record. Must be set if filter_datetime is True.
    date_col : str | None
        Name of the datetime column in the pool. Must be set if filter_datetime
        is True.
    dt_gap : timedelta | None
        Allowed time-gap for records. Records that fall between
        dt - dt_gap and dt + dt_gap will be returned. If not set then only
        records at dt will be returned. Applies if filter_datetime is True.
    filter_datetime : bool
        Only return records at the same datetime record as the input value. If
        assessing multiple points with different datetimes, hence calling this
        function frequently it will be more efficient to partition the pool
        first, then set this value to False and only input the subset of data.

    Returns
    -------
    polars.DataFrame
        Containing only records from the pool within max_dist of the input
        point, optionally at the same datetime if filter_datetime is True.
    """
    required_cols = [lon_col, lat_col]
    check_cols(pool, required_cols, "pool")

    if filter_datetime:
        if not dt or not date_col:
            raise ValueError(
                "'dt' and 'date_col' must be provided if 'filter_datetime' "
                + "is True"
            )
        if date_col not in pool.columns:
            raise ValueError(f"'date_col' value {date_col} not found in pool.")
        if not dt_gap:
            pool = pool.filter(pl.col(date_col).eq(dt))
        else:
            pool = pool.filter(
                pl.col(date_col).is_between(
                    dt - dt_gap, dt + dt_gap, closed="both"
                )
            )

    return (
        pool.pipe(
            haversine_df,
            lon=lon,
            lat=lat,
            lon_col=lon_col,
            lat_col=lat_col,
        )
        .filter(pl.col("_dist").le(max_dist))
        .drop(["_dist"])
    )

In [3]:
N = 16_000
lons = pl.int_range(-180, 180, eager=True)
lats = pl.int_range(-90, 90, eager=True)
dates = pl.datetime_range(
    datetime(1900, 1, 1, 0),
    datetime(1900, 1, 31, 23),
    interval="1h",
    eager=True,
)

lons_use = lons.sample(N, with_replacement=True).alias("lon")
lats_use = lats.sample(N, with_replacement=True).alias("lat")
dates_use = dates.sample(N, with_replacement=True).alias("datetime")
uids = pl.Series("uid", [generate_uid(8) for _ in range(N)])

df = pl.DataFrame([lons_use, lats_use, dates_use, uids]).unique()

## Initialise the OctTree Object

In [4]:
otree = OctTree(
    Rectangle(
        -180, 180, -90, 90, datetime(1900, 1, 1, 0), datetime(1900, 1, 31, 23)
    ),
    capacity=10,
    # max_depth=25,
)

In [5]:
dts = pl.datetime_range(
    datetime(1900, 1, 1),
    datetime(1900, 2, 1),
    interval="1h",
    eager=True,
    closed="left",
)
n = dts.len()
lons = 180 - 360 * np.random.rand(n)
lats = 90 - 180 * np.random.rand(n)
test_df = pl.DataFrame({"lon": lons, "lat": lats, "datetime": dts})
test_recs = [Record(*r) for r in test_df.rows()]
dt = timedelta(days=1)
dist = 350

In [6]:
all_recs = [Record(**row) for row in df.rows(named=True)]

## Time Execution

Testing the identification of nearby points against the original full search.

### Initialisation

In [7]:
%%time
for r in df.rows():
    otree.insert(Record(*r))

CPU times: user 97.9 ms, sys: 2.38 ms, total: 100 ms
Wall time: 99.5 ms


### OctTree query

In [8]:
%%timeit test_record = random.choice(test_recs)  # noqa: F821
otree.nearby_points(test_record, dist=dist, t_dist=dt)

143 μs ± 66.8 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


### Brute Force

In [9]:
%%timeit test_record = random.choice(test_recs)  # noqa: F821
[
    r
    for r in all_recs
    if r.distance(test_record) <= dist
    and r.datetime <= test_record.datetime + dt
    and r.datetime >= test_record.datetime - dt
]

9.14 ms ± 124 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [10]:
%%timeit test_record = random.choice(test_recs)  # noqa: F821
nearby_ships(
    lon=test_record.lon,
    lat=test_record.lat,
    dt=test_record.datetime,
    max_dist=dist,
    dt_gap=dt,
    date_col="datetime",
    pool=df,
    filter_datetime=True,
)

5.07 ms ± 110 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Verify

Check that records are the same

In [11]:
%%time
dist = 250
for _ in range(250):
    rec = Record(*random.choice(df.rows()))
    orig = nearby_ships(
        lon=rec.lon,
        lat=rec.lat,
        dt=rec.datetime,
        max_dist=dist,
        dt_gap=dt,
        date_col="datetime",
        pool=df,
        filter_datetime=True,
    )
    tree = otree.nearby_points(rec, dist=dist, t_dist=dt)
    if orig.height > 0:
        if not tree:
            print(rec)
            print("NO TREE!")
            print(f"{orig = }")
        else:
            tree = pl.from_records(
                [(r.lon, r.lat, r.datetime, r.uid) for r in tree], orient="row"
            ).rename(
                {
                    "column_0": "lon",
                    "column_1": "lat",
                    "column_2": "datetime",
                    "column_3": "uid",
                }
            )
            if tree.height != orig.height:
                print("Tree and Orig Heights Do Not Match")
                print(rec)
                print(f"{orig = }")
                print(f"{tree = }")
            else:
                # tree = tree.with_columns(pl.col("uid").str.slice(0, 6))
                if not tree.sort("uid").equals(orig.sort("uid")):
                    print("Tree and Orig Do Not Match")
                    print(rec)
                    print(f"{orig = }")
                    print(f"{tree = }")

CPU times: user 2.38 s, sys: 108 ms, total: 2.49 s
Wall time: 2.61 s


## Check -180/180 boundary

Expect to see `SpaceTimeRecord`s with both positive and negative `"x"` values

In [12]:
# Ensure some points are added
cross_points = [
    Record(179, -43, datetime(1900, 1, 14, 12), uid="e_1"),
    Record(-179, -42, datetime(1900, 1, 14, 15), uid="e_2"),
]

for c in cross_points:
    otree.insert(c)

out = otree.nearby_points(
    Record(179.5, -43.1, datetime(1900, 1, 14, 13)),
    dist=200,
    t_dist=timedelta(days=3),
)
for o in out:
    print(o)
assert all(c in out for c in cross_points)

SpaceTimeRecord(x = -179, y = -42, datetime = 1900-01-14 15:00:00, uid = e_2)
SpaceTimeRecord(x = 179, y = -43, datetime = 1900-01-14 12:00:00, uid = e_1)
