diff --git a/rust/sedona-expr/src/spatial_filter.rs b/rust/sedona-expr/src/spatial_filter.rs index f4d97815..246362b4 100644 --- a/rust/sedona-expr/src/spatial_filter.rs +++ b/rust/sedona-expr/src/spatial_filter.rs @@ -177,7 +177,7 @@ impl SpatialFilter { let args = parse_args(raw_args); let fun_name = scalar_fun.fun().name(); match fun_name { - "st_intersects" | "st_equals" | "st_touches" | "st_crosses" | "st_overlaps" => { + "st_intersects" | "st_touches" | "st_crosses" | "st_overlaps" => { if args.len() != 2 { return sedona_internal_err!("unexpected argument count in filter evaluation"); } @@ -199,6 +199,28 @@ impl SpatialFilter { _ => Ok(Some(Self::Unknown)), } } + "st_equals" => { + if args.len() != 2 { + return sedona_internal_err!("unexpected argument count in filter evaluation"); + } + + match (&args[0], &args[1]) { + (ArgRef::Col(column), ArgRef::Lit(literal)) + | (ArgRef::Lit(literal), ArgRef::Col(column)) => { + if !is_prunable_geospatial_literal(literal) { + return Ok(Some(Self::Unknown)); + } + match literal_bounds(literal) { + Ok(literal_bounds) => { + Ok(Some(Self::Covers(column.clone(), literal_bounds))) + } + Err(e) => Err(DataFusionError::External(Box::new(e))), + } + } + // Not between a literal and a column + _ => Ok(Some(Self::Unknown)), + } + } "st_within" | "st_covered_by" | "st_coveredby" => { if args.len() != 2 { return sedona_internal_err!("unexpected argument count in filter evaluation"); @@ -575,15 +597,8 @@ mod test { } #[rstest] - fn predicate_from_expr_commutative_functions( - #[values( - "st_intersects", - "st_equals", - "st_touches", - "st_crosses", - "st_overlaps" - )] - func_name: &str, + fn predicate_from_expr_commutative_intersects_functions( + #[values("st_intersects", "st_touches", "st_crosses", "st_overlaps")] func_name: &str, ) { let column: Arc = Arc::new(Column::new("geometry", 0)); let storage_field = WKB_GEOMETRY.to_storage_field("", true).unwrap(); @@ -620,6 +635,43 @@ mod test { ); } + #[rstest] + fn predicate_from_expr_equals_function(#[values("st_equals")] func_name: &str) { + let column: Arc = Arc::new(Column::new("geometry", 0)); + let storage_field = WKB_GEOMETRY.to_storage_field("", true).unwrap(); + let literal: Arc = Arc::new(Literal::new_with_metadata( + create_scalar(Some("POLYGON ((0 0, 2 0, 2 2, 0 2, 0 0))"), &WKB_GEOMETRY), + Some(storage_field.metadata().into()), + )); + + // Test functions that should result in Covers filter + let func = create_dummy_spatial_function(func_name, 2); + let expr: Arc = Arc::new(ScalarFunctionExpr::new( + func_name, + Arc::new(func.clone()), + vec![column.clone(), literal.clone()], + Arc::new(Field::new("", DataType::Boolean, true)), + )); + let predicate = SpatialFilter::try_from_expr(&expr).unwrap(); + assert!( + matches!(predicate, SpatialFilter::Covers(_, _)), + "Function {func_name} should produce Covers filter" + ); + + // Test reversed argument order + let expr_reversed: Arc = Arc::new(ScalarFunctionExpr::new( + func_name, + Arc::new(func), + vec![literal.clone(), column.clone()], + Arc::new(Field::new("", DataType::Boolean, true)), + )); + let predicate_reversed = SpatialFilter::try_from_expr(&expr_reversed).unwrap(); + assert!( + matches!(predicate_reversed, SpatialFilter::Covers(_, _)), + "Function {func_name} with reversed args should produce Covers filter" + ); + } + #[rstest] fn predicate_from_expr_within_covered_by_functions( #[values("st_within", "st_covered_by", "st_coveredby")] func_name: &str,