diff --git a/datafusion/expr/src/async_udf.rs b/datafusion/expr/src/async_udf.rs index a67738ac7b79..561ef1dc15e7 100644 --- a/datafusion/expr/src/async_udf.rs +++ b/datafusion/expr/src/async_udf.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use crate::ptr_eq::{arc_ptr_eq, arc_ptr_hash}; use crate::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; use arrow::datatypes::{DataType, FieldRef}; use async_trait::async_trait; @@ -62,17 +61,18 @@ pub struct AsyncScalarUDF { impl PartialEq for AsyncScalarUDF { fn eq(&self, other: &Self) -> bool { + // Deconstruct to catch any new fields added in future let Self { inner } = self; - // TODO when MSRV >= 1.86.0, switch to `inner.equals(other.inner.as_ref())` leveraging trait upcasting. - arc_ptr_eq(inner, &other.inner) + inner.dyn_eq(other.inner.as_any()) } } impl Eq for AsyncScalarUDF {} impl Hash for AsyncScalarUDF { fn hash(&self, state: &mut H) { + // Deconstruct to catch any new fields added in future let Self { inner } = self; - arc_ptr_hash(inner, state); + inner.dyn_hash(state); } } @@ -132,3 +132,129 @@ impl Display for AsyncScalarUDF { write!(f, "AsyncScalarUDF: {}", self.inner.name()) } } + +#[cfg(test)] +mod tests { + use std::{ + hash::{DefaultHasher, Hash, Hasher}, + sync::Arc, + }; + + use arrow::datatypes::DataType; + use async_trait::async_trait; + use datafusion_common::error::Result; + use datafusion_expr_common::{columnar_value::ColumnarValue, signature::Signature}; + + use crate::{ + async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl}, + ScalarFunctionArgs, ScalarUDFImpl, + }; + + #[derive(Debug, PartialEq, Eq, Hash, Clone)] + struct TestAsyncUDFImpl1 { + a: i32, + } + + impl ScalarUDFImpl for TestAsyncUDFImpl1 { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + todo!() + } + + fn signature(&self) -> &Signature { + todo!() + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + todo!() + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + todo!() + } + } + + #[async_trait] + impl AsyncScalarUDFImpl for TestAsyncUDFImpl1 { + async fn invoke_async_with_args( + &self, + _args: ScalarFunctionArgs, + ) -> Result { + todo!() + } + } + + #[derive(Debug, PartialEq, Eq, Hash, Clone)] + struct TestAsyncUDFImpl2 { + a: i32, + } + + impl ScalarUDFImpl for TestAsyncUDFImpl2 { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + todo!() + } + + fn signature(&self) -> &Signature { + todo!() + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + todo!() + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + todo!() + } + } + + #[async_trait] + impl AsyncScalarUDFImpl for TestAsyncUDFImpl2 { + async fn invoke_async_with_args( + &self, + _args: ScalarFunctionArgs, + ) -> Result { + todo!() + } + } + + fn hash(value: &T) -> u64 { + let hasher = &mut DefaultHasher::new(); + value.hash(hasher); + hasher.finish() + } + + #[test] + fn test_async_udf_partial_eq_and_hash() { + // Inner is same cloned arc -> equal + let inner = Arc::new(TestAsyncUDFImpl1 { a: 1 }); + let a = AsyncScalarUDF::new(Arc::clone(&inner) as Arc); + let b = AsyncScalarUDF::new(inner); + assert_eq!(a, b); + assert_eq!(hash(&a), hash(&b)); + + // Inner is distinct arc -> still equal + let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 })); + let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 })); + assert_eq!(a, b); + assert_eq!(hash(&a), hash(&b)); + + // Negative case: inner is different value -> not equal + let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 })); + let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 2 })); + assert_ne!(a, b); + assert_ne!(hash(&a), hash(&b)); + + // Negative case: different functions -> not equal + let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 })); + let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl2 { a: 1 })); + assert_ne!(a, b); + assert_ne!(hash(&a), hash(&b)); + } +}