From 8e463e45760257017963bdc95f62b3e2e3e8d93e Mon Sep 17 00:00:00 2001 From: Kristin Cowalcijk Date: Tue, 9 Sep 2025 12:47:23 +0800 Subject: [PATCH 1/9] Initial impl of bounds contains --- rust/sedona-expr/src/spatial_filter.rs | 13 +++++++++++++ rust/sedona-geometry/src/bounding_box.rs | 19 +++++++++++++++++++ rust/sedona-geometry/src/interval.rs | 19 +++++++++++++++++++ rust/sedona-geoparquet/src/format.rs | 2 +- 4 files changed, 52 insertions(+), 1 deletion(-) diff --git a/rust/sedona-expr/src/spatial_filter.rs b/rust/sedona-expr/src/spatial_filter.rs index 4923c413..a3270f28 100644 --- a/rust/sedona-expr/src/spatial_filter.rs +++ b/rust/sedona-expr/src/spatial_filter.rs @@ -42,6 +42,8 @@ use crate::statistics::GeoStatistics; pub enum SpatialFilter { /// ST_Intersects(\, \) or ST_Intersects(\, \) Intersects(Column, BoundingBox), + /// ST_CoveredBy(\, \) or ST_CoveredBy(\, \) + CoveredBy(Column, BoundingBox), /// ST_HasZ(\) HasZ(Column), /// Logical AND @@ -64,6 +66,9 @@ impl SpatialFilter { SpatialFilter::Intersects(column, bounds) => { Self::evaluate_intersects_bbox(&table_stats[column.index()], bounds) } + SpatialFilter::CoveredBy(column, bounds) => { + Self::evaluate_covered_by_bbox(&table_stats[column.index()], bounds) + } SpatialFilter::HasZ(column) => Self::evaluate_has_z(&table_stats[column.index()]), SpatialFilter::And(lhs, rhs) => Self::evaluate_and(lhs, rhs, table_stats), SpatialFilter::Or(lhs, rhs) => Self::evaluate_or(lhs, rhs, table_stats), @@ -80,6 +85,14 @@ impl SpatialFilter { } } + fn evaluate_covered_by_bbox(column_stats: &GeoStatistics, bounds: &BoundingBox) -> bool { + if let Some(bbox) = column_stats.bbox() { + bounds.contains(&bbox) + } else { + true + } + } + fn evaluate_has_z(column_stats: &GeoStatistics) -> bool { if let Some(bbox) = column_stats.bbox() { if let Some(z) = bbox.z() { diff --git a/rust/sedona-geometry/src/bounding_box.rs b/rust/sedona-geometry/src/bounding_box.rs index 4d50cd7f..68a898bd 100644 --- a/rust/sedona-geometry/src/bounding_box.rs +++ b/rust/sedona-geometry/src/bounding_box.rs @@ -108,6 +108,25 @@ impl BoundingBox { intersects_xy && may_intersect_z && may_intersect_m } + /// Calculate whether this bounding box contains another BoundingBox + /// + /// Returns true if this bounding box contains other or false otherwise. + /// This method will consider Z and M dimension if and only if those dimensions are present + /// in both bounding boxes. + pub fn contains(&self, other: &Self) -> bool { + let contains_xy = self.x.contains_interval(&other.x) && self.y.contains_interval(&other.y); + let may_contain_z = match (self.z, other.z) { + (Some(z), Some(other_z)) => z.contains_interval(&other_z), + _ => true, + }; + let may_contain_m = match (self.m, other.m) { + (Some(m), Some(other_m)) => m.contains_interval(&other_m), + _ => true, + }; + + contains_xy && may_contain_z && may_contain_m + } + /// Update this BoundingBox to include the bounds of another /// /// This method will propagate missingness of Z or M dimensions from the two boxes diff --git a/rust/sedona-geometry/src/interval.rs b/rust/sedona-geometry/src/interval.rs index b87d0e65..eb74d8bb 100644 --- a/rust/sedona-geometry/src/interval.rs +++ b/rust/sedona-geometry/src/interval.rs @@ -73,6 +73,15 @@ pub trait IntervalTrait: std::fmt::Debug + PartialEq { /// `is_wraparound()` when not required for an implementation. fn intersects_interval(&self, other: &Self) -> bool; + /// Check for potential containment of an interval + /// + /// Note that intervals always contain their endpoints (for both the wraparound and + /// non-wraparound case). + /// + /// This method accepts Self for performance reasons to prevent unnecessary checking of + /// `is_wraparound()` when not required for an implementation. + fn contains_interval(&self, other: &Self) -> bool; + /// The width of the interval /// /// For the non-wraparound case, this is the distance between lo and hi. For the wraparound @@ -204,6 +213,10 @@ impl IntervalTrait for Interval { self.lo <= other.hi && other.lo <= self.hi } + fn contains_interval(&self, other: &Self) -> bool { + self.lo <= other.lo && self.hi >= other.hi + } + fn width(&self) -> f64 { self.hi - self.lo } @@ -316,6 +329,12 @@ impl IntervalTrait for WraparoundInterval { || right.intersects_interval(&other_right) } + fn contains_interval(&self, other: &Self) -> bool { + let (left, right) = self.split(); + let (other_left, other_right) = other.split(); + left.contains_interval(&other_left) && right.contains_interval(&other_right) + } + fn width(&self) -> f64 { if self.is_wraparound() { f64::INFINITY diff --git a/rust/sedona-geoparquet/src/format.rs b/rust/sedona-geoparquet/src/format.rs index 74c6fcfd..eb37d172 100644 --- a/rust/sedona-geoparquet/src/format.rs +++ b/rust/sedona-geoparquet/src/format.rs @@ -352,7 +352,7 @@ impl GeoParquetFileSource { if let Some(parquet_source) = inner.as_any().downcast_ref::() { let mut parquet_source = parquet_source.clone(); - // Extract the precicate from the existing source if it exists so we can keep a copy of it + // Extract the predicate from the existing source if it exists so we can keep a copy of it let new_predicate = match (parquet_source.predicate().cloned(), predicate) { (None, None) => None, (None, Some(specified_predicate)) => Some(specified_predicate), From 5a5ead4621edf5422009e00f916ccda5fe7c2327 Mon Sep 17 00:00:00 2001 From: Kristin Cowalcijk Date: Tue, 9 Sep 2025 14:17:50 +0800 Subject: [PATCH 2/9] Add tests for bbox contains and interval contains --- rust/sedona-expr/src/spatial_filter.rs | 2 +- rust/sedona-geometry/src/bounding_box.rs | 82 ++++++++++++++++++++++++ rust/sedona-geometry/src/interval.rs | 63 ++++++++++++++++++ 3 files changed, 146 insertions(+), 1 deletion(-) diff --git a/rust/sedona-expr/src/spatial_filter.rs b/rust/sedona-expr/src/spatial_filter.rs index a3270f28..3f9ee645 100644 --- a/rust/sedona-expr/src/spatial_filter.rs +++ b/rust/sedona-expr/src/spatial_filter.rs @@ -87,7 +87,7 @@ impl SpatialFilter { fn evaluate_covered_by_bbox(column_stats: &GeoStatistics, bounds: &BoundingBox) -> bool { if let Some(bbox) = column_stats.bbox() { - bounds.contains(&bbox) + bounds.contains(bbox) } else { true } diff --git a/rust/sedona-geometry/src/bounding_box.rs b/rust/sedona-geometry/src/bounding_box.rs index 68a898bd..a781b3ca 100644 --- a/rust/sedona-geometry/src/bounding_box.rs +++ b/rust/sedona-geometry/src/bounding_box.rs @@ -207,6 +207,88 @@ mod test { ))); } + #[test] + fn bounding_box_contains() { + let xyzm = BoundingBox::xyzm( + (10, 20), + (30, 40), + Some((50, 60).into()), + Some((70, 80).into()), + ); + + // Should contain a smaller box completely within bounds + assert!(xyzm.contains(&BoundingBox::xy((14, 16), (34, 36)))); + + // Should contain itself + assert!(xyzm.contains(&xyzm)); + + // Should contain a box without z or m information if xy is contained + assert!(xyzm.contains(&BoundingBox::xy((12, 18), (32, 38)))); + + // Should contain without z information but with contained m + assert!(xyzm.contains(&BoundingBox::xyzm( + (14, 16), + (34, 36), + None, + Some((74, 76).into()) + ))); + + // Should contain without m information but with contained z + assert!(xyzm.contains(&BoundingBox::xyzm( + (14, 16), + (34, 36), + Some((54, 56).into()), + None, + ))); + + // Should contain boxes that touch the boundaries + assert!(xyzm.contains(&BoundingBox::xy((10, 20), (30, 40)))); + assert!(xyzm.contains(&BoundingBox::xy((10, 15), (30, 35)))); + assert!(xyzm.contains(&BoundingBox::xy((15, 20), (35, 40)))); + + // Should *not* contain if x or y extends beyond bounds + assert!(!xyzm.contains(&BoundingBox::xy((4, 16), (34, 36)))); // x extends below + assert!(!xyzm.contains(&BoundingBox::xy((14, 26), (34, 36)))); // x extends above + assert!(!xyzm.contains(&BoundingBox::xy((14, 16), (24, 36)))); // y extends below + assert!(!xyzm.contains(&BoundingBox::xy((14, 16), (34, 46)))); // y extends above + + // Should *not* contain if z is provided but extends beyond bounds + assert!(!xyzm.contains(&BoundingBox::xyzm( + (14, 16), + (34, 36), + Some((44, 56).into()), // z extends below + None + ))); + + assert!(!xyzm.contains(&BoundingBox::xyzm( + (14, 16), + (34, 36), + Some((54, 66).into()), // z extends above + None + ))); + + // Should *not* contain if m is provided but extends beyond bounds + assert!(!xyzm.contains(&BoundingBox::xyzm( + (14, 16), + (34, 36), + None, + Some((64, 76).into()) // m extends below + ))); + + assert!(!xyzm.contains(&BoundingBox::xyzm( + (14, 16), + (34, 36), + None, + Some((74, 86).into()) // m extends above + ))); + + // Should *not* contain boxes that are completely outside + assert!(!xyzm.contains(&BoundingBox::xy((0, 5), (30, 40)))); // x completely below + assert!(!xyzm.contains(&BoundingBox::xy((25, 30), (30, 40)))); // x completely above + assert!(!xyzm.contains(&BoundingBox::xy((10, 20), (0, 25)))); // y completely below + assert!(!xyzm.contains(&BoundingBox::xy((10, 20), (45, 50)))); // y completely above + } + #[test] fn bounding_box_update() { let xyzm = BoundingBox::xyzm( diff --git a/rust/sedona-geometry/src/interval.rs b/rust/sedona-geometry/src/interval.rs index eb74d8bb..f893c7ed 100644 --- a/rust/sedona-geometry/src/interval.rs +++ b/rust/sedona-geometry/src/interval.rs @@ -470,6 +470,13 @@ mod test { // ...except the full interval assert!(empty.intersects_interval(&T::full())); + // Empty contains no intervals + assert!(!empty.contains_interval(&T::new(-10.0, 10.0))); + assert!(!empty.contains_interval(&T::full())); + + // ...except empty itself (empty set is subset of itself) + assert!(empty.contains_interval(&T::empty())); + // Merging NaN is still empty assert_eq!(empty.merge_value(f64::NAN), empty); @@ -547,6 +554,21 @@ mod test { assert!(!finite.intersects_interval(&T::new(25.0, 30.0))); assert!(!finite.intersects_interval(&T::empty())); + // Intervals that are contained + assert!(finite.contains_interval(&T::new(14.0, 16.0))); + assert!(finite.contains_interval(&T::new(10.0, 15.0))); + assert!(finite.contains_interval(&T::new(15.0, 20.0))); + assert!(finite.contains_interval(&T::new(10.0, 20.0))); // itself + assert!(finite.contains_interval(&T::empty())); + + // Intervals that are not contained + assert!(!finite.contains_interval(&T::new(5.0, 15.0))); // extends below + assert!(!finite.contains_interval(&T::new(15.0, 25.0))); // extends above + assert!(!finite.contains_interval(&T::new(5.0, 25.0))); // extends both ways + assert!(!finite.contains_interval(&T::new(0.0, 5.0))); // completely below + assert!(!finite.contains_interval(&T::new(25.0, 30.0))); // completely above + assert!(!finite.contains_interval(&T::full())); // full interval is larger + // Merging NaN assert_eq!(finite.merge_value(f64::NAN), finite); @@ -679,6 +701,47 @@ mod test { assert!(wraparound.intersects_interval(&WraparoundInterval::new(30.0, 25.0))); } + #[test] + fn wraparound_interval_actually_wraparound_contains_interval() { + // Everything *except* the interval (10, 20) + let wraparound = WraparoundInterval::new(20.0, 10.0); + + // Contains itself + assert!(wraparound.contains_interval(&wraparound)); + + // Empty is contained by everything + assert!(wraparound.contains_interval(&WraparoundInterval::empty())); + + // Does not contain the full interval + assert!(!wraparound.contains_interval(&WraparoundInterval::full())); + + // Regular interval completely between endpoints is not contained + assert!(!wraparound.contains_interval(&WraparoundInterval::new(14.0, 16.0))); + + // Wraparound intervals that exclude more (narrower included regions) are contained + assert!(wraparound.contains_interval(&WraparoundInterval::new(22.0, 8.0))); // excludes (8,22) which is larger than (10,20) + assert!(!wraparound.contains_interval(&WraparoundInterval::new(18.0, 12.0))); // excludes (12,18) which is smaller than (10,20) + + // Regular intervals don't work the same way due to the split logic + // For a regular interval (a, b), split gives (left=(a,b), right=empty) + // For wraparound to contain it, we need both parts to be contained + // This means (-inf, 10] must contain (a,b) AND [20, inf) must contain empty + // The second is always true, but the first requires b <= 10 + assert!(wraparound.contains_interval(&WraparoundInterval::new(0.0, 5.0))); // completely within left part + assert!(wraparound.contains_interval(&WraparoundInterval::new(-5.0, 10.0))); // fits in left part + assert!(!wraparound.contains_interval(&WraparoundInterval::new(25.0, 30.0))); // doesn't fit in left part + assert!(!wraparound.contains_interval(&WraparoundInterval::new(20.0, 25.0))); // doesn't fit in left part + + // Regular intervals that overlap the excluded zone are not contained + assert!(!wraparound.contains_interval(&WraparoundInterval::new(5.0, 15.0))); // overlaps excluded zone + assert!(!wraparound.contains_interval(&WraparoundInterval::new(15.0, 25.0))); // overlaps excluded zone + + // Wraparound intervals that exclude less (wider included regions) are not contained + assert!(!wraparound.contains_interval(&WraparoundInterval::new(15.0, 5.0))); // excludes (5,15) which is smaller + assert!(!wraparound.contains_interval(&WraparoundInterval::new(25.0, 15.0))); + // excludes (15,25) which is smaller + } + #[test] fn wraparound_interval_actually_wraparound_merge_value() { // Everything *except* the interval (10, 20) From b406319289268faabbf93e065d0eaf6b77463cc2 Mon Sep 17 00:00:00 2001 From: Kristin Cowalcijk Date: Tue, 9 Sep 2025 15:37:06 +0800 Subject: [PATCH 3/9] Support various spatial range predicates and distance predicate --- rust/sedona-expr/src/lib.rs | 1 + rust/sedona-expr/src/spatial_filter.rs | 135 ++++++++++++------ rust/sedona-expr/src/utils.rs | 62 ++++++++ .../src/operand_evaluator.rs | 2 + rust/sedona-spatial-join/src/optimizer.rs | 59 ++------ 5 files changed, 170 insertions(+), 89 deletions(-) create mode 100644 rust/sedona-expr/src/utils.rs diff --git a/rust/sedona-expr/src/lib.rs b/rust/sedona-expr/src/lib.rs index d242625f..c200b8e5 100644 --- a/rust/sedona-expr/src/lib.rs +++ b/rust/sedona-expr/src/lib.rs @@ -19,3 +19,4 @@ pub mod function_set; pub mod scalar_udf; pub mod spatial_filter; pub mod statistics; +pub mod utils; diff --git a/rust/sedona-expr/src/spatial_filter.rs b/rust/sedona-expr/src/spatial_filter.rs index 3f9ee645..1123cd3f 100644 --- a/rust/sedona-expr/src/spatial_filter.rs +++ b/rust/sedona-expr/src/spatial_filter.rs @@ -16,7 +16,7 @@ // under the License. use std::sync::Arc; -use arrow_schema::Schema; +use arrow_schema::{DataType, Schema}; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::Operator; use datafusion_physical_expr::{ @@ -28,7 +28,10 @@ use sedona_common::sedona_internal_err; use sedona_geometry::{bounding_box::BoundingBox, bounds::wkb_bounds_xy, interval::IntervalTrait}; use sedona_schema::datatypes::SedonaType; -use crate::statistics::GeoStatistics; +use crate::{ + statistics::GeoStatistics, + utils::{parse_distance_predicate, ParsedDistancePredicate}, +}; /// Simplified parsed spatial filter /// @@ -132,46 +135,10 @@ impl SpatialFilter { /// /// Parses expr to extract known expressions we can evaluate against statistics. pub fn try_from_expr(expr: &Arc) -> Result { - if let Some(scalar_fun) = expr.as_any().downcast_ref::() { - let raw_args = scalar_fun.args(); - let args = parse_args(raw_args); - match scalar_fun.fun().name() { - "st_intersects" => { - if args.len() != 2 { - return sedona_internal_err!( - "unexpected argument count in filter evaluation" - ); - } - - match (&args[0], &args[1]) { - (ArgRef::Col(column), ArgRef::Lit(literal)) - | (ArgRef::Lit(literal), ArgRef::Col(column)) => { - match literal_bounds(literal) { - Ok(literal_bounds) => { - Ok(Self::Intersects(column.clone(), literal_bounds)) - } - Err(e) => Err(DataFusionError::External(Box::new(e))), - } - } - // Not between a literal and a column - _ => Ok(Self::Unknown), - } - } - "st_hasz" => { - if args.len() != 1 { - return sedona_internal_err!( - "unexpected argument count in filter evaluation" - ); - } - - match &args[0] { - ArgRef::Col(column) => Ok(Self::HasZ(column.clone())), - _ => Ok(Self::Unknown), - } - } - // Not a function we know about - _ => Ok(Self::Unknown), - } + if let Some(spatial_filter) = Self::try_from_range_predicate(expr)? { + Ok(spatial_filter) + } else if let Some(spatial_filter) = Self::try_from_distance_predicate(expr)? { + Ok(spatial_filter) } else if let Some(binary_expr) = expr.as_any().downcast_ref::() { match binary_expr.op() { Operator::And => Ok(Self::And( @@ -200,6 +167,90 @@ impl SpatialFilter { Ok(Self::Unknown) } } + + fn try_from_range_predicate(expr: &Arc) -> Result> { + let Some(scalar_fun) = expr.as_any().downcast_ref::() else { + return Ok(None); + }; + + let raw_args = scalar_fun.args(); + let args = parse_args(raw_args); + let fun_name = scalar_fun.fun().name(); + match fun_name { + "st_intersects" | "st_contains" | "st_covers" | "st_equals" | "st_touches" + | "st_within" | "st_covered_by" => { + if args.len() != 2 { + return sedona_internal_err!("unexpected argument count in filter evaluation"); + } + + match (&args[0], &args[1]) { + (ArgRef::Col(column), ArgRef::Lit(literal)) + | (ArgRef::Lit(literal), ArgRef::Col(column)) => { + match literal_bounds(literal) { + Ok(literal_bounds) => { + if matches!(fun_name, "st_within" | "st_covered_by") { + Ok(Some(Self::CoveredBy(column.clone(), literal_bounds))) + } else { + Ok(Some(Self::Intersects(column.clone(), literal_bounds))) + } + } + Err(e) => Err(DataFusionError::External(Box::new(e))), + } + } + // Not between a literal and a column + _ => Ok(Some(Self::Unknown)), + } + } + "st_hasz" => { + if args.len() != 1 { + return sedona_internal_err!("unexpected argument count in filter evaluation"); + } + + match &args[0] { + ArgRef::Col(column) => Ok(Some(Self::HasZ(column.clone()))), + _ => Ok(Some(Self::Unknown)), + } + } + _ => Ok(None), + } + } + + fn try_from_distance_predicate(expr: &Arc) -> Result> { + let Some(ParsedDistancePredicate { + arg0, + arg1, + arg_distance, + }) = parse_distance_predicate(expr) + else { + return Ok(None); + }; + + let raw_args = [arg0, arg1, arg_distance]; + let args = parse_args(&raw_args); + + match (&args[0], &args[1], &args[2]) { + (ArgRef::Col(column), ArgRef::Lit(literal), ArgRef::Lit(distance)) + | (ArgRef::Lit(literal), ArgRef::Col(column), ArgRef::Lit(distance)) => { + match ( + literal_bounds(literal), + distance.value().cast_to(&DataType::Float64)?, + ) { + (Ok(literal_bounds), distance_scalar_value) => { + if let ScalarValue::Float64(Some(dist)) = distance_scalar_value { + // let expanded_bounds = literal_bounds.expand_by(dist); + let expanded_bounds = literal_bounds; + Ok(Some(Self::Intersects(column.clone(), expanded_bounds))) + } else { + sedona_internal_err!("Unexpected distance type in filter expression ({distance_scalar_value:?})") + } + } + (Err(e), _) => Err(DataFusionError::External(Box::new(e))), + } + } + // Not between a literal and a column + _ => Ok(Some(Self::Unknown)), + } + } } /// Internal utility to help match physical expression types diff --git a/rust/sedona-expr/src/utils.rs b/rust/sedona-expr/src/utils.rs new file mode 100644 index 00000000..d949f8fa --- /dev/null +++ b/rust/sedona-expr/src/utils.rs @@ -0,0 +1,62 @@ +use std::sync::Arc; + +use datafusion_expr::Operator; +use datafusion_physical_expr::{expressions::BinaryExpr, PhysicalExpr, ScalarFunctionExpr}; + +pub struct ParsedDistancePredicate { + pub arg0: Arc, + pub arg1: Arc, + pub arg_distance: Arc, +} + +pub fn parse_distance_predicate(expr: &Arc) -> Option { + // There are 3 forms of distance predicates: + // 1. st_dwithin(geom1, geom2, distance) + // 2. st_distance(geom1, geom2) <= distance or st_distance(geom1, geom2) < distance + // 3. distance >= st_distance(geom1, geom2) or distance > st_distance(geom1, geom2) + if let Some(binary_expr) = expr.as_any().downcast_ref::() { + // handle case 2. and 3. + let left = binary_expr.left(); + let right = binary_expr.right(); + let (st_distance_expr, distance_bound_expr) = match *binary_expr.op() { + Operator::Lt | Operator::LtEq => (left, right), + Operator::Gt | Operator::GtEq => (right, left), + _ => return None, + }; + + if let Some(st_distance_expr) = st_distance_expr + .as_any() + .downcast_ref::() + { + if st_distance_expr.fun().name() != "st_distance" { + return None; + } + + let args = st_distance_expr.args(); + assert!(args.len() >= 2); + Some(ParsedDistancePredicate { + arg0: Arc::clone(&args[0]), + arg1: Arc::clone(&args[1]), + arg_distance: Arc::clone(distance_bound_expr), + }) + } else { + None + } + } else if let Some(st_dwithin_expr) = expr.as_any().downcast_ref::() { + // handle case 1. + if st_dwithin_expr.fun().name() != "st_dwithin" { + return None; + } + + let args = st_dwithin_expr.args(); + assert!(args.len() >= 3); + // Some((&args[0], &args[1], &args[2])) + Some(ParsedDistancePredicate { + arg0: Arc::clone(&args[0]), + arg1: Arc::clone(&args[1]), + arg_distance: Arc::clone(&args[2]), + }) + } else { + None + } +} diff --git a/rust/sedona-spatial-join/src/operand_evaluator.rs b/rust/sedona-spatial-join/src/operand_evaluator.rs index d945f710..3696698e 100644 --- a/rust/sedona-spatial-join/src/operand_evaluator.rs +++ b/rust/sedona-spatial-join/src/operand_evaluator.rs @@ -18,6 +18,7 @@ use core::fmt; use std::{mem::transmute, sync::Arc}; use arrow_array::{Array, ArrayRef, Float64Array, RecordBatch}; +use arrow_schema::DataType; use datafusion_common::{ utils::proxy::VecAllocExt, DataFusionError, JoinSide, Result, ScalarValue, }; @@ -240,6 +241,7 @@ impl DistanceOperandEvaluator { // Expand the vec by distance let distance_columnar_value = self.inner.distance.evaluate(batch)?; + let distance_columnar_value = distance_columnar_value.cast_to(&DataType::Float64, None)?; match &distance_columnar_value { ColumnarValue::Scalar(ScalarValue::Float64(Some(distance))) => { result.rects.iter_mut().for_each(|(_, rect)| { diff --git a/rust/sedona-spatial-join/src/optimizer.rs b/rust/sedona-spatial-join/src/optimizer.rs index b45d14cd..a576014c 100644 --- a/rust/sedona-spatial-join/src/optimizer.rs +++ b/rust/sedona-spatial-join/src/optimizer.rs @@ -40,6 +40,7 @@ use datafusion_physical_plan::joins::utils::ColumnIndex; use datafusion_physical_plan::joins::{HashJoinExec, NestedLoopJoinExec}; use datafusion_physical_plan::{joins::utils::JoinFilter, ExecutionPlan}; use sedona_common::{option::SedonaOptions, sedona_internal_err}; +use sedona_expr::utils::{parse_distance_predicate, ParsedDistancePredicate}; /// Physical planner extension for spatial joins /// @@ -594,60 +595,24 @@ fn match_distance_predicate( expr: &Arc, column_indices: &[ColumnIndex], ) -> Option { - // There are 3 forms of distance predicates: - // 1. st_dwithin(geom1, geom2, distance) - // 2. st_distance(geom1, geom2) <= distance or st_distance(geom1, geom2) < distance - // 3. distance >= st_distance(geom1, geom2) or distance > st_distance(geom1, geom2) - let (arg0, arg1, distance_bound_expr) = - if let Some(binary_expr) = expr.as_any().downcast_ref::() { - // handle case 2. and 3. - let left = binary_expr.left(); - let right = binary_expr.right(); - let (st_distance_expr, distance_bound_expr) = match *binary_expr.op() { - Operator::Lt | Operator::LtEq => (left, right), - Operator::Gt | Operator::GtEq => (right, left), - _ => return None, - }; - - if let Some(st_distance_expr) = st_distance_expr - .as_any() - .downcast_ref::() - { - if st_distance_expr.fun().name() != "st_distance" { - return None; - } - - let args = st_distance_expr.args(); - assert!(args.len() >= 2); - (&args[0], &args[1], distance_bound_expr) - } else { - return None; - } - } else if let Some(st_dwithin_expr) = expr.as_any().downcast_ref::() { - // handle case 1. - if st_dwithin_expr.fun().name() != "st_dwithin" { - return None; - } - - let args = st_dwithin_expr.args(); - assert!(args.len() >= 3); - (&args[0], &args[1], &args[2]) - } else { - return None; - }; + let ParsedDistancePredicate { + arg0, + arg1, + arg_distance, + } = parse_distance_predicate(expr)?; // Try to find the expressions that evaluates to the arguments of the spatial function - let arg0_refs = collect_column_references(arg0, column_indices); - let arg1_refs = collect_column_references(arg1, column_indices); - let arg_dist_refs = collect_column_references(distance_bound_expr, column_indices); + let arg0_refs = collect_column_references(&arg0, column_indices); + let arg1_refs = collect_column_references(&arg1, column_indices); + let arg_dist_refs = collect_column_references(&arg_distance, column_indices); let arg_dist_side = side_of_column_references(&arg_dist_refs)?; let (arg0_side, arg1_side) = resolve_column_reference_sides(&arg0_refs, &arg1_refs)?; - let arg0_reprojected = reproject_column_references_for_side(arg0, column_indices, arg0_side); - let arg1_reprojected = reproject_column_references_for_side(arg1, column_indices, arg1_side); + let arg0_reprojected = reproject_column_references_for_side(&arg0, column_indices, arg0_side); + let arg1_reprojected = reproject_column_references_for_side(&arg1, column_indices, arg1_side); let arg_dist_reprojected = - reproject_column_references_for_side(distance_bound_expr, column_indices, arg_dist_side); + reproject_column_references_for_side(&arg_distance, column_indices, arg_dist_side); match (arg0_side, arg1_side) { (JoinSide::Left, JoinSide::Right) => Some(DistancePredicate::new( From b09fcce479c504c2edf94af2fe9e0982072112a8 Mon Sep 17 00:00:00 2001 From: Kristin Cowalcijk Date: Tue, 9 Sep 2025 18:30:15 +0800 Subject: [PATCH 4/9] Implement bounding box expansion --- rust/sedona-expr/src/spatial_filter.rs | 13 +-- rust/sedona-geometry/src/bounding_box.rs | 68 +++++++++++++ rust/sedona-geometry/src/interval.rs | 120 +++++++++++++++++++++++ 3 files changed, 195 insertions(+), 6 deletions(-) diff --git a/rust/sedona-expr/src/spatial_filter.rs b/rust/sedona-expr/src/spatial_filter.rs index 1123cd3f..319f6063 100644 --- a/rust/sedona-expr/src/spatial_filter.rs +++ b/rust/sedona-expr/src/spatial_filter.rs @@ -236,13 +236,14 @@ impl SpatialFilter { distance.value().cast_to(&DataType::Float64)?, ) { (Ok(literal_bounds), distance_scalar_value) => { - if let ScalarValue::Float64(Some(dist)) = distance_scalar_value { - // let expanded_bounds = literal_bounds.expand_by(dist); - let expanded_bounds = literal_bounds; - Ok(Some(Self::Intersects(column.clone(), expanded_bounds))) - } else { - sedona_internal_err!("Unexpected distance type in filter expression ({distance_scalar_value:?})") + let ScalarValue::Float64(Some(dist)) = distance_scalar_value else { + return Ok(None); + }; + if dist.is_nan() || dist < 0.0 { + return Ok(None); } + let expanded_bounds = literal_bounds.expand_by(dist); + Ok(Some(Self::Intersects(column.clone(), expanded_bounds))) } (Err(e), _) => Err(DataFusionError::External(Box::new(e))), } diff --git a/rust/sedona-geometry/src/bounding_box.rs b/rust/sedona-geometry/src/bounding_box.rs index a781b3ca..fb5b9777 100644 --- a/rust/sedona-geometry/src/bounding_box.rs +++ b/rust/sedona-geometry/src/bounding_box.rs @@ -127,6 +127,20 @@ impl BoundingBox { contains_xy && may_contain_z && may_contain_m } + /// Expand this BoundingBox by a given distance in x and y dimensions only + /// + /// Returns a new BoundingBox where x and y intervals are expanded by the given distance. + /// The x dimension (which may wrap around) is handled correctly. + /// Z and M dimensions are left unchanged. + pub fn expand_by(&self, distance: f64) -> Self { + Self { + x: self.x.expand_by(distance), + y: self.y.expand_by(distance), + z: self.z, + m: self.m, + } + } + /// Update this BoundingBox to include the bounds of another /// /// This method will propagate missingness of Z or M dimensions from the two boxes @@ -400,4 +414,58 @@ mod test { assert!(bbox_nan2.y().lo().is_nan()); assert!(bbox_nan2.y().hi().is_nan()); } + + #[test] + fn bounding_box_expand_by() { + let xyzm = BoundingBox::xyzm( + (10, 20), + (30, 40), + Some((50, 60).into()), + Some((70, 80).into()), + ); + + // Expand by a positive distance - only x and y should change + let expanded = xyzm.expand_by(5.0); + assert_eq!(expanded.x(), &WraparoundInterval::new(5.0, 25.0)); + assert_eq!(expanded.y(), &Interval::new(25.0, 45.0)); + assert_eq!(expanded.z(), &Some(Interval::new(50.0, 60.0))); // unchanged + assert_eq!(expanded.m(), &Some(Interval::new(70.0, 80.0))); // unchanged + + // Expand by zero does nothing + let unchanged = xyzm.expand_by(0.0); + assert_eq!(unchanged, xyzm); + + // Expand by negative distance does nothing + let unchanged_neg = xyzm.expand_by(-2.0); + assert_eq!(unchanged_neg, xyzm); + + // Expand by NaN does nothing + let unchanged_nan = xyzm.expand_by(f64::NAN); + assert_eq!(unchanged_nan, xyzm); + + // Test with missing z and m dimensions + let xy_only = BoundingBox::xy((10, 20), (30, 40)); + let expanded_xy = xy_only.expand_by(3.0); + assert_eq!(expanded_xy.x(), &WraparoundInterval::new(7.0, 23.0)); + assert_eq!(expanded_xy.y(), &Interval::new(27.0, 43.0)); + assert!(expanded_xy.z().is_none()); + assert!(expanded_xy.m().is_none()); + + // Test with empty intervals + let bbox_with_empty = BoundingBox::xy((10, 20), Interval::empty()); + let expanded_empty = bbox_with_empty.expand_by(5.0); + assert_eq!(expanded_empty.x(), &WraparoundInterval::new(5.0, 25.0)); + assert_eq!(expanded_empty.y(), &Interval::empty()); + + // Test with wraparound x interval + let wraparound_x = BoundingBox::xy(WraparoundInterval::new(170.0, -170.0), (30, 40)); + let expanded_wraparound = wraparound_x.expand_by(10.0); + // Original excludes (-170, 170), expanding by 10 should exclude (-160, 160) + // So the new interval should be (160, -160) + assert_eq!( + expanded_wraparound.x(), + &WraparoundInterval::new(160.0, -160.0) + ); + assert_eq!(expanded_wraparound.y(), &Interval::new(20.0, 50.0)); + } } diff --git a/rust/sedona-geometry/src/interval.rs b/rust/sedona-geometry/src/interval.rs index f893c7ed..1037bf7e 100644 --- a/rust/sedona-geometry/src/interval.rs +++ b/rust/sedona-geometry/src/interval.rs @@ -107,6 +107,13 @@ pub trait IntervalTrait: std::fmt::Debug + PartialEq { /// /// When accumulating intervals in a loop, use [Interval::update_value]. fn merge_value(&self, other: f64) -> Self; + + /// Expand this interval by a given distance + /// + /// Returns a new interval where both endpoints are moved outward by the given distance. + /// For regular intervals, this expands both lo and hi by the distance. + /// For wraparound intervals, this may result in the full interval if expansion is large enough. + fn expand_by(&self, distance: f64) -> Self; } /// 1D Interval that never wraps around @@ -240,6 +247,14 @@ impl IntervalTrait for Interval { out.update_value(other); out } + + fn expand_by(&self, distance: f64) -> Self { + if self.is_empty() || distance.is_nan() || distance < 0.0 { + return *self; + } + + Self::new(self.lo - distance, self.hi + distance) + } } #[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] @@ -442,6 +457,35 @@ impl IntervalTrait for WraparoundInterval { } } } + + fn expand_by(&self, distance: f64) -> Self { + if self.is_empty() || distance.is_nan() || distance < 0.0 { + return *self; + } + + if !self.is_wraparound() { + // For non-wraparound, just expand the inner interval + return Self { + inner: self.inner.expand_by(distance), + }; + } + + // For wraparound intervals, expanding means including more values + // Wraparound interval (a, b) where a > b excludes the region (b, a) + // To expand by distance d, we shrink the excluded region from (b, a) to (b+d, a-d) + // This means the new wraparound interval becomes (a-d, b+d) + let excluded_lo = self.inner.hi + distance; // b + d + let excluded_hi = self.inner.lo - distance; // a - d + + // If the excluded region disappears (excluded_lo >= excluded_hi), we get the full interval + if excluded_lo >= excluded_hi { + return Self::full(); + } + + // The new wraparound interval excludes (excluded_lo, excluded_hi) + // So the interval itself is (excluded_hi, excluded_lo) + Self::new(excluded_hi, excluded_lo) + } } #[cfg(test)] @@ -491,6 +535,12 @@ mod test { empty.merge_interval(&T::new(10.0, 20.0)), T::new(10.0, 20.0) ); + + // Expanding empty interval keeps it empty + assert_eq!(empty.expand_by(5.0), empty); + assert_eq!(empty.expand_by(0.0), empty); + assert_eq!(empty.expand_by(-1.0), empty); + assert_eq!(empty.expand_by(f64::NAN), empty); } #[test] @@ -620,6 +670,19 @@ mod test { finite.merge_interval(&T::new(25.0, 30.0)), T::new(10.0, 30.0) ); + + // Expanding by positive distance + assert_eq!(finite.expand_by(2.0), T::new(8.0, 22.0)); + assert_eq!(finite.expand_by(5.0), T::new(5.0, 25.0)); + + // Expanding by zero does nothing + assert_eq!(finite.expand_by(0.0), finite); + + // Expanding by negative distance does nothing + assert_eq!(finite.expand_by(-1.0), finite); + + // Expanding by NaN does nothing + assert_eq!(finite.expand_by(f64::NAN), finite); } #[test] @@ -916,6 +979,63 @@ mod test { ); } + #[test] + fn wraparound_interval_actually_wraparound_expand_by() { + // Everything *except* the interval (10, 20) + let wraparound = WraparoundInterval::new(20.0, 10.0); + + // Expanding by a small amount shrinks the excluded region + // Original excludes (10, 20), expanding by 2 should exclude (12, 18) + // So the new interval should be (18, 12) = everything except (12, 18) + assert_eq!( + wraparound.expand_by(2.0), + WraparoundInterval::new(18.0, 12.0) + ); // now excludes (12, 18) + + // Expanding by 4 should exclude (14, 16) + assert_eq!( + wraparound.expand_by(4.0), + WraparoundInterval::new(16.0, 14.0) + ); // now excludes (14, 16) + + // Expanding by 5.0 should exactly eliminate the excluded region + // excluded region (10, 20) shrinks to (15, 15) which is empty + assert_eq!(wraparound.expand_by(5.0), WraparoundInterval::full()); // excluded region disappears + + // Any expansion greater than 5.0 should also give full interval + assert_eq!(wraparound.expand_by(6.0), WraparoundInterval::full()); + + assert_eq!(wraparound.expand_by(100.0), WraparoundInterval::full()); + + // Expanding by zero does nothing + assert_eq!(wraparound.expand_by(0.0), wraparound); + + // Expanding by negative distance does nothing + assert_eq!(wraparound.expand_by(-1.0), wraparound); + + // Expanding by NaN does nothing + assert_eq!(wraparound.expand_by(f64::NAN), wraparound); + + // Test a finite (non-wraparound) wraparound interval + let non_wraparound = WraparoundInterval::new(10.0, 20.0); + assert!(!non_wraparound.is_wraparound()); + assert_eq!( + non_wraparound.expand_by(2.0), + WraparoundInterval::new(8.0, 22.0) + ); + + // Test another wraparound case - excludes (5, 15) with width 10 + let wraparound2 = WraparoundInterval::new(15.0, 5.0); + // Expanding by 3 should shrink excluded region from (5, 15) to (8, 12) + assert_eq!( + wraparound2.expand_by(3.0), + WraparoundInterval::new(12.0, 8.0) + ); + + // Expanding by 5 should make excluded region disappear: (5+5, 15-5) = (10, 10) + assert_eq!(wraparound2.expand_by(5.0), WraparoundInterval::full()); + } + #[test] fn wraparound_interval_actually_wraparound_convert() { // Everything *except* the interval (10, 20) From a2a9a5c518f125189fcc566bd153bbde4011085e Mon Sep 17 00:00:00 2001 From: Kristin Cowalcijk Date: Tue, 9 Sep 2025 18:35:26 +0800 Subject: [PATCH 5/9] Comment the distance predicate parsing util --- rust/sedona-expr/src/utils.rs | 55 ++++++++++++++++++++++++++++++----- 1 file changed, 48 insertions(+), 7 deletions(-) diff --git a/rust/sedona-expr/src/utils.rs b/rust/sedona-expr/src/utils.rs index d949f8fa..25081a9a 100644 --- a/rust/sedona-expr/src/utils.rs +++ b/rust/sedona-expr/src/utils.rs @@ -3,19 +3,62 @@ use std::sync::Arc; use datafusion_expr::Operator; use datafusion_physical_expr::{expressions::BinaryExpr, PhysicalExpr, ScalarFunctionExpr}; +/// Represents a parsed distance predicate with its constituent parts. +/// +/// Distance predicates are spatial operations that determine whether two geometries +/// are within a specified distance of each other. This struct holds the parsed +/// components of such predicates for further processing. +/// +/// ## Supported Distance Predicate Forms +/// +/// This struct can represent the parsed components from any of these distance predicate forms: +/// +/// 1. **Direct distance function**: +/// - `st_dwithin(geom1, geom2, distance)` - Returns true if geometries are within the distance +/// +/// 2. **Distance comparison (left-to-right)**: +/// - `st_distance(geom1, geom2) <= distance` - Distance is less than or equal to threshold +/// - `st_distance(geom1, geom2) < distance` - Distance is strictly less than threshold +/// +/// 3. **Distance comparison (right-to-left)**: +/// - `distance >= st_distance(geom1, geom2)` - Threshold is greater than or equal to distance +/// - `distance > st_distance(geom1, geom2)` - Threshold is strictly greater than distance +/// +/// All forms are logically equivalent but may appear differently in SQL queries. The parser +/// normalizes them into this common structure for uniform processing. pub struct ParsedDistancePredicate { + /// The first geometry argument in the distance predicate pub arg0: Arc, + /// The second geometry argument in the distance predicate pub arg1: Arc, + /// The distance threshold argument (as a physical expression) pub arg_distance: Arc, } +/// Parses a physical expression to extract distance predicate components. +/// +/// This function recognizes and parses distance predicates in spatial queries. +/// See [`ParsedDistancePredicate`] documentation for details on the supported +/// distance predicate forms. +/// +/// # Arguments +/// +/// * `expr` - A physical expression that potentially represents a distance predicate +/// +/// # Returns +/// +/// * `Some(ParsedDistancePredicate)` - If the expression is a recognized distance predicate, +/// returns the parsed components (two geometry arguments and the distance threshold) +/// * `None` - If the expression is not a distance predicate or cannot be parsed +/// +/// # Examples +/// +/// The function can parse expressions like: +/// - `st_dwithin(geometry_column, POINT(0 0), 100.0)` +/// - `st_distance(geom_a, geom_b) <= 50.0` +/// - `25.0 >= st_distance(geom_x, geom_y)` pub fn parse_distance_predicate(expr: &Arc) -> Option { - // There are 3 forms of distance predicates: - // 1. st_dwithin(geom1, geom2, distance) - // 2. st_distance(geom1, geom2) <= distance or st_distance(geom1, geom2) < distance - // 3. distance >= st_distance(geom1, geom2) or distance > st_distance(geom1, geom2) if let Some(binary_expr) = expr.as_any().downcast_ref::() { - // handle case 2. and 3. let left = binary_expr.left(); let right = binary_expr.right(); let (st_distance_expr, distance_bound_expr) = match *binary_expr.op() { @@ -43,14 +86,12 @@ pub fn parse_distance_predicate(expr: &Arc) -> Option() { - // handle case 1. if st_dwithin_expr.fun().name() != "st_dwithin" { return None; } let args = st_dwithin_expr.args(); assert!(args.len() >= 3); - // Some((&args[0], &args[1], &args[2])) Some(ParsedDistancePredicate { arg0: Arc::clone(&args[0]), arg1: Arc::clone(&args[1]), From a1b37d9624c7f4ebc8add47fccf8bebd586d6bba Mon Sep 17 00:00:00 2001 From: Kristin Cowalcijk Date: Tue, 9 Sep 2025 19:13:35 +0800 Subject: [PATCH 6/9] Add tests for PhysicalExpr to SpatialFilter conversion and SpatialFilter evaluation function --- Cargo.lock | 1 + rust/sedona-expr/Cargo.toml | 1 + rust/sedona-expr/src/spatial_filter.rs | 223 ++++++++++++++++++++++--- 3 files changed, 203 insertions(+), 22 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 999f2c0a..0984483a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4880,6 +4880,7 @@ dependencies = [ "datafusion-expr", "datafusion-physical-expr", "geo-traits 0.2.0", + "rstest", "sedona-common", "sedona-geometry", "sedona-schema", diff --git a/rust/sedona-expr/Cargo.toml b/rust/sedona-expr/Cargo.toml index 2fef1be3..0cb72603 100644 --- a/rust/sedona-expr/Cargo.toml +++ b/rust/sedona-expr/Cargo.toml @@ -29,6 +29,7 @@ result_large_err = "allow" [dev-dependencies] sedona-testing = { path = "../sedona-testing" } +rstest = { workspace = true } [dependencies] arrow-array = { workspace = true } diff --git a/rust/sedona-expr/src/spatial_filter.rs b/rust/sedona-expr/src/spatial_filter.rs index 319f6063..d591c048 100644 --- a/rust/sedona-expr/src/spatial_filter.rs +++ b/rust/sedona-expr/src/spatial_filter.rs @@ -298,15 +298,16 @@ mod test { use arrow_schema::{DataType, Field}; use datafusion_expr::{ScalarUDF, Signature, SimpleScalarUDF, Volatility}; + use rstest::rstest; use sedona_geometry::{bounding_box::BoundingBox, interval::Interval}; use sedona_schema::datatypes::WKB_GEOMETRY; use sedona_testing::create::create_scalar; use super::*; - fn dummy_st_intersects() -> ScalarUDF { + fn dummy_st_hasz() -> ScalarUDF { SimpleScalarUDF::new_with_signature( - "st_intersects", + "st_hasz", Signature::any(2, Volatility::Immutable), DataType::Boolean, Arc::new(|_args| Ok(ScalarValue::Boolean(Some(true)).into())), @@ -314,9 +315,9 @@ mod test { .into() } - fn dummy_st_hasz() -> ScalarUDF { + fn dummy_unrelated() -> ScalarUDF { SimpleScalarUDF::new_with_signature( - "st_hasz", + "st_not_a_predicate", Signature::any(2, Volatility::Immutable), DataType::Boolean, Arc::new(|_args| Ok(ScalarValue::Boolean(Some(true)).into())), @@ -324,19 +325,16 @@ mod test { .into() } - fn dummy_unrelated() -> ScalarUDF { + fn create_dummy_spatial_function(name: &str, arg_count: usize) -> ScalarUDF { SimpleScalarUDF::new_with_signature( - "st_not_a_predicate", - Signature::any(2, Volatility::Immutable), + name, + Signature::any(arg_count, Volatility::Immutable), DataType::Boolean, Arc::new(|_args| Ok(ScalarValue::Boolean(Some(true)).into())), ) .into() } - #[test] - fn spatial_filters() {} - #[test] fn predicate_intersects() { let storage_field = WKB_GEOMETRY.to_storage_field("", true).unwrap(); @@ -372,6 +370,32 @@ mod test { .contains("Unexpected scalar type in filter expression")); } + #[test] + fn predicate_covered_by() { + let storage_field = WKB_GEOMETRY.to_storage_field("", true).unwrap(); + let literal = Literal::new_with_metadata( + create_scalar(Some("POLYGON ((0 0, 4 0, 4 4, 0 4, 0 0))"), &WKB_GEOMETRY), + Some(storage_field.metadata().into()), + ); + let bounds = literal_bounds(&literal).unwrap(); + + let stats_no_info = [GeoStatistics::unspecified()]; + let stats_covered = [ + GeoStatistics::unspecified().with_bbox(Some(BoundingBox::xy((1.0, 1.0), (2.0, 2.0)))) + ]; + let stats_not_covered = [ + GeoStatistics::unspecified().with_bbox(Some(BoundingBox::xy((3.0, 3.0), (5.0, 5.0)))) + ]; + let col0 = Column::new("col0", 0); + + // CoveredBy should return true when column bbox is fully contained in literal bounds + assert!(SpatialFilter::CoveredBy(col0.clone(), bounds.clone()).evaluate(&stats_no_info)); + assert!(SpatialFilter::CoveredBy(col0.clone(), bounds.clone()).evaluate(&stats_covered)); + assert!( + !SpatialFilter::CoveredBy(col0.clone(), bounds.clone()).evaluate(&stats_not_covered) + ); + } + #[test] fn predicate_has_z() { let col0 = Column::new("col0", 0); @@ -470,39 +494,194 @@ mod test { )); } - #[test] - fn predicate_from_expr_intersects() { + #[rstest] + fn predicate_from_expr_commutative_functions( + #[values("st_intersects", "st_contains", "st_covers", "st_equals", "st_touches")] func_name: &str, + ) { let column: Arc = Arc::new(Column::new("geometry", 0)); let storage_field = WKB_GEOMETRY.to_storage_field("", true).unwrap(); let literal: Arc = Arc::new(Literal::new_with_metadata( - create_scalar(Some("POINT (1 2)"), &WKB_GEOMETRY), + create_scalar(Some("POLYGON ((0 0, 2 0, 2 2, 0 2, 0 0))"), &WKB_GEOMETRY), Some(storage_field.metadata().into()), )); - let st_intersects = dummy_st_intersects(); + // Test functions that should result in Intersects filter + let func = create_dummy_spatial_function(func_name, 2); let expr: Arc = Arc::new(ScalarFunctionExpr::new( - "intersects", - Arc::new(st_intersects.clone()), + func_name, + Arc::new(func.clone()), vec![column.clone(), literal.clone()], Arc::new(Field::new("", DataType::Boolean, true)), )); let predicate = SpatialFilter::try_from_expr(&expr).unwrap(); - assert!(matches!(predicate, SpatialFilter::Intersects(_, _))); + assert!( + matches!(predicate, SpatialFilter::Intersects(_, _)), + "Function {} should produce Intersects filter", + func_name + ); - let expr: Arc = Arc::new(ScalarFunctionExpr::new( - "intersects", - Arc::new(st_intersects.clone()), + // Test reversed argument order + let expr_reversed: Arc = Arc::new(ScalarFunctionExpr::new( + func_name, + Arc::new(func), vec![literal.clone(), column.clone()], Arc::new(Field::new("", DataType::Boolean, true)), )); + let predicate_reversed = SpatialFilter::try_from_expr(&expr_reversed).unwrap(); + assert!( + matches!(predicate_reversed, SpatialFilter::Intersects(_, _)), + "Function {} with reversed args should produce Intersects filter", + func_name + ); + } + + #[rstest] + fn predicate_from_expr_non_commutative_functions( + #[values("st_within", "st_covered_by")] func_name: &str, + ) { + let column: Arc = Arc::new(Column::new("geometry", 0)); + let storage_field = WKB_GEOMETRY.to_storage_field("", true).unwrap(); + let literal: Arc = Arc::new(Literal::new_with_metadata( + create_scalar(Some("POLYGON ((0 0, 2 0, 2 2, 0 2, 0 0))"), &WKB_GEOMETRY), + Some(storage_field.metadata().into()), + )); + + // Test functions that should result in CoveredBy filter + let func = create_dummy_spatial_function(func_name, 2); + let expr: Arc = Arc::new(ScalarFunctionExpr::new( + func_name, + Arc::new(func.clone()), + vec![column.clone(), literal.clone()], + Arc::new(Field::new("", DataType::Boolean, true)), + )); let predicate = SpatialFilter::try_from_expr(&expr).unwrap(); - assert!(matches!(predicate, SpatialFilter::Intersects(_, _))) + assert!( + matches!(predicate, SpatialFilter::CoveredBy(_, _)), + "Function {} should produce CoveredBy filter", + func_name + ); + + // Test reversed argument order: should be converted to Intersects filter since + // within/covered_by are not commutative + let expr_reversed: Arc = Arc::new(ScalarFunctionExpr::new( + func_name, + Arc::new(func), + vec![literal.clone(), column.clone()], + Arc::new(Field::new("", DataType::Boolean, true)), + )); + let predicate_reversed = SpatialFilter::try_from_expr(&expr_reversed).unwrap(); + assert!( + matches!(predicate_reversed, SpatialFilter::Intersects(_, _)), + "Function {} with reversed args should produce Intersects filter", + func_name + ); + } + + #[test] + fn predicate_from_expr_distance_functions() { + let column: Arc = Arc::new(Column::new("geometry", 0)); + let storage_field = WKB_GEOMETRY.to_storage_field("", true).unwrap(); + let literal: Arc = Arc::new(Literal::new_with_metadata( + create_scalar(Some("POINT (1 2)"), &WKB_GEOMETRY), + Some(storage_field.metadata().into()), + )); + let distance_literal: Arc = + Arc::new(Literal::new(ScalarValue::Float64(Some(100.0)))); + + // Test ST_DWithin function + let st_dwithin = create_dummy_spatial_function("st_dwithin", 3); + let dwithin_expr: Arc = Arc::new(ScalarFunctionExpr::new( + "st_dwithin", + Arc::new(st_dwithin.clone()), + vec![column.clone(), literal.clone(), distance_literal.clone()], + Arc::new(Field::new("", DataType::Boolean, true)), + )); + let predicate = SpatialFilter::try_from_expr(&dwithin_expr).unwrap(); + assert!( + matches!(predicate, SpatialFilter::Intersects(_, _)), + "ST_DWithin should produce Intersects filter with expanded bounds" + ); + + // Test ST_DWithin with reversed geometry arguments + let dwithin_expr_reversed: Arc = Arc::new(ScalarFunctionExpr::new( + "st_dwithin", + Arc::new(st_dwithin), + vec![literal.clone(), column.clone(), distance_literal.clone()], + Arc::new(Field::new("", DataType::Boolean, true)), + )); + let predicate_reversed = SpatialFilter::try_from_expr(&dwithin_expr_reversed).unwrap(); + assert!( + matches!(predicate_reversed, SpatialFilter::Intersects(_, _)), + "ST_DWithin with reversed args should produce Intersects filter" + ); + + // Test ST_Distance <= threshold + let st_distance = create_dummy_spatial_function("st_distance", 2); + let distance_expr: Arc = Arc::new(ScalarFunctionExpr::new( + "st_distance", + Arc::new(st_distance.clone()), + vec![column.clone(), literal.clone()], + Arc::new(Field::new("", DataType::Boolean, true)), + )); + let comparison_expr: Arc = Arc::new(BinaryExpr::new( + distance_expr.clone(), + Operator::LtEq, + distance_literal.clone(), + )); + let predicate = SpatialFilter::try_from_expr(&comparison_expr).unwrap(); + assert!( + matches!(predicate, SpatialFilter::Intersects(_, _)), + "ST_Distance <= threshold should produce Intersects filter" + ); + + // Test threshold >= ST_Distance + let comparison_expr_reversed: Arc = Arc::new(BinaryExpr::new( + distance_literal.clone(), + Operator::GtEq, + distance_expr.clone(), + )); + let predicate_reversed = SpatialFilter::try_from_expr(&comparison_expr_reversed).unwrap(); + assert!( + matches!(predicate_reversed, SpatialFilter::Intersects(_, _)), + "threshold >= ST_Distance should produce Intersects filter" + ); + + // Test with negative distance (should be treated as Unknown) + let negative_distance: Arc = + Arc::new(Literal::new(ScalarValue::Float64(Some(-10.0)))); + let st_dwithin = create_dummy_spatial_function("st_dwithin", 3); + let dwithin_expr: Arc = Arc::new(ScalarFunctionExpr::new( + "st_dwithin", + Arc::new(st_dwithin.clone()), + vec![column.clone(), literal.clone(), negative_distance], + Arc::new(Field::new("", DataType::Boolean, true)), + )); + let predicate = SpatialFilter::try_from_expr(&dwithin_expr).unwrap(); + assert!( + matches!(predicate, SpatialFilter::Unknown), + "Negative distance should result in Unknown filter" + ); + + // Test with NaN distance (should be treated as Unknown) + let nan_distance: Arc = + Arc::new(Literal::new(ScalarValue::Float64(Some(f64::NAN)))); + let dwithin_expr_nan: Arc = Arc::new(ScalarFunctionExpr::new( + "st_dwithin", + Arc::new(st_dwithin), + vec![column.clone(), literal.clone(), nan_distance], + Arc::new(Field::new("", DataType::Boolean, true)), + )); + let predicate_nan = SpatialFilter::try_from_expr(&dwithin_expr_nan).unwrap(); + assert!( + matches!(predicate_nan, SpatialFilter::Unknown), + "NaN distance should result in Unknown filter" + ); } #[test] fn predicate_from_intersects_errors() { let literal: Arc = Arc::new(Literal::new(ScalarValue::Null)); - let st_intersects = dummy_st_intersects(); + let st_intersects = create_dummy_spatial_function("st_intersects", 2); // Wrong number of args let expr_no_args: Arc = Arc::new(ScalarFunctionExpr::new( From 6dd47eec6766dd5bc562b7859cca74a4e5eb56a8 Mon Sep 17 00:00:00 2001 From: Kristin Cowalcijk Date: Tue, 9 Sep 2025 20:29:58 +0800 Subject: [PATCH 7/9] Fix conversion from st_contains/st_coveres to SpatialFilter --- rust/sedona-expr/src/spatial_filter.rs | 139 ++++++++++++++++++++++--- 1 file changed, 123 insertions(+), 16 deletions(-) diff --git a/rust/sedona-expr/src/spatial_filter.rs b/rust/sedona-expr/src/spatial_filter.rs index d591c048..71f29496 100644 --- a/rust/sedona-expr/src/spatial_filter.rs +++ b/rust/sedona-expr/src/spatial_filter.rs @@ -177,8 +177,7 @@ impl SpatialFilter { let args = parse_args(raw_args); let fun_name = scalar_fun.fun().name(); match fun_name { - "st_intersects" | "st_contains" | "st_covers" | "st_equals" | "st_touches" - | "st_within" | "st_covered_by" => { + "st_intersects" | "st_equals" | "st_touches" => { if args.len() != 2 { return sedona_internal_err!("unexpected argument count in filter evaluation"); } @@ -188,11 +187,65 @@ impl SpatialFilter { | (ArgRef::Lit(literal), ArgRef::Col(column)) => { match literal_bounds(literal) { Ok(literal_bounds) => { - if matches!(fun_name, "st_within" | "st_covered_by") { - Ok(Some(Self::CoveredBy(column.clone(), literal_bounds))) - } else { - Ok(Some(Self::Intersects(column.clone(), literal_bounds))) - } + Ok(Some(Self::Intersects(column.clone(), literal_bounds))) + } + Err(e) => Err(DataFusionError::External(Box::new(e))), + } + } + // Not between a literal and a column + _ => Ok(Some(Self::Unknown)), + } + } + "st_within" | "st_covered_by" | "st_coveredby" => { + if args.len() != 2 { + return sedona_internal_err!("unexpected argument count in filter evaluation"); + } + + match (&args[0], &args[1]) { + (ArgRef::Col(column), ArgRef::Lit(literal)) => { + // column within/covered_by literal -> CoveredBy filter + match literal_bounds(literal) { + Ok(literal_bounds) => { + Ok(Some(Self::CoveredBy(column.clone(), literal_bounds))) + } + Err(e) => Err(DataFusionError::External(Box::new(e))), + } + } + (ArgRef::Lit(literal), ArgRef::Col(column)) => { + // literal within/covered_by column -> Intersects filter + match literal_bounds(literal) { + Ok(literal_bounds) => { + Ok(Some(Self::Intersects(column.clone(), literal_bounds))) + } + Err(e) => Err(DataFusionError::External(Box::new(e))), + } + } + // Not between a literal and a column + _ => Ok(Some(Self::Unknown)), + } + } + "st_contains" | "st_covers" => { + if args.len() != 2 { + return sedona_internal_err!("unexpected argument count in filter evaluation"); + } + + match (&args[0], &args[1]) { + (ArgRef::Col(column), ArgRef::Lit(literal)) => { + // column contains/covers literal -> Intersects filter + // (column must potentially intersect literal to contain it) + match literal_bounds(literal) { + Ok(literal_bounds) => { + Ok(Some(Self::Intersects(column.clone(), literal_bounds))) + } + Err(e) => Err(DataFusionError::External(Box::new(e))), + } + } + (ArgRef::Lit(literal), ArgRef::Col(column)) => { + // literal contains/covers column -> CoveredBy filter + // (equivalent to st_within(column, literal)) + match literal_bounds(literal) { + Ok(literal_bounds) => { + Ok(Some(Self::CoveredBy(column.clone(), literal_bounds))) } Err(e) => Err(DataFusionError::External(Box::new(e))), } @@ -496,7 +549,7 @@ mod test { #[rstest] fn predicate_from_expr_commutative_functions( - #[values("st_intersects", "st_contains", "st_covers", "st_equals", "st_touches")] func_name: &str, + #[values("st_intersects", "st_equals", "st_touches")] func_name: &str, ) { let column: Arc = Arc::new(Column::new("geometry", 0)); let storage_field = WKB_GEOMETRY.to_storage_field("", true).unwrap(); @@ -536,8 +589,8 @@ mod test { } #[rstest] - fn predicate_from_expr_non_commutative_functions( - #[values("st_within", "st_covered_by")] func_name: &str, + fn predicate_from_expr_within_covered_by_functions( + #[values("st_within", "st_covered_by", "st_coveredby")] func_name: &str, ) { let column: Arc = Arc::new(Column::new("geometry", 0)); let storage_field = WKB_GEOMETRY.to_storage_field("", true).unwrap(); @@ -546,7 +599,7 @@ mod test { Some(storage_field.metadata().into()), )); - // Test functions that should result in CoveredBy filter + // Test functions that should result in CoveredBy filter when column is first arg let func = create_dummy_spatial_function(func_name, 2); let expr: Arc = Arc::new(ScalarFunctionExpr::new( func_name, @@ -561,8 +614,7 @@ mod test { func_name ); - // Test reversed argument order: should be converted to Intersects filter since - // within/covered_by are not commutative + // Test reversed argument order: should be converted to Intersects filter let expr_reversed: Arc = Arc::new(ScalarFunctionExpr::new( func_name, Arc::new(func), @@ -577,6 +629,49 @@ mod test { ); } + #[rstest] + fn predicate_from_expr_contains_covers_functions( + #[values("st_contains", "st_covers")] func_name: &str, + ) { + let column: Arc = Arc::new(Column::new("geometry", 0)); + let storage_field = WKB_GEOMETRY.to_storage_field("", true).unwrap(); + let literal: Arc = Arc::new(Literal::new_with_metadata( + create_scalar(Some("POLYGON ((0 0, 2 0, 2 2, 0 2, 0 0))"), &WKB_GEOMETRY), + Some(storage_field.metadata().into()), + )); + + // Test functions that should result in Intersects filter when column is first arg + // (column contains/covers literal -> column must intersect literal) + let func = create_dummy_spatial_function(func_name, 2); + let expr: Arc = Arc::new(ScalarFunctionExpr::new( + func_name, + Arc::new(func.clone()), + vec![column.clone(), literal.clone()], + Arc::new(Field::new("", DataType::Boolean, true)), + )); + let predicate = SpatialFilter::try_from_expr(&expr).unwrap(); + assert!( + matches!(predicate, SpatialFilter::Intersects(_, _)), + "Function {} should produce Intersects filter", + func_name + ); + + // Test reversed argument order: should be converted to CoveredBy filter + // (literal contains/covers column -> equivalent to st_within(column, literal)) + let expr_reversed: Arc = Arc::new(ScalarFunctionExpr::new( + func_name, + Arc::new(func), + vec![literal.clone(), column.clone()], + Arc::new(Field::new("", DataType::Boolean, true)), + )); + let predicate_reversed = SpatialFilter::try_from_expr(&expr_reversed).unwrap(); + assert!( + matches!(predicate_reversed, SpatialFilter::CoveredBy(_, _)), + "Function {} with reversed args should produce CoveredBy filter", + func_name + ); + } + #[test] fn predicate_from_expr_distance_functions() { let column: Arc = Arc::new(Column::new("geometry", 0)); @@ -678,10 +773,22 @@ mod test { ); } - #[test] - fn predicate_from_intersects_errors() { + #[rstest] + fn predicate_from_spatial_relation_function_errors( + #[values( + "st_intersects", + "st_equals", + "st_touches", + "st_contains", + "st_covers", + "st_within", + "st_covered_by", + "st_coveredby" + )] + func_name: &str, + ) { let literal: Arc = Arc::new(Literal::new(ScalarValue::Null)); - let st_intersects = create_dummy_spatial_function("st_intersects", 2); + let st_intersects = create_dummy_spatial_function(func_name, 2); // Wrong number of args let expr_no_args: Arc = Arc::new(ScalarFunctionExpr::new( From ec1a245eb960037b29269cc8ceb5f5c2a38e1e8b Mon Sep 17 00:00:00 2001 From: Kristin Cowalcijk Date: Tue, 9 Sep 2025 20:59:57 +0800 Subject: [PATCH 8/9] Extend geoparquet metadata pruning test for st_contains --- Cargo.lock | 1 + rust/sedona-geoparquet/Cargo.toml | 1 + rust/sedona-geoparquet/src/format.rs | 17 ++++++++++++----- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0984483a..4003484e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4993,6 +4993,7 @@ dependencies = [ "geo-traits 0.2.0", "object_store", "parquet", + "rstest", "sedona-common", "sedona-expr", "sedona-geometry", diff --git a/rust/sedona-geoparquet/Cargo.toml b/rust/sedona-geoparquet/Cargo.toml index ee5ffe6b..6c1ffa1e 100644 --- a/rust/sedona-geoparquet/Cargo.toml +++ b/rust/sedona-geoparquet/Cargo.toml @@ -33,6 +33,7 @@ default = [] [dev-dependencies] sedona-testing = { path = "../sedona-testing" } url = { workspace = true } +rstest = { workspace = true } [dependencies] async-trait = { workspace = true } diff --git a/rust/sedona-geoparquet/src/format.rs b/rust/sedona-geoparquet/src/format.rs index eb37d172..93346d26 100644 --- a/rust/sedona-geoparquet/src/format.rs +++ b/rust/sedona-geoparquet/src/format.rs @@ -530,6 +530,7 @@ mod test { use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; use datafusion_physical_expr::PhysicalExpr; + use rstest::rstest; use sedona_schema::crs::lnglat; use sedona_schema::datatypes::{Edges, SedonaType, WKB_GEOMETRY}; use sedona_testing::create::create_scalar; @@ -675,21 +676,24 @@ mod test { assert_eq!(total_size, 244); } + #[rstest] #[tokio::test] - async fn pruning_geoparquet_metadata() { + async fn pruning_geoparquet_metadata(#[values("st_intersects", "st_within")] udf_name: &str) { let data_dir = geoarrow_data_dir().unwrap(); let ctx = setup_context(); let udf: ScalarUDF = SimpleScalarUDF::new_with_signature( - "st_intersects", + udf_name, Signature::any(2, Volatility::Immutable), DataType::Boolean, Arc::new(|_args| Ok(ScalarValue::Boolean(Some(true)).into())), ) .into(); - let definitely_non_intersecting_scalar = - create_scalar(Some("POINT (100 200)"), &WKB_GEOMETRY); + let definitely_non_intersecting_scalar = create_scalar( + Some("POLYGON ((100 200), (100 300), (200 300), (100 200))"), + &WKB_GEOMETRY, + ); let storage_field = WKB_GEOMETRY.to_storage_field("", true).unwrap(); let df = ctx @@ -708,7 +712,10 @@ mod test { let batches_out = df.collect().await.unwrap(); assert!(batches_out.is_empty()); - let definitely_intersecting_scalar = create_scalar(Some("POINT (30 10)"), &WKB_GEOMETRY); + let definitely_intersecting_scalar = create_scalar( + Some("POLYGON ((30 10), (30 20), (40 20), (40 10), (30 10))"), + &WKB_GEOMETRY, + ); let df = ctx .table(format!("{data_dir}/example/files/*_geo.parquet")) .await From 34d88f9206522b8d6ec0ae8a9aad0d567c32bae0 Mon Sep 17 00:00:00 2001 From: Kristin Cowalcijk Date: Tue, 9 Sep 2025 21:52:48 +0800 Subject: [PATCH 9/9] Fix review comments --- rust/sedona-spatial-join/src/operand_evaluator.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/rust/sedona-spatial-join/src/operand_evaluator.rs b/rust/sedona-spatial-join/src/operand_evaluator.rs index 3696698e..56dca647 100644 --- a/rust/sedona-spatial-join/src/operand_evaluator.rs +++ b/rust/sedona-spatial-join/src/operand_evaluator.rs @@ -241,6 +241,7 @@ impl DistanceOperandEvaluator { // Expand the vec by distance let distance_columnar_value = self.inner.distance.evaluate(batch)?; + // No timezone conversion needed for distance; pass None as cast_options explicitly. let distance_columnar_value = distance_columnar_value.cast_to(&DataType::Float64, None)?; match &distance_columnar_value { ColumnarValue::Scalar(ScalarValue::Float64(Some(distance))) => {