diff --git a/python/sedonadb/tests/functions/test_functions.py b/python/sedonadb/tests/functions/test_functions.py index 21b715c4e..6b5a42ec2 100644 --- a/python/sedonadb/tests/functions/test_functions.py +++ b/python/sedonadb/tests/functions/test_functions.py @@ -1313,6 +1313,25 @@ def test_st_point(eng, x, y, expected): ) +@pytest.mark.parametrize("eng", [SedonaDB, PostGIS]) +@pytest.mark.parametrize( + ("x", "y", "srid", "expected"), + [ + (None, None, None, None), + (1, 1, None, None), + (1, 1, 0, 0), + (1, 1, 4326, 4326), + (1, 1, "4326", 4326), + ], +) +def test_st_point_with_srid(eng, x, y, srid, expected): + eng = eng.create_or_skip() + eng.assert_query_result( + f"SELECT ST_SRID(ST_Point({val_or_null(x)}, {val_or_null(y)}, {val_or_null(srid)}))", + expected, + ) + + @pytest.mark.parametrize("eng", [SedonaDB, PostGIS]) @pytest.mark.parametrize( ("x", "y", "z", "expected"), diff --git a/rust/sedona-functions/src/st_point.rs b/rust/sedona-functions/src/st_point.rs index 42f893ae0..aedf5e346 100644 --- a/rust/sedona-functions/src/st_point.rs +++ b/rust/sedona-functions/src/st_point.rs @@ -30,18 +30,21 @@ use sedona_schema::{ matchers::ArgMatcher, }; -use crate::executor::WkbExecutor; +use crate::{executor::WkbExecutor, st_setsrid::SRIDifiedKernel}; /// ST_Point() scalar UDF implementation /// /// Native implementation to create geometries from coordinates. /// See [`st_geogpoint_udf`] for the corresponding geography constructor. pub fn st_point_udf() -> SedonaScalarUDF { + let kernel = Arc::new(STGeoFromPoint { + out_type: WKB_GEOMETRY, + }); + let sridified_kernel = Arc::new(SRIDifiedKernel::new(kernel.clone())); + SedonaScalarUDF::new( "st_point", - vec![Arc::new(STGeoFromPoint { - out_type: WKB_GEOMETRY, - })], + vec![sridified_kernel, kernel], Volatility::Immutable, Some(doc("ST_Point", "Geometry")), ) @@ -52,11 +55,14 @@ pub fn st_point_udf() -> SedonaScalarUDF { /// Native implementation to create geometries from coordinates. /// See [`st_geogpoint_udf`] for the corresponding geography constructor. pub fn st_geogpoint_udf() -> SedonaScalarUDF { + let kernel = Arc::new(STGeoFromPoint { + out_type: WKB_GEOGRAPHY, + }); + let sridified_kernel = Arc::new(SRIDifiedKernel::new(kernel.clone())); + SedonaScalarUDF::new( "st_geogpoint", - vec![Arc::new(STGeoFromPoint { - out_type: WKB_GEOGRAPHY, - })], + vec![sridified_kernel, kernel], Volatility::Immutable, Some(doc("st_geogpoint", "Geography")), ) @@ -73,6 +79,7 @@ fn doc(name: &str, out_type_name: &str) -> Documentation { ) .with_argument("x", "double: X value") .with_argument("y", "double: Y value") + .with_argument("srid", "srid: EPSG code to set (e.g., 4326)") .with_sql_example(format!("{name}(-64.36, 45.09)")) .build() } @@ -157,8 +164,11 @@ mod tests { use arrow_array::create_array; use arrow_array::ArrayRef; use arrow_schema::DataType; + use datafusion_expr::Literal; use datafusion_expr::ScalarUDF; use rstest::rstest; + use sedona_schema::crs::lnglat; + use sedona_schema::datatypes::Edges; use sedona_testing::compare::assert_array_equal; use sedona_testing::{create::create_array, testers::ScalarUdfTester}; @@ -247,6 +257,56 @@ mod tests { ); } + #[rstest] + #[case(DataType::UInt32, 4326)] + #[case(DataType::Int32, 4326)] + #[case(DataType::Utf8, "4326")] + #[case(DataType::Utf8, "EPSG:4326")] + fn udf_invoke_with_srid(#[case] srid_type: DataType, #[case] srid_value: impl Literal + Copy) { + let udf = st_point_udf(); + let tester = ScalarUdfTester::new( + udf.into(), + vec![ + SedonaType::Arrow(DataType::Float64), + SedonaType::Arrow(DataType::Float64), + SedonaType::Arrow(srid_type), + ], + ); + + let return_type = tester + .return_type_with_scalar_scalar_scalar(Some(1.0), Some(2.0), Some(srid_value)) + .unwrap(); + assert_eq!(return_type, SedonaType::Wkb(Edges::Planar, lnglat())); + + let result = tester + .invoke_scalar_scalar_scalar(1.0, 2.0, srid_value) + .unwrap(); + tester.assert_scalar_result_equals_with_return_type(result, "POINT (1 2)", return_type); + } + + #[test] + fn udf_invoke_with_invalid_srid() { + let udf = st_point_udf(); + let tester = ScalarUdfTester::new( + udf.into(), + vec![ + SedonaType::Arrow(DataType::Float64), + SedonaType::Arrow(DataType::Float64), + SedonaType::Arrow(DataType::Utf8), + ], + ); + + let return_type = tester.return_type_with_scalar_scalar_scalar( + Some(1.0), + Some(2.0), + Some("gazornenplat"), + ); + assert!(return_type.is_err()); + + let result = tester.invoke_scalar_scalar_scalar(1.0, 2.0, "gazornenplat"); + assert!(result.is_err()); + } + #[test] fn geog() { let udf = st_geogpoint_udf(); diff --git a/rust/sedona-functions/src/st_setsrid.rs b/rust/sedona-functions/src/st_setsrid.rs index 68d916eb4..591b7e9c9 100644 --- a/rust/sedona-functions/src/st_setsrid.rs +++ b/rust/sedona-functions/src/st_setsrid.rs @@ -16,13 +16,14 @@ // under the License. use std::{sync::Arc, vec}; +use arrow_array::builder::BinaryBuilder; use arrow_schema::DataType; use datafusion_common::{error::Result, DataFusionError, ScalarValue}; use datafusion_expr::{ scalar_doc_sections::DOC_SECTION_OTHER, ColumnarValue, Documentation, Volatility, }; use sedona_common::sedona_internal_err; -use sedona_expr::scalar_udf::{SedonaScalarKernel, SedonaScalarUDF}; +use sedona_expr::scalar_udf::{ScalarKernelRef, SedonaScalarKernel, SedonaScalarUDF}; use sedona_geometry::transform::CrsEngine; use sedona_schema::{crs::deserialize_crs, datatypes::SedonaType, matchers::ArgMatcher}; @@ -227,6 +228,119 @@ fn determine_return_type( sedona_internal_err!("Unexpected argument types: {}, {}", args[0], args[1]) } +/// [SedonaScalarKernel] wrapper that handles the SRID argument for constructors like ST_Point +#[derive(Debug)] +pub(crate) struct SRIDifiedKernel { + inner: ScalarKernelRef, +} + +impl SRIDifiedKernel { + pub(crate) fn new(inner: ScalarKernelRef) -> Self { + Self { inner } + } +} + +impl SedonaScalarKernel for SRIDifiedKernel { + fn return_type_from_args_and_scalars( + &self, + args: &[SedonaType], + scalar_args: &[Option<&ScalarValue>], + ) -> Result> { + // args should consist of the original args and one extra arg for + // specifying CRS. So, first, validate the length and separate these. + // + // [arg0, arg1, ..., crs_arg]; + // ^^^^^^^^^^^^^^^ + // orig_args + let orig_args_len = match (args.len(), scalar_args.len()) { + (0, 0) => return Ok(None), + (l1, l2) if l1 == l2 => l1 - 1, + _ => return sedona_internal_err!("Arg types and arg values have different lengths"), + }; + + let orig_args = &args[..orig_args_len]; + let orig_scalar_args = &scalar_args[..orig_args_len]; + + // Invoke the original return_type_from_args_and_scalars() first before checking the CRS argument + let mut inner_result = match self + .inner + .return_type_from_args_and_scalars(orig_args, orig_scalar_args)? + { + Some(sedona_type) => sedona_type, + // if no match, quit here. Since the CRS arg is also an unintended + // one, validating it would be a cryptic error to the user. + None => return Ok(None), + }; + + let crs = match scalar_args[orig_args_len] { + Some(crs) => crs, + None => return Ok(None), + }; + let new_crs = match crs.cast_to(&DataType::Utf8) { + Ok(ScalarValue::Utf8(Some(crs))) => { + if crs == "0" { + None + } else { + validate_crs(&crs, None)?; + deserialize_crs(&serde_json::Value::String(crs))? + } + } + Ok(ScalarValue::Utf8(None)) => None, + Ok(_) | Err(_) => return sedona_internal_err!("Can't cast Crs {crs:?} to Utf8"), + }; + + match &mut inner_result { + SedonaType::Wkb(_, crs) => *crs = new_crs, + SedonaType::WkbView(_, crs) => *crs = new_crs, + _ => { + return sedona_internal_err!("Return type must be Wkb or WkbView"); + } + } + + Ok(Some(inner_result)) + } + + fn invoke_batch( + &self, + arg_types: &[SedonaType], + args: &[ColumnarValue], + ) -> Result { + let orig_args_len = arg_types.len() - 1; + let orig_arg_types = &arg_types[..orig_args_len]; + let orig_args = &args[..orig_args_len]; + + // Invoke the inner UDF first to propagate any errors even when the CRS is NULL. + // Note that, this behavior is different from PostGIS. + let result = self.inner.invoke_batch(orig_arg_types, orig_args)?; + + // If the specified SRID is NULL, the result is also NULL. + if let ColumnarValue::Scalar(sc) = &args[orig_args_len] { + if sc.is_null() { + // Create the same length of NULLs as the original result. + let len = match &result { + ColumnarValue::Array(array) => array.len(), + ColumnarValue::Scalar(_) => 1, + }; + + let mut builder = BinaryBuilder::with_capacity(len, 0); + for _ in 0..len { + builder.append_null(); + } + let new_array = builder.finish(); + return Ok(ColumnarValue::Array(Arc::new(new_array))); + } + } + + Ok(result) + } + + fn return_type(&self, _args: &[SedonaType]) -> Result> { + sedona_internal_err!( + "Should not be called because return_type_from_args_and_scalars() is implemented" + ) + } +} + #[cfg(test)] mod test { use std::rc::Rc; diff --git a/rust/sedona-testing/src/testers.rs b/rust/sedona-testing/src/testers.rs index b6580a0e7..fdded3b3a 100644 --- a/rust/sedona-testing/src/testers.rs +++ b/rust/sedona-testing/src/testers.rs @@ -184,7 +184,28 @@ impl ScalarUdfTester { /// Both actual and expected are interpreted according to the calculated /// return type (notably, WKT is interpreted as geometry or geography output). pub fn assert_scalar_result_equals(&self, actual: impl Literal, expected: impl Literal) { - let return_type = self.return_type().unwrap(); + self.assert_scalar_result_equals_inner(actual, expected, None); + } + + /// Assert the result of invoking this function with the return type specified + /// + /// This is for UDFs implementing `SedonaScalarKernel::return_type_from_args_and_scalars()`. + pub fn assert_scalar_result_equals_with_return_type( + &self, + actual: impl Literal, + expected: impl Literal, + return_type: SedonaType, + ) { + self.assert_scalar_result_equals_inner(actual, expected, Some(return_type)); + } + + fn assert_scalar_result_equals_inner( + &self, + actual: impl Literal, + expected: impl Literal, + return_type: Option, + ) { + let return_type = return_type.unwrap_or_else(|| self.return_type().unwrap()); let actual = Self::scalar_lit(actual, &return_type).unwrap(); let expected = Self::scalar_lit(expected, &return_type).unwrap(); assert_scalar_equal(&actual, &expected); @@ -192,16 +213,72 @@ impl ScalarUdfTester { /// Compute the return type pub fn return_type(&self) -> Result { + let scalar_arguments = vec![None; self.arg_types.len()]; + self.return_type_with_scalars_inner(&scalar_arguments) + } + + /// Compute the return type from one scalar argument + /// + /// This is for UDFs implementing `SedonaScalarKernel::return_type_from_args_and_scalars()`. + pub fn return_type_with_scalar(&self, arg0: Option) -> Result { + let scalar_arguments = vec![arg0 + .map(|x| Self::scalar_lit(x, &self.arg_types[0])) + .transpose()?]; + self.return_type_with_scalars_inner(&scalar_arguments) + } + + /// Compute the return type from two scalar arguments + /// + /// This is for UDFs implementing `SedonaScalarKernel::return_type_from_args_and_scalars()`. + pub fn return_type_with_scalar_scalar( + &self, + arg0: Option, + arg1: Option, + ) -> Result { + let scalar_arguments = vec![ + arg0.map(|x| Self::scalar_lit(x, &self.arg_types[0])) + .transpose()?, + arg1.map(|x| Self::scalar_lit(x, &self.arg_types[1])) + .transpose()?, + ]; + self.return_type_with_scalars_inner(&scalar_arguments) + } + + /// Compute the return type from three scalar arguments + /// + /// This is for UDFs implementing `SedonaScalarKernel::return_type_from_args_and_scalars()`. + pub fn return_type_with_scalar_scalar_scalar( + &self, + arg0: Option, + arg1: Option, + arg2: Option, + ) -> Result { + let scalar_arguments = vec![ + arg0.map(|x| Self::scalar_lit(x, &self.arg_types[0])) + .transpose()?, + arg1.map(|x| Self::scalar_lit(x, &self.arg_types[1])) + .transpose()?, + arg2.map(|x| Self::scalar_lit(x, &self.arg_types[2])) + .transpose()?, + ]; + self.return_type_with_scalars_inner(&scalar_arguments) + } + + fn return_type_with_scalars_inner( + &self, + scalar_arguments: &[Option], + ) -> Result { let arg_fields = self .arg_types .iter() .map(|sedona_type| sedona_type.to_storage_field("", true).map(Arc::new)) .collect::>>()?; - let scalar_arguments = (0..arg_fields.len()).map(|_| None).collect::>(); + let scalar_arguments_ref: Vec> = + scalar_arguments.iter().map(|x| x.as_ref()).collect(); let args = ReturnFieldArgs { arg_fields: &arg_fields, - scalar_arguments: &scalar_arguments, + scalar_arguments: &scalar_arguments_ref, }; let return_field = self.udf.return_field_from_args(args)?; SedonaType::from_storage_field(&return_field) @@ -209,9 +286,15 @@ impl ScalarUdfTester { /// Invoke this function with a scalar pub fn invoke_scalar(&self, arg: impl Literal) -> Result { - let args = vec![Self::scalar_arg(arg, &self.arg_types[0])?]; + let scalar_arg = Self::scalar_lit(arg, &self.arg_types[0])?; - if let ColumnarValue::Scalar(scalar) = self.invoke(args)? { + // Some UDF calculate the return type from the input scalar arguments, so try it first. + let return_type = self + .return_type_with_scalars_inner(&[Some(scalar_arg.clone())]) + .ok(); + + let args = vec![ColumnarValue::Scalar(scalar_arg)]; + if let ColumnarValue::Scalar(scalar) = self.invoke_with_return_type(args, return_type)? { Ok(scalar) } else { sedona_internal_err!("Expected scalar result from scalar invoke") @@ -229,12 +312,19 @@ impl ScalarUdfTester { arg0: T0, arg1: T1, ) -> Result { + let scalar_arg0 = Self::scalar_lit(arg0, &self.arg_types[0])?; + let scalar_arg1 = Self::scalar_lit(arg1, &self.arg_types[1])?; + + // Some UDF calculate the return type from the input scalar arguments, so try it first. + let return_type = self + .return_type_with_scalars_inner(&[Some(scalar_arg0.clone()), Some(scalar_arg1.clone())]) + .ok(); + let args = vec![ - Self::scalar_arg(arg0, &self.arg_types[0])?, - Self::scalar_arg(arg1, &self.arg_types[1])?, + ColumnarValue::Scalar(scalar_arg0), + ColumnarValue::Scalar(scalar_arg1), ]; - - if let ColumnarValue::Scalar(scalar) = self.invoke(args)? { + if let ColumnarValue::Scalar(scalar) = self.invoke_with_return_type(args, return_type)? { Ok(scalar) } else { sedona_internal_err!("Expected scalar result from binary scalar invoke") @@ -248,13 +338,25 @@ impl ScalarUdfTester { arg1: T1, arg2: T2, ) -> Result { + let scalar_arg0 = Self::scalar_lit(arg0, &self.arg_types[0])?; + let scalar_arg1 = Self::scalar_lit(arg1, &self.arg_types[1])?; + let scalar_arg2 = Self::scalar_lit(arg2, &self.arg_types[2])?; + + // Some UDF calculate the return type from the input scalar arguments, so try it first. + let return_type = self + .return_type_with_scalars_inner(&[ + Some(scalar_arg0.clone()), + Some(scalar_arg1.clone()), + Some(scalar_arg2.clone()), + ]) + .ok(); + let args = vec![ - Self::scalar_arg(arg0, &self.arg_types[0])?, - Self::scalar_arg(arg1, &self.arg_types[1])?, - Self::scalar_arg(arg2, &self.arg_types[2])?, + ColumnarValue::Scalar(scalar_arg0), + ColumnarValue::Scalar(scalar_arg1), + ColumnarValue::Scalar(scalar_arg2), ]; - - if let ColumnarValue::Scalar(scalar) = self.invoke(args)? { + if let ColumnarValue::Scalar(scalar) = self.invoke_with_return_type(args, return_type)? { Ok(scalar) } else { sedona_internal_err!("Expected scalar result from binary scalar invoke") @@ -386,6 +488,13 @@ impl ScalarUdfTester { } pub fn invoke(&self, args: Vec) -> Result { + self.invoke_with_return_type(args, None) + } + pub fn invoke_with_return_type( + &self, + args: Vec, + return_type: Option, + ) -> Result { assert_eq!(args.len(), self.arg_types.len(), "Unexpected arg length"); let mut number_rows = 1; @@ -399,11 +508,16 @@ impl ScalarUdfTester { } } + let return_type = match return_type { + Some(return_type) => return_type, + None => self.return_type()?, + }; + let args = ScalarFunctionArgs { args, arg_fields: self.arg_fields(), number_rows, - return_field: self.return_type()?.to_storage_field("", true)?.into(), + return_field: return_type.to_storage_field("", true)?.into(), // TODO: Consider piping actual ConfigOptions for more realistic testing // See: https://github.com/apache/sedona-db/issues/248 config_options: Arc::new(ConfigOptions::default()),