From 3bea45f83086505b4e4b33e9fd21fece0b664d8a Mon Sep 17 00:00:00 2001 From: Jefffrey Date: Thu, 2 Oct 2025 15:13:28 +0900 Subject: [PATCH 1/2] chore: utilize trait upcasting for AsyncScalarUDF PartialEq & Hash --- datafusion/expr/src/async_udf.rs | 133 +++++++++++++++++++++++++++++-- 1 file changed, 127 insertions(+), 6 deletions(-) diff --git a/datafusion/expr/src/async_udf.rs b/datafusion/expr/src/async_udf.rs index a67738ac7b79..6f4ffedbda57 100644 --- a/datafusion/expr/src/async_udf.rs +++ b/datafusion/expr/src/async_udf.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use crate::ptr_eq::{arc_ptr_eq, arc_ptr_hash}; use crate::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; use arrow::datatypes::{DataType, FieldRef}; use async_trait::async_trait; @@ -62,17 +61,14 @@ pub struct AsyncScalarUDF { impl PartialEq for AsyncScalarUDF { fn eq(&self, other: &Self) -> bool { - let Self { inner } = self; - // TODO when MSRV >= 1.86.0, switch to `inner.equals(other.inner.as_ref())` leveraging trait upcasting. - arc_ptr_eq(inner, &other.inner) + self.inner.dyn_eq(other.inner.as_any()) } } impl Eq for AsyncScalarUDF {} impl Hash for AsyncScalarUDF { fn hash(&self, state: &mut H) { - let Self { inner } = self; - arc_ptr_hash(inner, state); + self.inner.dyn_hash(state); } } @@ -132,3 +128,128 @@ impl Display for AsyncScalarUDF { write!(f, "AsyncScalarUDF: {}", self.inner.name()) } } + +#[cfg(test)] +mod tests { + use std::{collections::HashSet, sync::Arc}; + + use arrow::datatypes::DataType; + use async_trait::async_trait; + use datafusion_common::error::Result; + use datafusion_expr_common::{columnar_value::ColumnarValue, signature::Signature}; + + use crate::{ + async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl}, + ScalarFunctionArgs, ScalarUDFImpl, + }; + + #[derive(Debug, PartialEq, Eq, Hash, Clone)] + struct TestAsyncUDFImpl1 { + a: i32, + } + + impl ScalarUDFImpl for TestAsyncUDFImpl1 { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + todo!() + } + + fn signature(&self) -> &Signature { + todo!() + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + todo!() + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + todo!() + } + } + + #[async_trait] + impl AsyncScalarUDFImpl for TestAsyncUDFImpl1 { + async fn invoke_async_with_args( + &self, + _args: ScalarFunctionArgs, + ) -> Result { + todo!() + } + } + + #[derive(Debug, PartialEq, Eq, Hash, Clone)] + struct TestAsyncUDFImpl2 { + a: i32, + } + + impl ScalarUDFImpl for TestAsyncUDFImpl2 { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + todo!() + } + + fn signature(&self) -> &Signature { + todo!() + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + todo!() + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + todo!() + } + } + + #[async_trait] + impl AsyncScalarUDFImpl for TestAsyncUDFImpl2 { + async fn invoke_async_with_args( + &self, + _args: ScalarFunctionArgs, + ) -> Result { + todo!() + } + } + + #[test] + fn udf_equality_and_hash() { + // Inner is same cloned arc -> equal + let inner = Arc::new(TestAsyncUDFImpl1 { a: 1 }); + let a = AsyncScalarUDF::new(Arc::clone(&inner) as Arc); + let b = AsyncScalarUDF::new(inner); + assert_eq!(a, b); + let mut set = HashSet::new(); + set.insert(a); + assert!(set.contains(&b)); + + // Inner is distinct arc -> still equal + let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 })); + let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 })); + assert_eq!(a, b); + let mut set = HashSet::new(); + set.insert(a); + assert!(set.contains(&b)); + + // Negative case: inner is different value -> not equal + let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 })); + let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 2 })); + assert_ne!(a, b); + let mut set = HashSet::new(); + set.insert(a); + assert!(!set.contains(&b)); + + // Negative case: different functions -> not equal + let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 })); + let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl2 { a: 1 })); + assert_ne!(a, b); + let mut set = HashSet::new(); + set.insert(a); + assert!(!set.contains(&b)); + } +} From a25019f59ca018b1340a650e9bec082773c92d80 Mon Sep 17 00:00:00 2001 From: Jefffrey Date: Fri, 3 Oct 2025 11:11:41 +0900 Subject: [PATCH 2/2] review comments --- datafusion/expr/src/async_udf.rs | 37 ++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/datafusion/expr/src/async_udf.rs b/datafusion/expr/src/async_udf.rs index 6f4ffedbda57..561ef1dc15e7 100644 --- a/datafusion/expr/src/async_udf.rs +++ b/datafusion/expr/src/async_udf.rs @@ -61,14 +61,18 @@ pub struct AsyncScalarUDF { impl PartialEq for AsyncScalarUDF { fn eq(&self, other: &Self) -> bool { - self.inner.dyn_eq(other.inner.as_any()) + // Deconstruct to catch any new fields added in future + let Self { inner } = self; + inner.dyn_eq(other.inner.as_any()) } } impl Eq for AsyncScalarUDF {} impl Hash for AsyncScalarUDF { fn hash(&self, state: &mut H) { - self.inner.dyn_hash(state); + // Deconstruct to catch any new fields added in future + let Self { inner } = self; + inner.dyn_hash(state); } } @@ -131,7 +135,10 @@ impl Display for AsyncScalarUDF { #[cfg(test)] mod tests { - use std::{collections::HashSet, sync::Arc}; + use std::{ + hash::{DefaultHasher, Hash, Hasher}, + sync::Arc, + }; use arrow::datatypes::DataType; use async_trait::async_trait; @@ -217,39 +224,37 @@ mod tests { } } + fn hash(value: &T) -> u64 { + let hasher = &mut DefaultHasher::new(); + value.hash(hasher); + hasher.finish() + } + #[test] - fn udf_equality_and_hash() { + fn test_async_udf_partial_eq_and_hash() { // Inner is same cloned arc -> equal let inner = Arc::new(TestAsyncUDFImpl1 { a: 1 }); let a = AsyncScalarUDF::new(Arc::clone(&inner) as Arc); let b = AsyncScalarUDF::new(inner); assert_eq!(a, b); - let mut set = HashSet::new(); - set.insert(a); - assert!(set.contains(&b)); + assert_eq!(hash(&a), hash(&b)); // Inner is distinct arc -> still equal let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 })); let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 })); assert_eq!(a, b); - let mut set = HashSet::new(); - set.insert(a); - assert!(set.contains(&b)); + assert_eq!(hash(&a), hash(&b)); // Negative case: inner is different value -> not equal let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 })); let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 2 })); assert_ne!(a, b); - let mut set = HashSet::new(); - set.insert(a); - assert!(!set.contains(&b)); + assert_ne!(hash(&a), hash(&b)); // Negative case: different functions -> not equal let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 })); let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl2 { a: 1 })); assert_ne!(a, b); - let mut set = HashSet::new(); - set.insert(a); - assert!(!set.contains(&b)); + assert_ne!(hash(&a), hash(&b)); } }