diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index fcfe6eaaa84d..5438632bafbf 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -30,8 +30,8 @@ use datafusion_expr::type_coercion::other::{ }; use datafusion_expr::utils::from_plan; use datafusion_expr::{ - is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, - BuiltinScalarFunction, Expr, LogicalPlan, Operator, + function, is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, + Expr, LogicalPlan, Operator, }; use datafusion_expr::{ExprSchemable, Signature}; use std::sync::Arc; @@ -311,18 +311,6 @@ impl ExprRewriter for TypeCoercionRewriter { }; Ok(expr) } - Expr::ScalarUDF { fun, args } => { - let new_expr = coerce_arguments_for_signature( - args.as_slice(), - &self.schema, - &fun.signature, - )?; - let expr = Expr::ScalarUDF { - fun, - args: new_expr, - }; - Ok(expr) - } Expr::InList { expr, list, @@ -395,20 +383,30 @@ impl ExprRewriter for TypeCoercionRewriter { } } } - Expr::ScalarFunction { fun, args } => match fun { - BuiltinScalarFunction::Concat - | BuiltinScalarFunction::ConcatWithSeparator => { - let new_args = args - .iter() - .map(|e| e.clone().cast_to(&DataType::Utf8, &self.schema)) - .collect::>>()?; - Ok(Expr::ScalarFunction { - fun, - args: new_args, - }) - } - fun => Ok(Expr::ScalarFunction { fun, args }), - }, + Expr::ScalarUDF { fun, args } => { + let new_expr = coerce_arguments_for_signature( + args.as_slice(), + &self.schema, + &fun.signature, + )?; + let expr = Expr::ScalarUDF { + fun, + args: new_expr, + }; + Ok(expr) + } + Expr::ScalarFunction { fun, args } => { + let nex_expr = coerce_arguments_for_signature( + args.as_slice(), + &self.schema, + &function::signature(&fun), + )?; + let expr = Expr::ScalarFunction { + fun, + args: nex_expr, + }; + Ok(expr) + } expr => Ok(expr), } } @@ -457,7 +455,9 @@ mod test { use arrow::datatypes::DataType; use datafusion_common::{DFField, DFSchema, Result, ScalarValue}; use datafusion_expr::expr_rewriter::ExprRewritable; - use datafusion_expr::{cast, col, concat, concat_ws, is_true, ColumnarValue}; + use datafusion_expr::{ + cast, col, concat, concat_ws, is_true, BuiltinScalarFunction, ColumnarValue, + }; use datafusion_expr::{ lit, logical_plan::{EmptyRelation, Projection}, @@ -572,6 +572,30 @@ mod test { Ok(()) } + #[test] + fn scalar_function() -> Result<()> { + let empty = empty(); + let lit_expr = lit(10i64); + let fun: BuiltinScalarFunction = BuiltinScalarFunction::Abs; + let scalar_function_expr = Expr::ScalarFunction { + fun, + args: vec![lit_expr], + }; + let plan = LogicalPlan::Projection(Projection::try_new( + vec![scalar_function_expr], + empty, + None, + )?); + let rule = TypeCoercion::new(); + let mut config = OptimizerConfig::default(); + let plan = rule.optimize(&plan, &mut config)?; + assert_eq!( + "Projection: abs(CAST(Int64(10) AS Float64))\n EmptyRelation", + &format!("{:?}", plan) + ); + Ok(()) + } + #[test] fn binary_op_date32_add_interval() -> Result<()> { //CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("386547056640") diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 5796f8f7d5f0..7d9e89b52e32 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -34,9 +34,8 @@ use crate::execution_props::ExecutionProps; use crate::{ array_expressions, conditional_expressions, datetime_expressions, expressions::{cast_column, nullif_func, DEFAULT_DATAFUSION_CAST_OPTIONS}, - math_expressions, string_expressions, struct_expressions, - type_coercion::coerce, - PhysicalExpr, ScalarFunctionExpr, + math_expressions, string_expressions, struct_expressions, PhysicalExpr, + ScalarFunctionExpr, }; use arrow::{ array::ArrayRef, @@ -58,15 +57,12 @@ pub fn create_physical_expr( input_schema: &Schema, execution_props: &ExecutionProps, ) -> Result> { - let coerced_phy_exprs = - coerce(input_phy_exprs, input_schema, &function::signature(fun))?; - - let coerced_expr_types = coerced_phy_exprs + let input_expr_types = input_phy_exprs .iter() .map(|e| e.data_type(input_schema)) .collect::>>()?; - let data_type = function::return_type(fun, &coerced_expr_types)?; + let data_type = function::return_type(fun, &input_expr_types)?; let fun_expr: ScalarFunctionImplementation = match fun { // These functions need args and input schema to pick an implementation @@ -74,7 +70,7 @@ pub fn create_physical_expr( // here we return either a cast fn or string timestamp translation based on the expression data type // so we don't have to pay a per-array/batch cost. BuiltinScalarFunction::ToTimestamp => { - Arc::new(match coerced_phy_exprs[0].data_type(input_schema) { + Arc::new(match input_phy_exprs[0].data_type(input_schema) { Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => { |col_values: &[ColumnarValue]| { cast_column( @@ -89,12 +85,12 @@ pub fn create_physical_expr( return Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function to_timestamp", other, - ))) + ))); } }) } BuiltinScalarFunction::ToTimestampMillis => { - Arc::new(match coerced_phy_exprs[0].data_type(input_schema) { + Arc::new(match input_phy_exprs[0].data_type(input_schema) { Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => { |col_values: &[ColumnarValue]| { cast_column( @@ -109,12 +105,12 @@ pub fn create_physical_expr( return Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function to_timestamp_millis", other, - ))) + ))); } }) } BuiltinScalarFunction::ToTimestampMicros => { - Arc::new(match coerced_phy_exprs[0].data_type(input_schema) { + Arc::new(match input_phy_exprs[0].data_type(input_schema) { Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => { |col_values: &[ColumnarValue]| { cast_column( @@ -129,12 +125,12 @@ pub fn create_physical_expr( return Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function to_timestamp_micros", other, - ))) + ))); } }) } BuiltinScalarFunction::ToTimestampSeconds => Arc::new({ - match coerced_phy_exprs[0].data_type(input_schema) { + match input_phy_exprs[0].data_type(input_schema) { Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => { |col_values: &[ColumnarValue]| { cast_column( @@ -149,12 +145,12 @@ pub fn create_physical_expr( return Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function to_timestamp_seconds", other, - ))) + ))); } } }), BuiltinScalarFunction::FromUnixtime => Arc::new({ - match coerced_phy_exprs[0].data_type(input_schema) { + match input_phy_exprs[0].data_type(input_schema) { Ok(DataType::Int64) => |col_values: &[ColumnarValue]| { cast_column( &col_values[0], @@ -166,12 +162,12 @@ pub fn create_physical_expr( return Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function from_unixtime", other, - ))) + ))); } } }), BuiltinScalarFunction::ArrowTypeof => { - let input_data_type = coerced_phy_exprs[0].data_type(input_schema)?; + let input_data_type = input_phy_exprs[0].data_type(input_schema)?; Arc::new(move |_| { Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(format!( "{}", @@ -186,7 +182,7 @@ pub fn create_physical_expr( Ok(Arc::new(ScalarFunctionExpr::new( &format!("{}", fun), fun_expr, - coerced_phy_exprs, + input_phy_exprs.to_vec(), &data_type, ))) } @@ -727,7 +723,7 @@ pub fn create_physical_fun( return Err(DataFusionError::Internal(format!( "create_physical_fun: Unsupported scalar function {:?}", fun - ))) + ))); } }) } @@ -737,6 +733,7 @@ mod tests { use super::*; use crate::expressions::{col, lit}; use crate::from_slice::FromSlice; + use crate::type_coercion::coerce; use arrow::{ array::{ Array, ArrayRef, BinaryArray, BooleanArray, FixedSizeListArray, Float32Array, @@ -764,7 +761,7 @@ mod tests { let columns: Vec = vec![Arc::new(Int32Array::from_slice(&[1]))]; let expr = - create_physical_expr(&BuiltinScalarFunction::$FUNC, $ARGS, &schema, &execution_props)?; + create_physical_expr_with_type_coercion(&BuiltinScalarFunction::$FUNC, $ARGS, &schema, &execution_props)?; // type is correct assert_eq!(expr.data_type(&schema)?, DataType::$DATA_TYPE); @@ -2683,7 +2680,12 @@ mod tests { ]; for fun in funs.iter() { - let expr = create_physical_expr(fun, &[], &schema, &execution_props); + let expr = create_physical_expr_with_type_coercion( + fun, + &[], + &schema, + &execution_props, + ); match expr { Ok(..) => { @@ -2720,7 +2722,7 @@ mod tests { let funs = [BuiltinScalarFunction::Now, BuiltinScalarFunction::Random]; for fun in funs.iter() { - create_physical_expr(fun, &[], &schema, &execution_props)?; + create_physical_expr_with_type_coercion(fun, &[], &schema, &execution_props)?; } Ok(()) } @@ -2739,7 +2741,7 @@ mod tests { let columns: Vec = vec![value1, value2]; let execution_props = ExecutionProps::new(); - let expr = create_physical_expr( + let expr = create_physical_expr_with_type_coercion( &BuiltinScalarFunction::MakeArray, &[col("a", &schema)?, col("b", &schema)?], &schema, @@ -2805,7 +2807,7 @@ mod tests { let col_value: ArrayRef = Arc::new(StringArray::from_slice(&["aaa-555"])); let pattern = lit(r".*-(\d*)"); let columns: Vec = vec![col_value]; - let expr = create_physical_expr( + let expr = create_physical_expr_with_type_coercion( &BuiltinScalarFunction::RegexpMatch, &[col("a", &schema)?, pattern], &schema, @@ -2844,7 +2846,7 @@ mod tests { let col_value = lit("aaa-555"); let pattern = lit(r".*-(\d*)"); let columns: Vec = vec![Arc::new(Int32Array::from_slice(&[1]))]; - let expr = create_physical_expr( + let expr = create_physical_expr_with_type_coercion( &BuiltinScalarFunction::RegexpMatch, &[col_value, pattern], &schema, @@ -2872,4 +2874,17 @@ mod tests { Ok(()) } + + // Helper function + // The type coercion will be done in the logical phase, should do the type coercion for the test + fn create_physical_expr_with_type_coercion( + fun: &BuiltinScalarFunction, + input_phy_exprs: &[Arc], + input_schema: &Schema, + execution_props: &ExecutionProps, + ) -> Result> { + let type_coerced_phy_exprs = + coerce(input_phy_exprs, input_schema, &function::signature(fun)).unwrap(); + create_physical_expr(fun, &type_coerced_phy_exprs, input_schema, execution_props) + } }