Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 130 additions & 4 deletions datafusion/expr/src/async_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
// specific language governing permissions and limitations
// under the License.

use crate::ptr_eq::{arc_ptr_eq, arc_ptr_hash};
use crate::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl};
use arrow::datatypes::{DataType, FieldRef};
use async_trait::async_trait;
Expand Down Expand Up @@ -62,17 +61,18 @@ pub struct AsyncScalarUDF {

impl PartialEq for AsyncScalarUDF {
fn eq(&self, other: &Self) -> bool {
// Deconstruct to catch any new fields added in future
let Self { inner } = self;
// TODO when MSRV >= 1.86.0, switch to `inner.equals(other.inner.as_ref())` leveraging trait upcasting.
arc_ptr_eq(inner, &other.inner)
inner.dyn_eq(other.inner.as_any())
}
}
impl Eq for AsyncScalarUDF {}

impl Hash for AsyncScalarUDF {
fn hash<H: Hasher>(&self, state: &mut H) {
// Deconstruct to catch any new fields added in future
let Self { inner } = self;
arc_ptr_hash(inner, state);
inner.dyn_hash(state);
}
}

Expand Down Expand Up @@ -132,3 +132,129 @@ impl Display for AsyncScalarUDF {
write!(f, "AsyncScalarUDF: {}", self.inner.name())
}
}

#[cfg(test)]
mod tests {
use std::{
hash::{DefaultHasher, Hash, Hasher},
sync::Arc,
};

use arrow::datatypes::DataType;
use async_trait::async_trait;
use datafusion_common::error::Result;
use datafusion_expr_common::{columnar_value::ColumnarValue, signature::Signature};

use crate::{
async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl},
ScalarFunctionArgs, ScalarUDFImpl,
};

#[derive(Debug, PartialEq, Eq, Hash, Clone)]
struct TestAsyncUDFImpl1 {
a: i32,
}

impl ScalarUDFImpl for TestAsyncUDFImpl1 {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
todo!()
}

fn signature(&self) -> &Signature {
todo!()
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
todo!()
}

fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
todo!()
}
}

#[async_trait]
impl AsyncScalarUDFImpl for TestAsyncUDFImpl1 {
async fn invoke_async_with_args(
&self,
_args: ScalarFunctionArgs,
) -> Result<ColumnarValue> {
todo!()
}
}

#[derive(Debug, PartialEq, Eq, Hash, Clone)]
struct TestAsyncUDFImpl2 {
a: i32,
}

impl ScalarUDFImpl for TestAsyncUDFImpl2 {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
todo!()
}

fn signature(&self) -> &Signature {
todo!()
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
todo!()
}

fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
todo!()
}
}

#[async_trait]
impl AsyncScalarUDFImpl for TestAsyncUDFImpl2 {
async fn invoke_async_with_args(
&self,
_args: ScalarFunctionArgs,
) -> Result<ColumnarValue> {
todo!()
}
}

fn hash<T: Hash>(value: &T) -> u64 {
let hasher = &mut DefaultHasher::new();
value.hash(hasher);
hasher.finish()
}

#[test]
fn test_async_udf_partial_eq_and_hash() {
// Inner is same cloned arc -> equal
let inner = Arc::new(TestAsyncUDFImpl1 { a: 1 });
let a = AsyncScalarUDF::new(Arc::clone(&inner) as Arc<dyn AsyncScalarUDFImpl>);
let b = AsyncScalarUDF::new(inner);
assert_eq!(a, b);
assert_eq!(hash(&a), hash(&b));

// Inner is distinct arc -> still equal
let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 }));
let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 }));
assert_eq!(a, b);
assert_eq!(hash(&a), hash(&b));

// Negative case: inner is different value -> not equal
let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 }));
let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 2 }));
assert_ne!(a, b);
assert_ne!(hash(&a), hash(&b));

// Negative case: different functions -> not equal
let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 }));
let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl2 { a: 1 }));
assert_ne!(a, b);
assert_ne!(hash(&a), hash(&b));
}
}