Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 53 additions & 29 deletions datafusion/optimizer/src/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ use datafusion_expr::type_coercion::other::{
};
use datafusion_expr::utils::from_plan;
use datafusion_expr::{
is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown,
BuiltinScalarFunction, Expr, LogicalPlan, Operator,
function, is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown,
Expr, LogicalPlan, Operator,
};
use datafusion_expr::{ExprSchemable, Signature};
use std::sync::Arc;
Expand Down Expand Up @@ -311,18 +311,6 @@ impl ExprRewriter for TypeCoercionRewriter {
};
Ok(expr)
}
Expr::ScalarUDF { fun, args } => {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just move these code

let new_expr = coerce_arguments_for_signature(
args.as_slice(),
&self.schema,
&fun.signature,
)?;
let expr = Expr::ScalarUDF {
fun,
args: new_expr,
};
Ok(expr)
}
Expr::InList {
expr,
list,
Expand Down Expand Up @@ -395,20 +383,30 @@ impl ExprRewriter for TypeCoercionRewriter {
}
}
}
Expr::ScalarFunction { fun, args } => match fun {
BuiltinScalarFunction::Concat
| BuiltinScalarFunction::ConcatWithSeparator => {
let new_args = args
.iter()
.map(|e| e.clone().cast_to(&DataType::Utf8, &self.schema))
.collect::<Result<Vec<_>>>()?;
Ok(Expr::ScalarFunction {
fun,
args: new_args,
})
}
fun => Ok(Expr::ScalarFunction { fun, args }),
},
Expr::ScalarUDF { fun, args } => {
let new_expr = coerce_arguments_for_signature(
args.as_slice(),
&self.schema,
&fun.signature,
)?;
let expr = Expr::ScalarUDF {
fun,
args: new_expr,
};
Ok(expr)
}
Expr::ScalarFunction { fun, args } => {
let nex_expr = coerce_arguments_for_signature(
args.as_slice(),
&self.schema,
&function::signature(&fun),
)?;
let expr = Expr::ScalarFunction {
fun,
args: nex_expr,
};
Ok(expr)
}
expr => Ok(expr),
}
}
Expand Down Expand Up @@ -457,7 +455,9 @@ mod test {
use arrow::datatypes::DataType;
use datafusion_common::{DFField, DFSchema, Result, ScalarValue};
use datafusion_expr::expr_rewriter::ExprRewritable;
use datafusion_expr::{cast, col, concat, concat_ws, is_true, ColumnarValue};
use datafusion_expr::{
cast, col, concat, concat_ws, is_true, BuiltinScalarFunction, ColumnarValue,
};
use datafusion_expr::{
lit,
logical_plan::{EmptyRelation, Projection},
Expand Down Expand Up @@ -572,6 +572,30 @@ mod test {
Ok(())
}

#[test]
fn scalar_function() -> Result<()> {
let empty = empty();
let lit_expr = lit(10i64);
let fun: BuiltinScalarFunction = BuiltinScalarFunction::Abs;
let scalar_function_expr = Expr::ScalarFunction {
fun,
args: vec![lit_expr],
};
let plan = LogicalPlan::Projection(Projection::try_new(
vec![scalar_function_expr],
empty,
None,
)?);
let rule = TypeCoercion::new();
let mut config = OptimizerConfig::default();
let plan = rule.optimize(&plan, &mut config)?;
assert_eq!(
"Projection: abs(CAST(Int64(10) AS Float64))\n EmptyRelation",
&format!("{:?}", plan)
);
Ok(())
}

#[test]
fn binary_op_date32_add_interval() -> Result<()> {
//CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("386547056640")
Expand Down
69 changes: 42 additions & 27 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,8 @@ use crate::execution_props::ExecutionProps;
use crate::{
array_expressions, conditional_expressions, datetime_expressions,
expressions::{cast_column, nullif_func, DEFAULT_DATAFUSION_CAST_OPTIONS},
math_expressions, string_expressions, struct_expressions,
type_coercion::coerce,
PhysicalExpr, ScalarFunctionExpr,
math_expressions, string_expressions, struct_expressions, PhysicalExpr,
ScalarFunctionExpr,
};
use arrow::{
array::ArrayRef,
Expand All @@ -58,23 +57,20 @@ pub fn create_physical_expr(
input_schema: &Schema,
execution_props: &ExecutionProps,
) -> Result<Arc<dyn PhysicalExpr>> {
let coerced_phy_exprs =
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove type coercion in the physical phase

coerce(input_phy_exprs, input_schema, &function::signature(fun))?;

let coerced_expr_types = coerced_phy_exprs
let input_expr_types = input_phy_exprs
.iter()
.map(|e| e.data_type(input_schema))
.collect::<Result<Vec<_>>>()?;

let data_type = function::return_type(fun, &coerced_expr_types)?;
let data_type = function::return_type(fun, &input_expr_types)?;

let fun_expr: ScalarFunctionImplementation = match fun {
// These functions need args and input schema to pick an implementation
// Unlike the string functions, which actually figure out the function to use with each array,
// here we return either a cast fn or string timestamp translation based on the expression data type
// so we don't have to pay a per-array/batch cost.
BuiltinScalarFunction::ToTimestamp => {
Arc::new(match coerced_phy_exprs[0].data_type(input_schema) {
Arc::new(match input_phy_exprs[0].data_type(input_schema) {
Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => {
|col_values: &[ColumnarValue]| {
cast_column(
Expand All @@ -89,12 +85,12 @@ pub fn create_physical_expr(
return Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function to_timestamp",
other,
)))
)));
}
})
}
BuiltinScalarFunction::ToTimestampMillis => {
Arc::new(match coerced_phy_exprs[0].data_type(input_schema) {
Arc::new(match input_phy_exprs[0].data_type(input_schema) {
Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => {
|col_values: &[ColumnarValue]| {
cast_column(
Expand All @@ -109,12 +105,12 @@ pub fn create_physical_expr(
return Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function to_timestamp_millis",
other,
)))
)));
}
})
}
BuiltinScalarFunction::ToTimestampMicros => {
Arc::new(match coerced_phy_exprs[0].data_type(input_schema) {
Arc::new(match input_phy_exprs[0].data_type(input_schema) {
Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => {
|col_values: &[ColumnarValue]| {
cast_column(
Expand All @@ -129,12 +125,12 @@ pub fn create_physical_expr(
return Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function to_timestamp_micros",
other,
)))
)));
}
})
}
BuiltinScalarFunction::ToTimestampSeconds => Arc::new({
match coerced_phy_exprs[0].data_type(input_schema) {
match input_phy_exprs[0].data_type(input_schema) {
Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => {
|col_values: &[ColumnarValue]| {
cast_column(
Expand All @@ -149,12 +145,12 @@ pub fn create_physical_expr(
return Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function to_timestamp_seconds",
other,
)))
)));
}
}
}),
BuiltinScalarFunction::FromUnixtime => Arc::new({
match coerced_phy_exprs[0].data_type(input_schema) {
match input_phy_exprs[0].data_type(input_schema) {
Ok(DataType::Int64) => |col_values: &[ColumnarValue]| {
cast_column(
&col_values[0],
Expand All @@ -166,12 +162,12 @@ pub fn create_physical_expr(
return Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function from_unixtime",
other,
)))
)));
}
}
}),
BuiltinScalarFunction::ArrowTypeof => {
let input_data_type = coerced_phy_exprs[0].data_type(input_schema)?;
let input_data_type = input_phy_exprs[0].data_type(input_schema)?;
Arc::new(move |_| {
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(format!(
"{}",
Expand All @@ -186,7 +182,7 @@ pub fn create_physical_expr(
Ok(Arc::new(ScalarFunctionExpr::new(
&format!("{}", fun),
fun_expr,
coerced_phy_exprs,
input_phy_exprs.to_vec(),
&data_type,
)))
}
Expand Down Expand Up @@ -727,7 +723,7 @@ pub fn create_physical_fun(
return Err(DataFusionError::Internal(format!(
"create_physical_fun: Unsupported scalar function {:?}",
fun
)))
)));
}
})
}
Expand All @@ -737,6 +733,7 @@ mod tests {
use super::*;
use crate::expressions::{col, lit};
use crate::from_slice::FromSlice;
use crate::type_coercion::coerce;
use arrow::{
array::{
Array, ArrayRef, BinaryArray, BooleanArray, FixedSizeListArray, Float32Array,
Expand Down Expand Up @@ -764,7 +761,7 @@ mod tests {
let columns: Vec<ArrayRef> = vec![Arc::new(Int32Array::from_slice(&[1]))];

let expr =
create_physical_expr(&BuiltinScalarFunction::$FUNC, $ARGS, &schema, &execution_props)?;
create_physical_expr_with_type_coercion(&BuiltinScalarFunction::$FUNC, $ARGS, &schema, &execution_props)?;

// type is correct
assert_eq!(expr.data_type(&schema)?, DataType::$DATA_TYPE);
Expand Down Expand Up @@ -2683,7 +2680,12 @@ mod tests {
];

for fun in funs.iter() {
let expr = create_physical_expr(fun, &[], &schema, &execution_props);
let expr = create_physical_expr_with_type_coercion(
fun,
&[],
&schema,
&execution_props,
);

match expr {
Ok(..) => {
Expand Down Expand Up @@ -2720,7 +2722,7 @@ mod tests {
let funs = [BuiltinScalarFunction::Now, BuiltinScalarFunction::Random];

for fun in funs.iter() {
create_physical_expr(fun, &[], &schema, &execution_props)?;
create_physical_expr_with_type_coercion(fun, &[], &schema, &execution_props)?;
}
Ok(())
}
Expand All @@ -2739,7 +2741,7 @@ mod tests {
let columns: Vec<ArrayRef> = vec![value1, value2];
let execution_props = ExecutionProps::new();

let expr = create_physical_expr(
let expr = create_physical_expr_with_type_coercion(
&BuiltinScalarFunction::MakeArray,
&[col("a", &schema)?, col("b", &schema)?],
&schema,
Expand Down Expand Up @@ -2805,7 +2807,7 @@ mod tests {
let col_value: ArrayRef = Arc::new(StringArray::from_slice(&["aaa-555"]));
let pattern = lit(r".*-(\d*)");
let columns: Vec<ArrayRef> = vec![col_value];
let expr = create_physical_expr(
let expr = create_physical_expr_with_type_coercion(
&BuiltinScalarFunction::RegexpMatch,
&[col("a", &schema)?, pattern],
&schema,
Expand Down Expand Up @@ -2844,7 +2846,7 @@ mod tests {
let col_value = lit("aaa-555");
let pattern = lit(r".*-(\d*)");
let columns: Vec<ArrayRef> = vec![Arc::new(Int32Array::from_slice(&[1]))];
let expr = create_physical_expr(
let expr = create_physical_expr_with_type_coercion(
&BuiltinScalarFunction::RegexpMatch,
&[col_value, pattern],
&schema,
Expand Down Expand Up @@ -2872,4 +2874,17 @@ mod tests {

Ok(())
}

// Helper function
// The type coercion will be done in the logical phase, should do the type coercion for the test
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

BTW this is basically what I need to do in IOx / why I made #3708

fn create_physical_expr_with_type_coercion(
fun: &BuiltinScalarFunction,
input_phy_exprs: &[Arc<dyn PhysicalExpr>],
input_schema: &Schema,
execution_props: &ExecutionProps,
) -> Result<Arc<dyn PhysicalExpr>> {
let type_coerced_phy_exprs =
coerce(input_phy_exprs, input_schema, &function::signature(fun)).unwrap();
create_physical_expr(fun, &type_coerced_phy_exprs, input_schema, execution_props)
}
}