Skip to content

Commit

Permalink
Support Serde for ScalarUDF in Physical Expressions (#9436)
Browse files Browse the repository at this point in the history
* initial try

* revert

* stage commit

* use ScalarFunctionDefinition to rewrite PhysicalExpr proto

* cargo fmt

* feat : add test

* fix bug

* fix wrong delete code when resolve conflict

* Update datafusion/proto/src/physical_plan/to_proto.rs

Co-authored-by: Dan Harris <1327726+thinkharderdev@users.noreply.github.com>

* Update datafusion/proto/tests/cases/roundtrip_physical_plan.rs

Co-authored-by: Dan Harris <1327726+thinkharderdev@users.noreply.github.com>

* address the comment

---------

Co-authored-by: Dan Harris <1327726+thinkharderdev@users.noreply.github.com>
  • Loading branch information
yyy1000 and thinkharderdev authored Mar 19, 2024
1 parent 7fab5ac commit 0974759
Show file tree
Hide file tree
Showing 11 changed files with 634 additions and 264 deletions.
58 changes: 53 additions & 5 deletions datafusion/core/src/physical_optimizer/projection_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1287,6 +1287,7 @@ fn new_join_children(
#[cfg(test)]
mod tests {
use super::*;
use std::any::Any;
use std::sync::Arc;

use crate::datasource::file_format::file_compression_type::FileCompressionType;
Expand All @@ -1313,7 +1314,10 @@ mod tests {
use datafusion_common::{JoinSide, JoinType, Result, ScalarValue, Statistics};
use datafusion_execution::object_store::ObjectStoreUrl;
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_expr::{ColumnarValue, Operator};
use datafusion_expr::{
ColumnarValue, Operator, ScalarFunctionDefinition, ScalarUDF, ScalarUDFImpl,
Signature, Volatility,
};
use datafusion_physical_expr::expressions::{
BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr,
};
Expand All @@ -1329,6 +1333,42 @@ mod tests {

use itertools::Itertools;

/// Mocked UDF
#[derive(Debug)]
struct DummyUDF {
signature: Signature,
}

impl DummyUDF {
fn new() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
}
}
}

impl ScalarUDFImpl for DummyUDF {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"dummy_udf"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Int32)
}

fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
unimplemented!("DummyUDF::invoke")
}
}

#[test]
fn test_update_matching_exprs() -> Result<()> {
let exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
Expand All @@ -1345,7 +1385,9 @@ mod tests {
Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))),
Arc::new(ScalarFunctionExpr::new(
"scalar_expr",
Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")),
ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl(
DummyUDF::new(),
))),
vec![
Arc::new(BinaryExpr::new(
Arc::new(Column::new("b", 1)),
Expand Down Expand Up @@ -1412,7 +1454,9 @@ mod tests {
Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 5)))),
Arc::new(ScalarFunctionExpr::new(
"scalar_expr",
Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")),
ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl(
DummyUDF::new(),
))),
vec![
Arc::new(BinaryExpr::new(
Arc::new(Column::new("b", 1)),
Expand Down Expand Up @@ -1482,7 +1526,9 @@ mod tests {
Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))),
Arc::new(ScalarFunctionExpr::new(
"scalar_expr",
Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")),
ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl(
DummyUDF::new(),
))),
vec![
Arc::new(BinaryExpr::new(
Arc::new(Column::new("b", 1)),
Expand Down Expand Up @@ -1549,7 +1595,9 @@ mod tests {
Arc::new(NegativeExpr::new(Arc::new(Column::new("f_new", 5)))),
Arc::new(ScalarFunctionExpr::new(
"scalar_expr",
Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")),
ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl(
DummyUDF::new(),
))),
vec![
Arc::new(BinaryExpr::new(
Arc::new(Column::new("b_new", 1)),
Expand Down
10 changes: 4 additions & 6 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ use arrow_array::Array;
use datafusion_common::{exec_err, Result, ScalarValue};
use datafusion_expr::execution_props::ExecutionProps;
pub use datafusion_expr::FuncMonotonicity;
use datafusion_expr::ScalarFunctionDefinition;
use datafusion_expr::{
type_coercion::functions::data_types, BuiltinScalarFunction, ColumnarValue,
ScalarFunctionImplementation,
Expand All @@ -57,7 +58,7 @@ pub fn create_physical_expr(
fun: &BuiltinScalarFunction,
input_phy_exprs: &[Arc<dyn PhysicalExpr>],
input_schema: &Schema,
execution_props: &ExecutionProps,
_execution_props: &ExecutionProps,
) -> Result<Arc<dyn PhysicalExpr>> {
let input_expr_types = input_phy_exprs
.iter()
Expand All @@ -69,14 +70,12 @@ pub fn create_physical_expr(

let data_type = fun.return_type(&input_expr_types)?;

let fun_expr: ScalarFunctionImplementation =
create_physical_fun(fun, execution_props)?;

let monotonicity = fun.monotonicity();

let fun_def = ScalarFunctionDefinition::BuiltIn(*fun);
Ok(Arc::new(ScalarFunctionExpr::new(
&format!("{fun}"),
fun_expr,
fun_def,
input_phy_exprs.to_vec(),
data_type,
monotonicity,
Expand Down Expand Up @@ -195,7 +194,6 @@ where
/// Create a physical scalar function.
pub fn create_physical_fun(
fun: &BuiltinScalarFunction,
_execution_props: &ExecutionProps,
) -> Result<ScalarFunctionImplementation> {
Ok(match fun {
// math functions
Expand Down
26 changes: 18 additions & 8 deletions datafusion/physical-expr/src/scalar_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,22 @@ use std::fmt::{self, Debug, Formatter};
use std::hash::{Hash, Hasher};
use std::sync::Arc;

use crate::functions::out_ordering;
use crate::functions::{create_physical_fun, out_ordering};
use crate::physical_expr::{down_cast_any_ref, physical_exprs_equal};
use crate::sort_properties::SortProperties;
use crate::PhysicalExpr;

use arrow::datatypes::{DataType, Schema};
use arrow::record_batch::RecordBatch;
use datafusion_common::Result;
use datafusion_common::{internal_err, Result};
use datafusion_expr::{
expr_vec_fmt, BuiltinScalarFunction, ColumnarValue, FuncMonotonicity,
ScalarFunctionImplementation,
ScalarFunctionDefinition,
};

/// Physical expression of a scalar function
pub struct ScalarFunctionExpr {
fun: ScalarFunctionImplementation,
fun: ScalarFunctionDefinition,
name: String,
args: Vec<Arc<dyn PhysicalExpr>>,
return_type: DataType,
Expand Down Expand Up @@ -79,7 +79,7 @@ impl ScalarFunctionExpr {
/// Create a new Scalar function
pub fn new(
name: &str,
fun: ScalarFunctionImplementation,
fun: ScalarFunctionDefinition,
args: Vec<Arc<dyn PhysicalExpr>>,
return_type: DataType,
monotonicity: Option<FuncMonotonicity>,
Expand All @@ -96,7 +96,7 @@ impl ScalarFunctionExpr {
}

/// Get the scalar function implementation
pub fn fun(&self) -> &ScalarFunctionImplementation {
pub fn fun(&self) -> &ScalarFunctionDefinition {
&self.fun
}

Expand Down Expand Up @@ -172,8 +172,18 @@ impl PhysicalExpr for ScalarFunctionExpr {
};

// evaluate the function
let fun = self.fun.as_ref();
(fun)(&inputs)
match self.fun {
ScalarFunctionDefinition::BuiltIn(ref fun) => {
let fun = create_physical_fun(fun)?;
(fun)(&inputs)
}
ScalarFunctionDefinition::UDF(ref fun) => fun.invoke(&inputs),
ScalarFunctionDefinition::Name(_) => {
internal_err!(
"Name function must be resolved to one of the other variants prior to physical planning"
)
}
}
}

fn children(&self) -> Vec<Arc<dyn PhysicalExpr>> {
Expand Down
7 changes: 5 additions & 2 deletions datafusion/physical-expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ use crate::{PhysicalExpr, ScalarFunctionExpr};
use arrow_schema::Schema;
use datafusion_common::{DFSchema, Result};
pub use datafusion_expr::ScalarUDF;
use datafusion_expr::{type_coercion::functions::data_types, Expr};
use datafusion_expr::{
type_coercion::functions::data_types, Expr, ScalarFunctionDefinition,
};
use std::sync::Arc;

/// Create a physical expression of the UDF.
Expand All @@ -45,9 +47,10 @@ pub fn create_physical_expr(
let return_type =
fun.return_type_from_exprs(args, input_dfschema, &input_expr_types)?;

let fun_def = ScalarFunctionDefinition::UDF(Arc::new(fun.clone()));
Ok(Arc::new(ScalarFunctionExpr::new(
fun.name(),
fun.fun(),
fun_def,
input_phy_exprs.to_vec(),
return_type,
fun.monotonicity()?,
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1458,6 +1458,7 @@ message PhysicalExprNode {
message PhysicalScalarUdfNode {
string name = 1;
repeated PhysicalExprNode args = 2;
optional bytes fun_definition = 3;
ArrowType return_type = 4;
}

Expand Down
21 changes: 21 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 0974759

Please sign in to comment.