diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index 61228946fdcd..97abf4d552a9 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -17,7 +17,7 @@ use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion::error::Result; -use datafusion::optimizer::expr_simplifier::{ExprSimplifier, SimplifyContext}; +use datafusion::optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext}; use datafusion::physical_expr::execution_props::ExecutionProps; use datafusion::prelude::*; use datafusion_common::{ScalarValue, ToDFSchema}; diff --git a/datafusion/core/tests/simplification.rs b/datafusion/core/tests/simplification.rs index 5a4a2e92fbf9..6e74fc0d9be8 100644 --- a/datafusion/core/tests/simplification.rs +++ b/datafusion/core/tests/simplification.rs @@ -21,7 +21,7 @@ use arrow::datatypes::{DataType, Field, Schema}; use datafusion::common::DFSchema; use datafusion::{error::Result, execution::context::ExecutionProps, prelude::*}; use datafusion_expr::{Expr, ExprSchemable}; -use datafusion_optimizer::expr_simplifier::{ExprSimplifier, SimplifyInfo}; +use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyInfo}; /// In order to simplify expressions, DataFusion must have information /// about the expressions. diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 0fcdf5546d8a..874a80713a67 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -56,7 +56,9 @@ pub use accumulator::{Accumulator, AggregateState}; pub use aggregate_function::AggregateFunction; pub use built_in_function::BuiltinScalarFunction; pub use columnar_value::{ColumnarValue, NullColumnarValue}; -pub use expr::{Between, BinaryExpr, Case, Expr, GetIndexedField, GroupingSet, Like}; +pub use expr::{ + Between, BinaryExpr, Case, Cast, Expr, GetIndexedField, GroupingSet, Like, +}; pub use expr_fn::*; pub use expr_schema::ExprSchemable; pub use function::{ diff --git a/datafusion/optimizer/src/expr_simplifier.rs b/datafusion/optimizer/src/expr_simplifier.rs deleted file mode 100644 index a30fa9c3b5ee..000000000000 --- a/datafusion/optimizer/src/expr_simplifier.rs +++ /dev/null @@ -1,315 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Expression simplification API - -use crate::{ - simplify_expressions::{ConstEvaluator, Simplifier}, - type_coercion::TypeCoercionRewriter, -}; -use arrow::datatypes::DataType; -use datafusion_common::{DFSchemaRef, DataFusionError, Result}; -use datafusion_expr::{expr_rewriter::ExprRewritable, Expr, ExprSchemable}; -use datafusion_physical_expr::execution_props::ExecutionProps; - -#[allow(rustdoc::private_intra_doc_links)] -/// The information necessary to apply algebraic simplification to an -/// [Expr]. See [SimplifyContext] for one concrete implementation. -/// -/// This trait exists so that other systems can plug schema -/// information in without having to create `DFSchema` objects. If you -/// have a [`DFSchemaRef`] you can use [`SimplifyContext`] -pub trait SimplifyInfo { - /// returns true if this Expr has boolean type - fn is_boolean_type(&self, expr: &Expr) -> Result; - - /// returns true of this expr is nullable (could possibly be NULL) - fn nullable(&self, expr: &Expr) -> Result; - - /// Returns details needed for partial expression evaluation - fn execution_props(&self) -> &ExecutionProps; -} - -/// This structure handles API for expression simplification -pub struct ExprSimplifier { - info: S, -} - -impl ExprSimplifier { - /// Create a new `ExprSimplifier` with the given `info` such as an - /// instance of [`SimplifyContext`]. See - /// [`simplify`](Self::simplify) for an example. - pub fn new(info: S) -> Self { - Self { info } - } - - /// Simplifies this [`Expr`]`s as much as possible, evaluating - /// constants and applying algebraic simplifications. - /// - /// The types of the expression must match what operators expect, - /// or else an error may occur trying to evaluate. See - /// [`coerce`](Self::coerce) for a function to help. - /// - /// # Example: - /// - /// `b > 2 AND b > 2` - /// - /// can be written to - /// - /// `b > 2` - /// - /// ``` - /// use datafusion_expr::{col, lit, Expr}; - /// use datafusion_common::Result; - /// use datafusion_physical_expr::execution_props::ExecutionProps; - /// use datafusion_optimizer::expr_simplifier::{ExprSimplifier, SimplifyInfo}; - /// - /// /// Simple implementation that provides `Simplifier` the information it needs - /// /// See SimplifyContext for a structure that does this. - /// #[derive(Default)] - /// struct Info { - /// execution_props: ExecutionProps, - /// }; - /// - /// impl SimplifyInfo for Info { - /// fn is_boolean_type(&self, expr: &Expr) -> Result { - /// Ok(false) - /// } - /// fn nullable(&self, expr: &Expr) -> Result { - /// Ok(true) - /// } - /// fn execution_props(&self) -> &ExecutionProps { - /// &self.execution_props - /// } - /// } - /// - /// // Create the simplifier - /// let simplifier = ExprSimplifier::new(Info::default()); - /// - /// // b < 2 - /// let b_lt_2 = col("b").gt(lit(2)); - /// - /// // (b < 2) OR (b < 2) - /// let expr = b_lt_2.clone().or(b_lt_2.clone()); - /// - /// // (b < 2) OR (b < 2) --> (b < 2) - /// let expr = simplifier.simplify(expr).unwrap(); - /// assert_eq!(expr, b_lt_2); - /// ``` - pub fn simplify(&self, expr: Expr) -> Result { - let mut simplifier = Simplifier::new(&self.info); - let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?; - - // TODO iterate until no changes are made during rewrite - // (evaluating constants can enable new simplifications and - // simplifications can enable new constant evaluation) - // https://github.com/apache/arrow-datafusion/issues/1160 - expr.rewrite(&mut const_evaluator)? - .rewrite(&mut simplifier)? - // run both passes twice to try an minimize simplifications that we missed - .rewrite(&mut const_evaluator)? - .rewrite(&mut simplifier) - } - - /// Apply type coercion to an [`Expr`] so that it can be - /// evaluated as a [`PhysicalExpr`](datafusion_physical_expr::PhysicalExpr). - /// - /// See the [type coercion module](datafusion_expr::type_coercion) - /// documentation for more details on type coercion - /// - // Would be nice if this API could use the SimplifyInfo - // rather than creating an DFSchemaRef coerces rather than doing - // it manually. - // https://github.com/apache/arrow-datafusion/issues/3793 - pub fn coerce(&self, expr: Expr, schema: DFSchemaRef) -> Result { - let mut expr_rewrite = TypeCoercionRewriter { schema }; - - expr.rewrite(&mut expr_rewrite) - } -} - -/// Provides simplification information based on DFSchema and -/// [`ExecutionProps`]. This is the default implementation used by DataFusion -/// -/// For example: -/// ``` -/// use arrow::datatypes::{Schema, Field, DataType}; -/// use datafusion_expr::{col, lit}; -/// use datafusion_common::{DataFusionError, ToDFSchema}; -/// use datafusion_physical_expr::execution_props::ExecutionProps; -/// use datafusion_optimizer::expr_simplifier::{SimplifyContext, ExprSimplifier}; -/// -/// // Create the schema -/// let schema = Schema::new(vec![ -/// Field::new("i", DataType::Int64, false), -/// ]) -/// .to_dfschema_ref().unwrap(); -/// -/// // Create the simplifier -/// let props = ExecutionProps::new(); -/// let context = SimplifyContext::new(&props) -/// .with_schema(schema); -/// let simplifier = ExprSimplifier::new(context); -/// -/// // Use the simplifier -/// -/// // b < 2 or (1 > 3) -/// let expr = col("b").lt(lit(2)).or(lit(1).gt(lit(3))); -/// -/// // b < 2 -/// let simplified = simplifier.simplify(expr).unwrap(); -/// assert_eq!(simplified, col("b").lt(lit(2))); -/// ``` -pub struct SimplifyContext<'a> { - schemas: Vec, - props: &'a ExecutionProps, -} - -impl<'a> SimplifyContext<'a> { - /// Create a new SimplifyContext - pub fn new(props: &'a ExecutionProps) -> Self { - Self { - schemas: vec![], - props, - } - } - - /// Register a [`DFSchemaRef`] with this context - pub fn with_schema(mut self, schema: DFSchemaRef) -> Self { - self.schemas.push(schema); - self - } -} - -impl<'a> SimplifyInfo for SimplifyContext<'a> { - /// returns true if this Expr has boolean type - fn is_boolean_type(&self, expr: &Expr) -> Result { - for schema in &self.schemas { - if let Ok(DataType::Boolean) = expr.get_type(schema) { - return Ok(true); - } - } - - Ok(false) - } - /// Returns true if expr is nullable - fn nullable(&self, expr: &Expr) -> Result { - self.schemas - .iter() - .find_map(|schema| { - // expr may be from another input, so ignore errors - // by converting to None to keep trying - expr.nullable(schema.as_ref()).ok() - }) - .ok_or_else(|| { - // This means we weren't able to compute `Expr::nullable` with - // *any* input schemas, signalling a problem - DataFusionError::Internal(format!( - "Could not find columns in '{}' during simplify", - expr - )) - }) - } - - fn execution_props(&self) -> &ExecutionProps { - self.props - } -} - -#[cfg(test)] -mod tests { - use super::*; - use arrow::datatypes::{Field, Schema}; - use datafusion_common::ToDFSchema; - use datafusion_expr::{col, lit, when}; - - #[test] - fn api_basic() { - let props = ExecutionProps::new(); - let simplifier = - ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema())); - - let expr = lit(1) + lit(2); - let expected = lit(3); - assert_eq!(expected, simplifier.simplify(expr).unwrap()); - } - - #[test] - fn basic_coercion() { - let schema = test_schema(); - let props = ExecutionProps::new(); - let simplifier = - ExprSimplifier::new(SimplifyContext::new(&props).with_schema(schema.clone())); - - // Note expr type is int32 (not int64) - // (1i64 + 2i32) < i - let expr = (lit(1i64) + lit(2i32)).lt(col("i")); - // should fully simplify to 3 < i (though i has been coerced to i64) - let expected = lit(3i64).lt(col("i")); - - // Would be nice if this API could use the SimplifyInfo - // rather than creating an DFSchemaRef coerces rather than doing - // it manually. - // https://github.com/apache/arrow-datafusion/issues/3793 - let expr = simplifier.coerce(expr, schema).unwrap(); - - assert_eq!(expected, simplifier.simplify(expr).unwrap()); - } - - fn test_schema() -> DFSchemaRef { - Schema::new(vec![ - Field::new("i", DataType::Int64, false), - Field::new("b", DataType::Boolean, true), - ]) - .to_dfschema_ref() - .unwrap() - } - - #[test] - fn simplify_and_constant_prop() { - let props = ExecutionProps::new(); - let simplifier = - ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema())); - - // should be able to simplify to false - // (i * (1 - 2)) > 0 - let expr = (col("i") * (lit(1) - lit(1))).gt(lit(0)); - let expected = lit(false); - assert_eq!(expected, simplifier.simplify(expr).unwrap()); - } - - #[test] - fn simplify_and_constant_prop_with_case() { - let props = ExecutionProps::new(); - let simplifier = - ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema())); - - // CASE - // WHEN i>5 AND false THEN i > 5 - // WHEN i<5 AND true THEN i < 5 - // ELSE false - // END - // - // Can be simplified to `i < 5` - let expr = when(col("i").gt(lit(5)).and(lit(false)), col("i").gt(lit(5))) - .when(col("i").lt(lit(5)).and(lit(true)), col("i").lt(lit(5))) - .otherwise(lit(false)) - .unwrap(); - let expected = col("i").lt(lit(5)); - assert_eq!(expected, simplifier.simplify(expr).unwrap()); - } -} diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 5e8108d6766e..13d4cf4a328a 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -20,7 +20,6 @@ pub mod decorrelate_where_exists; pub mod decorrelate_where_in; pub mod eliminate_filter; pub mod eliminate_limit; -pub mod expr_simplifier; pub mod filter_null_join_keys; pub mod filter_push_down; pub mod inline_table_scan; diff --git a/datafusion/optimizer/src/simplify_expressions/context.rs b/datafusion/optimizer/src/simplify_expressions/context.rs new file mode 100644 index 000000000000..da44a7e8fd28 --- /dev/null +++ b/datafusion/optimizer/src/simplify_expressions/context.rs @@ -0,0 +1,130 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Structs and traits to provide the information needed for expression simplification. + +use arrow::datatypes::DataType; +use datafusion_common::{DFSchemaRef, DataFusionError, Result}; +use datafusion_expr::{Expr, ExprSchemable}; +use datafusion_physical_expr::execution_props::ExecutionProps; + +#[allow(rustdoc::private_intra_doc_links)] +/// The information necessary to apply algebraic simplification to an +/// [Expr]. See [SimplifyContext] for one concrete implementation. +/// +/// This trait exists so that other systems can plug schema +/// information in without having to create `DFSchema` objects. If you +/// have a [`DFSchemaRef`] you can use [`SimplifyContext`] +pub trait SimplifyInfo { + /// returns true if this Expr has boolean type + fn is_boolean_type(&self, expr: &Expr) -> Result; + + /// returns true of this expr is nullable (could possibly be NULL) + fn nullable(&self, expr: &Expr) -> Result; + + /// Returns details needed for partial expression evaluation + fn execution_props(&self) -> &ExecutionProps; +} + +/// Provides simplification information based on DFSchema and +/// [`ExecutionProps`]. This is the default implementation used by DataFusion +/// +/// For example: +/// ``` +/// use arrow::datatypes::{Schema, Field, DataType}; +/// use datafusion_expr::{col, lit}; +/// use datafusion_common::{DataFusionError, ToDFSchema}; +/// use datafusion_physical_expr::execution_props::ExecutionProps; +/// use datafusion_optimizer::simplify_expressions::{SimplifyContext, ExprSimplifier}; +/// +/// // Create the schema +/// let schema = Schema::new(vec![ +/// Field::new("i", DataType::Int64, false), +/// ]) +/// .to_dfschema_ref().unwrap(); +/// +/// // Create the simplifier +/// let props = ExecutionProps::new(); +/// let context = SimplifyContext::new(&props) +/// .with_schema(schema); +/// let simplifier = ExprSimplifier::new(context); +/// +/// // Use the simplifier +/// +/// // b < 2 or (1 > 3) +/// let expr = col("b").lt(lit(2)).or(lit(1).gt(lit(3))); +/// +/// // b < 2 +/// let simplified = simplifier.simplify(expr).unwrap(); +/// assert_eq!(simplified, col("b").lt(lit(2))); +/// ``` +pub struct SimplifyContext<'a> { + schemas: Vec, + props: &'a ExecutionProps, +} + +impl<'a> SimplifyContext<'a> { + /// Create a new SimplifyContext + pub fn new(props: &'a ExecutionProps) -> Self { + Self { + schemas: vec![], + props, + } + } + + /// Register a [`DFSchemaRef`] with this context + pub fn with_schema(mut self, schema: DFSchemaRef) -> Self { + self.schemas.push(schema); + self + } +} + +impl<'a> SimplifyInfo for SimplifyContext<'a> { + /// returns true if this Expr has boolean type + fn is_boolean_type(&self, expr: &Expr) -> Result { + for schema in &self.schemas { + if let Ok(DataType::Boolean) = expr.get_type(schema) { + return Ok(true); + } + } + + Ok(false) + } + + /// Returns true if expr is nullable + fn nullable(&self, expr: &Expr) -> Result { + self.schemas + .iter() + .find_map(|schema| { + // expr may be from another input, so ignore errors + // by converting to None to keep trying + expr.nullable(schema.as_ref()).ok() + }) + .ok_or_else(|| { + // This means we weren't able to compute `Expr::nullable` with + // *any* input schemas, signalling a problem + DataFusionError::Internal(format!( + "Could not find columns in '{}' during simplify", + expr + )) + }) + } + + fn execution_props(&self) -> &ExecutionProps { + self.props + } +} diff --git a/datafusion/optimizer/src/simplify_expressions.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs similarity index 57% rename from datafusion/optimizer/src/simplify_expressions.rs rename to datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 32c8c9bce856..6a3723644abb 100644 --- a/datafusion/optimizer/src/simplify_expressions.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -15,442 +15,121 @@ // specific language governing permissions and limitations // under the License. -//! Simplify expressions optimizer rule and implementation - -use crate::expr_simplifier::{ExprSimplifier, SimplifyContext}; -use crate::{expr_simplifier::SimplifyInfo, OptimizerConfig, OptimizerRule}; -use arrow::array::new_null_array; -use arrow::datatypes::{DataType, Field, Schema, DECIMAL128_MAX_PRECISION}; -use arrow::error::ArrowError; -use arrow::record_batch::RecordBatch; -use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue}; -use datafusion_expr::expr::BinaryExpr; +//! Expression simplification API + +use super::utils::*; +use crate::type_coercion::TypeCoercionRewriter; +use arrow::{ + array::new_null_array, + datatypes::{DataType, Field, Schema}, + error::ArrowError, + record_batch::RecordBatch, +}; +use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue}; use datafusion_expr::{ - expr::Between, - expr_fn::{and, concat_ws, or}, + and, expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}, - lit, - logical_plan::LogicalPlan, - utils::from_plan, - BuiltinScalarFunction, ColumnarValue, Expr, Operator, Volatility, + lit, or, BinaryExpr, BuiltinScalarFunction, ColumnarValue, Expr, Volatility, }; use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; -static POWS_OF_TEN: [i128; 38] = [ - 1, - 10, - 100, - 1000, - 10000, - 100000, - 1000000, - 10000000, - 100000000, - 1000000000, - 10000000000, - 100000000000, - 1000000000000, - 10000000000000, - 100000000000000, - 1000000000000000, - 10000000000000000, - 100000000000000000, - 1000000000000000000, - 10000000000000000000, - 100000000000000000000, - 1000000000000000000000, - 10000000000000000000000, - 100000000000000000000000, - 1000000000000000000000000, - 10000000000000000000000000, - 100000000000000000000000000, - 1000000000000000000000000000, - 10000000000000000000000000000, - 100000000000000000000000000000, - 1000000000000000000000000000000, - 10000000000000000000000000000000, - 100000000000000000000000000000000, - 1000000000000000000000000000000000, - 10000000000000000000000000000000000, - 100000000000000000000000000000000000, - 1000000000000000000000000000000000000, - 10000000000000000000000000000000000000, -]; - -/// Optimizer Pass that simplifies [`LogicalPlan`]s by rewriting -/// [`Expr`]`s evaluating constants and applying algebraic -/// simplifications -/// -/// # Introduction -/// It uses boolean algebra laws to simplify or reduce the number of terms in expressions. -/// -/// # Example: -/// `Filter: b > 2 AND b > 2` -/// is optimized to -/// `Filter: b > 2` -/// -#[derive(Default)] -pub struct SimplifyExpressions {} - -/// returns true if `needle` is found in a chain of search_op -/// expressions. Such as: (A AND B) AND C -fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> bool { - match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == search_op => { - expr_contains(left, needle, search_op) - || expr_contains(right, needle, search_op) - } - _ => expr == needle, - } -} - -fn is_zero(s: &Expr) -> bool { - match s { - Expr::Literal(ScalarValue::Int8(Some(0))) - | Expr::Literal(ScalarValue::Int16(Some(0))) - | Expr::Literal(ScalarValue::Int32(Some(0))) - | Expr::Literal(ScalarValue::Int64(Some(0))) - | Expr::Literal(ScalarValue::UInt8(Some(0))) - | Expr::Literal(ScalarValue::UInt16(Some(0))) - | Expr::Literal(ScalarValue::UInt32(Some(0))) - | Expr::Literal(ScalarValue::UInt64(Some(0))) => true, - Expr::Literal(ScalarValue::Float32(Some(v))) if *v == 0. => true, - Expr::Literal(ScalarValue::Float64(Some(v))) if *v == 0. => true, - Expr::Literal(ScalarValue::Decimal128(Some(v), _p, _s)) if *v == 0 => true, - _ => false, - } -} - -fn is_one(s: &Expr) -> bool { - match s { - Expr::Literal(ScalarValue::Int8(Some(1))) - | Expr::Literal(ScalarValue::Int16(Some(1))) - | Expr::Literal(ScalarValue::Int32(Some(1))) - | Expr::Literal(ScalarValue::Int64(Some(1))) - | Expr::Literal(ScalarValue::UInt8(Some(1))) - | Expr::Literal(ScalarValue::UInt16(Some(1))) - | Expr::Literal(ScalarValue::UInt32(Some(1))) - | Expr::Literal(ScalarValue::UInt64(Some(1))) => true, - Expr::Literal(ScalarValue::Float32(Some(v))) if *v == 1. => true, - Expr::Literal(ScalarValue::Float64(Some(v))) if *v == 1. => true, - Expr::Literal(ScalarValue::Decimal128(Some(v), _p, _s)) => { - *_s < DECIMAL128_MAX_PRECISION && POWS_OF_TEN[*_s as usize] == *v - } - _ => false, - } -} - -fn is_true(expr: &Expr) -> bool { - match expr { - Expr::Literal(ScalarValue::Boolean(Some(v))) => *v, - _ => false, - } -} - -/// returns true if expr is a -/// `Expr::Literal(ScalarValue::Boolean(v))` , false otherwise -fn is_bool_lit(expr: &Expr) -> bool { - matches!(expr, Expr::Literal(ScalarValue::Boolean(_))) -} - -/// Return a literal NULL value of Boolean data type -fn lit_bool_null() -> Expr { - Expr::Literal(ScalarValue::Boolean(None)) -} - -fn is_null(expr: &Expr) -> bool { - match expr { - Expr::Literal(v) => v.is_null(), - _ => false, - } -} - -fn is_false(expr: &Expr) -> bool { - match expr { - Expr::Literal(ScalarValue::Boolean(Some(v))) => !(*v), - _ => false, - } -} - -/// returns true if `haystack` looks like (needle OP X) or (X OP needle) -fn is_op_with(target_op: Operator, haystack: &Expr, needle: &Expr) -> bool { - matches!(haystack, Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == &target_op && (needle == left.as_ref() || needle == right.as_ref())) -} - -/// returns the contained boolean value in `expr` as -/// `Expr::Literal(ScalarValue::Boolean(v))`. -fn as_bool_lit(expr: Expr) -> Result> { - match expr { - Expr::Literal(ScalarValue::Boolean(v)) => Ok(v), - _ => Err(DataFusionError::Internal(format!( - "Expected boolean literal, got {:?}", - expr - ))), - } -} - -/// negate a Not clause -/// input is the clause to be negated.(args of Not clause) -/// For BinaryExpr, use the negator of op instead. -/// not ( A > B) ===> (A <= B) -/// For BoolExpr, not (A and B) ===> (not A) or (not B) -/// not (A or B) ===> (not A) and (not B) -/// not (not A) ===> A -/// For NullExpr, not (A is not null) ===> A is null -/// not (A is null) ===> A is not null -/// For InList, not (A not in (..)) ===> A in (..) -/// not (A in (..)) ===> A not in (..) -/// For Between, not (A between B and C) ===> (A not between B and C) -/// not (A not between B and C) ===> (A between B and C) -/// For others, use Not clause -fn negate_clause(expr: Expr) -> Expr { - match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - if let Some(negated_op) = op.negate() { - return Expr::BinaryExpr(BinaryExpr::new(left, negated_op, right)); - } - match op { - // not (A and B) ===> (not A) or (not B) - Operator::And => { - let left = negate_clause(*left); - let right = negate_clause(*right); - - or(left, right) - } - // not (A or B) ===> (not A) and (not B) - Operator::Or => { - let left = negate_clause(*left); - let right = negate_clause(*right); - - and(left, right) - } - // use not clause - _ => Expr::Not(Box::new(Expr::BinaryExpr(BinaryExpr::new( - left, op, right, - )))), - } - } - // not (not A) ===> A - Expr::Not(expr) => *expr, - // not (A is not null) ===> A is null - Expr::IsNotNull(expr) => expr.is_null(), - // not (A is null) ===> A is not null - Expr::IsNull(expr) => expr.is_not_null(), - // not (A not in (..)) ===> A in (..) - // not (A in (..)) ===> A not in (..) - Expr::InList { - expr, - list, - negated, - } => expr.in_list(list, !negated), - // not (A between B and C) ===> (A not between B and C) - // not (A not between B and C) ===> (A between B and C) - Expr::Between(between) => Expr::Between(Between::new( - between.expr, - !between.negated, - between.low, - between.high, - )), - // use not clause - _ => Expr::Not(Box::new(expr)), - } -} - -/// Simplify the `concat` function by -/// 1. filtering out all `null` literals -/// 2. concatenating contiguous literal arguments -/// -/// For example: -/// `concat(col(a), 'hello ', 'world', col(b), null)` -/// will be optimized to -/// `concat(col(a), 'hello world', col(b))` -fn simpl_concat(args: Vec) -> Result { - let mut new_args = Vec::with_capacity(args.len()); - let mut contiguous_scalar = "".to_string(); - for arg in args { - match arg { - // filter out `null` args - Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None)) => {} - // All literals have been converted to Utf8 or LargeUtf8 in type_coercion. - // Concatenate it with the `contiguous_scalar`. - Expr::Literal( - ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)), - ) => contiguous_scalar += &v, - Expr::Literal(x) => { - return Err(DataFusionError::Internal(format!( - "The scalar {} should be casted to string type during the type coercion.", - x - ))); - } - // If the arg is not a literal, we should first push the current `contiguous_scalar` - // to the `new_args` (if it is not empty) and reset it to empty string. - // Then pushing this arg to the `new_args`. - arg => { - if !contiguous_scalar.is_empty() { - new_args.push(lit(contiguous_scalar)); - contiguous_scalar = "".to_string(); - } - new_args.push(arg); - } - } - } - if !contiguous_scalar.is_empty() { - new_args.push(lit(contiguous_scalar)); - } - - Ok(Expr::ScalarFunction { - fun: BuiltinScalarFunction::Concat, - args: new_args, - }) -} +use super::SimplifyInfo; -/// Simply the `concat_ws` function by -/// 1. folding to `null` if the delimiter is null -/// 2. filtering out `null` arguments -/// 3. using `concat` to replace `concat_ws` if the delimiter is an empty string -/// 4. concatenating contiguous literals if the delimiter is a literal. -fn simpl_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result { - match delimiter { - Expr::Literal( - ScalarValue::Utf8(delimiter) | ScalarValue::LargeUtf8(delimiter), - ) => { - match delimiter { - // when the delimiter is an empty string, - // we can use `concat` to replace `concat_ws` - Some(delimiter) if delimiter.is_empty() => simpl_concat(args.to_vec()), - Some(delimiter) => { - let mut new_args = Vec::with_capacity(args.len()); - new_args.push(lit(delimiter)); - let mut contiguous_scalar = None; - for arg in args { - match arg { - // filter out null args - Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None)) => {} - Expr::Literal(ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v))) => { - match contiguous_scalar { - None => contiguous_scalar = Some(v.to_string()), - Some(mut pre) => { - pre += delimiter; - pre += v; - contiguous_scalar = Some(pre) - } - } - } - Expr::Literal(s) => return Err(DataFusionError::Internal(format!("The scalar {} should be casted to string type during the type coercion.", s))), - // If the arg is not a literal, we should first push the current `contiguous_scalar` - // to the `new_args` and reset it to None. - // Then pushing this arg to the `new_args`. - arg => { - if let Some(val) = contiguous_scalar { - new_args.push(lit(val)); - } - new_args.push(arg.clone()); - contiguous_scalar = None; - } - } - } - if let Some(val) = contiguous_scalar { - new_args.push(lit(val)); - } - Ok(Expr::ScalarFunction { - fun: BuiltinScalarFunction::ConcatWithSeparator, - args: new_args, - }) - } - // if the delimiter is null, then the value of the whole expression is null. - None => Ok(Expr::Literal(ScalarValue::Utf8(None))), - } - } - Expr::Literal(d) => Err(DataFusionError::Internal(format!( - "The scalar {} should be casted to string type during the type coercion.", - d - ))), - d => Ok(concat_ws( - d.clone(), - args.iter() - .cloned() - .filter(|x| !is_null(x)) - .collect::>(), - )), - } +/// This structure handles API for expression simplification +pub struct ExprSimplifier { + info: S, } -impl OptimizerRule for SimplifyExpressions { - fn name(&self) -> &str { - "simplify_expressions" +impl ExprSimplifier { + /// Create a new `ExprSimplifier` with the given `info` such as an + /// instance of [`SimplifyContext`]. See + /// [`simplify`](Self::simplify) for an example. + pub fn new(info: S) -> Self { + Self { info } } - fn optimize( - &self, - plan: &LogicalPlan, - optimizer_config: &mut OptimizerConfig, - ) -> Result { - let mut execution_props = ExecutionProps::new(); - execution_props.query_execution_start_time = - optimizer_config.query_execution_start_time(); - self.optimize_internal(plan, &execution_props) + /// Simplifies this [`Expr`]`s as much as possible, evaluating + /// constants and applying algebraic simplifications. + /// + /// The types of the expression must match what operators expect, + /// or else an error may occur trying to evaluate. See + /// [`coerce`](Self::coerce) for a function to help. + /// + /// # Example: + /// + /// `b > 2 AND b > 2` + /// + /// can be written to + /// + /// `b > 2` + /// + /// ``` + /// use datafusion_expr::{col, lit, Expr}; + /// use datafusion_common::Result; + /// use datafusion_physical_expr::execution_props::ExecutionProps; + /// use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyInfo}; + /// + /// /// Simple implementation that provides `Simplifier` the information it needs + /// /// See SimplifyContext for a structure that does this. + /// #[derive(Default)] + /// struct Info { + /// execution_props: ExecutionProps, + /// }; + /// + /// impl SimplifyInfo for Info { + /// fn is_boolean_type(&self, expr: &Expr) -> Result { + /// Ok(false) + /// } + /// fn nullable(&self, expr: &Expr) -> Result { + /// Ok(true) + /// } + /// fn execution_props(&self) -> &ExecutionProps { + /// &self.execution_props + /// } + /// } + /// + /// // Create the simplifier + /// let simplifier = ExprSimplifier::new(Info::default()); + /// + /// // b < 2 + /// let b_lt_2 = col("b").gt(lit(2)); + /// + /// // (b < 2) OR (b < 2) + /// let expr = b_lt_2.clone().or(b_lt_2.clone()); + /// + /// // (b < 2) OR (b < 2) --> (b < 2) + /// let expr = simplifier.simplify(expr).unwrap(); + /// assert_eq!(expr, b_lt_2); + /// ``` + pub fn simplify(&self, expr: Expr) -> Result { + let mut simplifier = Simplifier::new(&self.info); + let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?; + + // TODO iterate until no changes are made during rewrite + // (evaluating constants can enable new simplifications and + // simplifications can enable new constant evaluation) + // https://github.com/apache/arrow-datafusion/issues/1160 + expr.rewrite(&mut const_evaluator)? + .rewrite(&mut simplifier)? + // run both passes twice to try an minimize simplifications that we missed + .rewrite(&mut const_evaluator)? + .rewrite(&mut simplifier) } -} - -impl SimplifyExpressions { - fn optimize_internal( - &self, - plan: &LogicalPlan, - execution_props: &ExecutionProps, - ) -> Result { - // We need to pass down the all schemas within the plan tree to `optimize_expr` in order to - // to evaluate expression types. For example, a projection plan's schema will only include - // projected columns. With just the projected schema, it's not possible to infer types for - // expressions that references non-projected columns within the same project plan or its - // children plans. - let info = plan - .all_schemas() - .into_iter() - .fold(SimplifyContext::new(execution_props), |context, schema| { - context.with_schema(schema.clone()) - }); - - let simplifier = ExprSimplifier::new(info); - - let new_inputs = plan - .inputs() - .iter() - .map(|input| self.optimize_internal(input, execution_props)) - .collect::>>()?; - - let expr = plan - .expressions() - .into_iter() - .map(|e| { - // We need to keep original expression name, if any. - // Constant folding should not change expression name. - let name = &e.display_name(); - - // Apply the actual simplification logic - let new_e = simplifier.simplify(e)?; - - let new_name = &new_e.display_name(); - - if let (Ok(expr_name), Ok(new_expr_name)) = (name, new_name) { - if expr_name != new_expr_name { - Ok(new_e.alias(expr_name)) - } else { - Ok(new_e) - } - } else { - Ok(new_e) - } - }) - .collect::>>()?; - from_plan(plan, &expr, &new_inputs) - } -} + /// Apply type coercion to an [`Expr`] so that it can be + /// evaluated as a [`PhysicalExpr`](datafusion_physical_expr::PhysicalExpr). + /// + /// See the [type coercion module](datafusion_expr::type_coercion) + /// documentation for more details on type coercion + /// + // Would be nice if this API could use the SimplifyInfo + // rather than creating an DFSchemaRef coerces rather than doing + // it manually. + // https://github.com/apache/arrow-datafusion/issues/3793 + pub fn coerce(&self, expr: Expr, schema: DFSchemaRef) -> Result { + let mut expr_rewrite = TypeCoercionRewriter { schema }; -impl SimplifyExpressions { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} + expr.rewrite(&mut expr_rewrite) } } @@ -459,24 +138,7 @@ impl SimplifyExpressions { /// /// Note it does not handle algebraic rewrites such as `(a or false)` /// --> `a`, which is handled by [`Simplifier`] -/// -/// ``` -/// # use datafusion_expr::{col, lit}; -/// # use datafusion_optimizer::simplify_expressions::ConstEvaluator; -/// # use datafusion_physical_expr::execution_props::ExecutionProps; -/// # use datafusion_expr::expr_rewriter::ExprRewritable; -/// -/// let execution_props = ExecutionProps::new(); -/// let mut const_evaluator = ConstEvaluator::try_new(&execution_props).unwrap(); -/// -/// // (1 + 2) + a -/// let expr = (lit(1) + lit(2)) + col("a"); -/// -/// // is rewritten to (3 + a); -/// let rewritten = expr.rewrite(&mut const_evaluator).unwrap(); -/// assert_eq!(rewritten, lit(3) + col("a")); -/// ``` -pub struct ConstEvaluator<'a> { +struct ConstEvaluator<'a> { /// can_evaluate is used during the depth-first-search of the /// Expr tree to track if any siblings (or their descendants) were /// non evaluatable (e.g. had a column reference or volatile @@ -654,7 +316,7 @@ impl<'a> ConstEvaluator<'a> { /// * `false = true` and `true = false` to `false` /// * `!!expr` to `expr` /// * `expr = null` and `expr != null` to `null` -pub(crate) struct Simplifier<'a, S> { +struct Simplifier<'a, S> { info: &'a S, } @@ -667,7 +329,7 @@ impl<'a, S> Simplifier<'a, S> { impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { /// rewrite the expression simplifying any constant expressions fn mutate(&mut self, expr: Expr) -> Result { - use Operator::{And, Divide, Eq, Modulo, Multiply, NotEq, Or}; + use datafusion_expr::Operator::{And, Divide, Eq, Modulo, Multiply, NotEq, Or}; let info = self.info; let new_expr = match expr { @@ -1050,194 +712,453 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { } } -/// A macro to assert that one string is contained within another with -/// a nice error message if they are not. -/// -/// Usage: `assert_contains!(actual, expected)` -/// -/// Is a macro so test error -/// messages are on the same line as the failure; -/// -/// Both arguments must be convertable into Strings (Into) -#[macro_export] -macro_rules! assert_contains { - ($ACTUAL: expr, $EXPECTED: expr) => { - let actual_value: String = $ACTUAL.into(); - let expected_value: String = $EXPECTED.into(); - assert!( - actual_value.contains(&expected_value), - "Can not find expected in actual.\n\nExpected:\n{}\n\nActual:\n{}", - expected_value, - actual_value - ); - }; -} - #[cfg(test)] mod tests { + use std::{collections::HashMap, sync::Arc}; + + use crate::simplify_expressions::{ + utils::for_test::{cast_to_int64_expr, now_expr, to_timestamp_expr}, + SimplifyContext, + }; + use super::*; - use arrow::array::{ArrayRef, Int32Array}; + use arrow::{ + array::{ArrayRef, Int32Array}, + datatypes::{DataType, Field, Schema}, + }; use chrono::{DateTime, TimeZone, Utc}; - use datafusion_common::{DFField, DFSchemaRef}; - use datafusion_expr::expr::{Case, Cast}; - use datafusion_expr::expr_fn::{concat, concat_ws}; - use datafusion_expr::logical_plan::table_scan; - use datafusion_expr::{ - and, binary_expr, call_fn, col, create_udf, lit, lit_timestamp_nano, - logical_plan::builder::LogicalPlanBuilder, BuiltinScalarFunction, Expr, - ExprSchemable, ScalarUDF, + use datafusion_common::{DFField, ToDFSchema}; + use datafusion_expr::*; + use datafusion_physical_expr::{ + execution_props::ExecutionProps, functions::make_scalar_function, }; - use datafusion_physical_expr::functions::make_scalar_function; - use std::collections::HashMap; - use std::sync::Arc; + // ------------------------------ + // --- ExprSimplifier tests ----- + // ------------------------------ #[test] - fn test_simplify_or_true() { - let expr_a = col("c2").or(lit(true)); - let expr_b = lit(true).or(col("c2")); - let expected = lit(true); + fn api_basic() { + let props = ExecutionProps::new(); + let simplifier = + ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema())); - assert_eq!(simplify(expr_a), expected); - assert_eq!(simplify(expr_b), expected); + let expr = lit(1) + lit(2); + let expected = lit(3); + assert_eq!(expected, simplifier.simplify(expr).unwrap()); } #[test] - fn test_simplify_or_false() { - let expr_a = lit(false).or(col("c2")); - let expr_b = col("c2").or(lit(false)); - let expected = col("c2"); + fn basic_coercion() { + let schema = test_schema(); + let props = ExecutionProps::new(); + let simplifier = + ExprSimplifier::new(SimplifyContext::new(&props).with_schema(schema.clone())); - assert_eq!(simplify(expr_a), expected); - assert_eq!(simplify(expr_b), expected); - } + // Note expr type is int32 (not int64) + // (1i64 + 2i32) < i + let expr = (lit(1i64) + lit(2i32)).lt(col("i")); + // should fully simplify to 3 < i (though i has been coerced to i64) + let expected = lit(3i64).lt(col("i")); - #[test] - fn test_simplify_or_same() { - let expr = col("c2").or(col("c2")); - let expected = col("c2"); + // Would be nice if this API could use the SimplifyInfo + // rather than creating an DFSchemaRef coerces rather than doing + // it manually. + // https://github.com/apache/arrow-datafusion/issues/3793 + let expr = simplifier.coerce(expr, schema).unwrap(); - assert_eq!(simplify(expr), expected); + assert_eq!(expected, simplifier.simplify(expr).unwrap()); } - #[test] - fn test_simplify_and_false() { - let expr_a = lit(false).and(col("c2")); - let expr_b = col("c2").and(lit(false)); - let expected = lit(false); - - assert_eq!(simplify(expr_a), expected); - assert_eq!(simplify(expr_b), expected); + fn test_schema() -> DFSchemaRef { + Schema::new(vec![ + Field::new("i", DataType::Int64, false), + Field::new("b", DataType::Boolean, true), + ]) + .to_dfschema_ref() + .unwrap() } #[test] - fn test_simplify_and_same() { - let expr = col("c2").and(col("c2")); - let expected = col("c2"); + fn simplify_and_constant_prop() { + let props = ExecutionProps::new(); + let simplifier = + ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema())); - assert_eq!(simplify(expr), expected); + // should be able to simplify to false + // (i * (1 - 2)) > 0 + let expr = (col("i") * (lit(1) - lit(1))).gt(lit(0)); + let expected = lit(false); + assert_eq!(expected, simplifier.simplify(expr).unwrap()); } #[test] - fn test_simplify_and_true() { - let expr_a = lit(true).and(col("c2")); - let expr_b = col("c2").and(lit(true)); - let expected = col("c2"); + fn simplify_and_constant_prop_with_case() { + let props = ExecutionProps::new(); + let simplifier = + ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema())); - assert_eq!(simplify(expr_a), expected); - assert_eq!(simplify(expr_b), expected); + // CASE + // WHEN i>5 AND false THEN i > 5 + // WHEN i<5 AND true THEN i < 5 + // ELSE false + // END + // + // Can be simplified to `i < 5` + let expr = when(col("i").gt(lit(5)).and(lit(false)), col("i").gt(lit(5))) + .when(col("i").lt(lit(5)).and(lit(true)), col("i").lt(lit(5))) + .otherwise(lit(false)) + .unwrap(); + let expected = col("i").lt(lit(5)); + assert_eq!(expected, simplifier.simplify(expr).unwrap()); } - #[test] - fn test_simplify_multiply_by_one() { - let expr_a = binary_expr(col("c2"), Operator::Multiply, lit(1)); - let expr_b = binary_expr(lit(1), Operator::Multiply, col("c2")); - let expected = col("c2"); + // ------------------------------ + // --- ConstEvaluator tests ----- + // ------------------------------ + fn test_evaluate_with_start_time( + input_expr: Expr, + expected_expr: Expr, + date_time: &DateTime, + ) { + let execution_props = ExecutionProps { + query_execution_start_time: *date_time, + var_providers: None, + }; - assert_eq!(simplify(expr_a), expected); - assert_eq!(simplify(expr_b), expected); + let mut const_evaluator = ConstEvaluator::try_new(&execution_props).unwrap(); + let evaluated_expr = input_expr + .clone() + .rewrite(&mut const_evaluator) + .expect("successfully evaluated"); - let expr = binary_expr( - col("c2"), - Operator::Multiply, - Expr::Literal(ScalarValue::Decimal128(Some(10000000000), 38, 10)), - ); - assert_eq!(simplify(expr), expected); - let expr = binary_expr( - Expr::Literal(ScalarValue::Decimal128(Some(10000000000), 31, 10)), - Operator::Multiply, - col("c2"), + assert_eq!( + evaluated_expr, expected_expr, + "Mismatch evaluating {}\n Expected:{}\n Got:{}", + input_expr, expected_expr, evaluated_expr ); - assert_eq!(simplify(expr), expected); } - #[test] - fn test_simplify_multiply_by_null() { - let null = Expr::Literal(ScalarValue::Null); - // A * null --> null - { - let expr = binary_expr(col("c2"), Operator::Multiply, null.clone()); - assert_eq!(simplify(expr), null); - } - // null * A --> null - { - let expr = binary_expr(null.clone(), Operator::Multiply, col("c2")); - assert_eq!(simplify(expr), null); - } + fn test_evaluate(input_expr: Expr, expected_expr: Expr) { + test_evaluate_with_start_time(input_expr, expected_expr, &Utc::now()) } - #[test] - fn test_simplify_multiply_by_zero() { - // cannot optimize A * null (null * A) if A is nullable - { - let expr_a = binary_expr(col("c2"), Operator::Multiply, lit(0)); - let expr_b = binary_expr(lit(0), Operator::Multiply, col("c2")); + // Make a UDF that adds its two values together, with the specified volatility + fn make_udf_add(volatility: Volatility) -> Arc { + let input_types = vec![DataType::Int32, DataType::Int32]; + let return_type = Arc::new(DataType::Int32); - assert_eq!(simplify(expr_a.clone()), expr_a); - assert_eq!(simplify(expr_b.clone()), expr_b); - } - // 0 * A --> 0 if A is not nullable - { - let expr = binary_expr(lit(0), Operator::Multiply, col("c2_non_null")); - assert_eq!(simplify(expr), lit(0)); - } - // A * 0 --> 0 if A is not nullable - { - let expr = binary_expr(col("c2_non_null"), Operator::Multiply, lit(0)); - assert_eq!(simplify(expr), lit(0)); - } - // A * Decimal128(0) --> 0 if A is not nullable - { - let expr = binary_expr( - col("c2_non_null"), - Operator::Multiply, - Expr::Literal(ScalarValue::Decimal128(Some(0), 31, 10)), - ); - assert_eq!( - simplify(expr), - Expr::Literal(ScalarValue::Decimal128(Some(0), 31, 10)) - ); - let expr = binary_expr( - Expr::Literal(ScalarValue::Decimal128(Some(0), 31, 10)), - Operator::Multiply, - col("c2_non_null"), - ); - assert_eq!( - simplify(expr), - Expr::Literal(ScalarValue::Decimal128(Some(0), 31, 10)) - ); - } - } + let fun = |args: &[ArrayRef]| { + let arg0 = &args[0] + .as_any() + .downcast_ref::() + .expect("cast failed"); + let arg1 = &args[1] + .as_any() + .downcast_ref::() + .expect("cast failed"); - #[test] - fn test_simplify_divide_by_one() { - let expr = binary_expr(col("c2"), Operator::Divide, lit(1)); - let expected = col("c2"); - assert_eq!(simplify(expr), expected); - let expr = binary_expr( - col("c2"), - Operator::Divide, - Expr::Literal(ScalarValue::Decimal128(Some(10000000000), 31, 10)), - ); + // 2. perform the computation + let array = arg0 + .iter() + .zip(arg1.iter()) + .map(|args| { + if let (Some(arg0), Some(arg1)) = args { + Some(arg0 + arg1) + } else { + // one or both args were Null + None + } + }) + .collect::(); + + Ok(Arc::new(array) as ArrayRef) + }; + + let fun = make_scalar_function(fun); + Arc::new(create_udf( + "udf_add", + input_types, + return_type, + volatility, + fun, + )) + } + + #[test] + fn test_const_evaluator() { + // true --> true + test_evaluate(lit(true), lit(true)); + // true or true --> true + test_evaluate(lit(true).or(lit(true)), lit(true)); + // true or false --> true + test_evaluate(lit(true).or(lit(false)), lit(true)); + + // "foo" == "foo" --> true + test_evaluate(lit("foo").eq(lit("foo")), lit(true)); + // "foo" != "foo" --> false + test_evaluate(lit("foo").not_eq(lit("foo")), lit(false)); + + // c = 1 --> c = 1 + test_evaluate(col("c").eq(lit(1)), col("c").eq(lit(1))); + // c = 1 + 2 --> c + 3 + test_evaluate(col("c").eq(lit(1) + lit(2)), col("c").eq(lit(3))); + // (foo != foo) OR (c = 1) --> false OR (c = 1) + test_evaluate( + (lit("foo").not_eq(lit("foo"))).or(col("c").eq(lit(1))), + lit(false).or(col("c").eq(lit(1))), + ); + } + + #[test] + fn test_const_evaluator_scalar_functions() { + // concat("foo", "bar") --> "foobar" + let expr = call_fn("concat", vec![lit("foo"), lit("bar")]).unwrap(); + test_evaluate(expr, lit("foobar")); + + // ensure arguments are also constant folded + // concat("foo", concat("bar", "baz")) --> "foobarbaz" + let concat1 = call_fn("concat", vec![lit("bar"), lit("baz")]).unwrap(); + let expr = call_fn("concat", vec![lit("foo"), concat1]).unwrap(); + test_evaluate(expr, lit("foobarbaz")); + + // Check non string arguments + // to_timestamp("2020-09-08T12:00:00+00:00") --> timestamp(1599566400000000000i64) + let expr = + call_fn("to_timestamp", vec![lit("2020-09-08T12:00:00+00:00")]).unwrap(); + test_evaluate(expr, lit_timestamp_nano(1599566400000000000i64)); + + // check that non foldable arguments are folded + // to_timestamp(a) --> to_timestamp(a) [no rewrite possible] + let expr = call_fn("to_timestamp", vec![col("a")]).unwrap(); + test_evaluate(expr.clone(), expr); + + // check that non foldable arguments are folded + // to_timestamp(a) --> to_timestamp(a) [no rewrite possible] + let expr = call_fn("to_timestamp", vec![col("a")]).unwrap(); + test_evaluate(expr.clone(), expr); + + // volatile / stable functions should not be evaluated + // rand() + (1 + 2) --> rand() + 3 + let fun = BuiltinScalarFunction::Random; + assert_eq!(fun.volatility(), Volatility::Volatile); + let rand = Expr::ScalarFunction { args: vec![], fun }; + let expr = rand.clone() + (lit(1) + lit(2)); + let expected = rand + lit(3); + test_evaluate(expr, expected); + + // parenthesization matters: can't rewrite + // (rand() + 1) + 2 --> (rand() + 1) + 2) + let fun = BuiltinScalarFunction::Random; + let rand = Expr::ScalarFunction { args: vec![], fun }; + let expr = (rand + lit(1)) + lit(2); + test_evaluate(expr.clone(), expr); + } + + #[test] + fn test_const_evaluator_now() { + let ts_nanos = 1599566400000000000i64; + let time = chrono::Utc.timestamp_nanos(ts_nanos); + let ts_string = "2020-09-08T12:05:00+00:00"; + // now() --> ts + test_evaluate_with_start_time(now_expr(), lit_timestamp_nano(ts_nanos), &time); + + // CAST(now() as int64) + 100_i64 --> ts + 100_i64 + let expr = cast_to_int64_expr(now_expr()) + lit(100_i64); + test_evaluate_with_start_time(expr, lit(ts_nanos + 100), &time); + + // CAST(now() as int64) < cast(to_timestamp(...) as int64) + 50000_i64 ---> true + let expr = cast_to_int64_expr(now_expr()) + .lt(cast_to_int64_expr(to_timestamp_expr(ts_string)) + lit(50000i64)); + test_evaluate_with_start_time(expr, lit(true), &time); + } + + #[test] + fn test_evaluator_udfs() { + let args = vec![lit(1) + lit(2), lit(30) + lit(40)]; + let folded_args = vec![lit(3), lit(70)]; + + // immutable UDF should get folded + // udf_add(1+2, 30+40) --> 73 + let expr = Expr::ScalarUDF { + args: args.clone(), + fun: make_udf_add(Volatility::Immutable), + }; + test_evaluate(expr, lit(73)); + + // stable UDF should be entirely folded + // udf_add(1+2, 30+40) --> 73 + let fun = make_udf_add(Volatility::Stable); + let expr = Expr::ScalarUDF { + args: args.clone(), + fun: Arc::clone(&fun), + }; + test_evaluate(expr, lit(73)); + + // volatile UDF should have args folded + // udf_add(1+2, 30+40) --> udf_add(3, 70) + let fun = make_udf_add(Volatility::Volatile); + let expr = Expr::ScalarUDF { + args, + fun: Arc::clone(&fun), + }; + let expected_expr = Expr::ScalarUDF { + args: folded_args, + fun: Arc::clone(&fun), + }; + test_evaluate(expr, expected_expr); + } + + // ------------------------------ + // --- Simplifier tests ----- + // ------------------------------ + + #[test] + fn test_simplify_or_true() { + let expr_a = col("c2").or(lit(true)); + let expr_b = lit(true).or(col("c2")); + let expected = lit(true); + + assert_eq!(simplify(expr_a), expected); + assert_eq!(simplify(expr_b), expected); + } + + #[test] + fn test_simplify_or_false() { + let expr_a = lit(false).or(col("c2")); + let expr_b = col("c2").or(lit(false)); + let expected = col("c2"); + + assert_eq!(simplify(expr_a), expected); + assert_eq!(simplify(expr_b), expected); + } + + #[test] + fn test_simplify_or_same() { + let expr = col("c2").or(col("c2")); + let expected = col("c2"); + + assert_eq!(simplify(expr), expected); + } + + #[test] + fn test_simplify_and_false() { + let expr_a = lit(false).and(col("c2")); + let expr_b = col("c2").and(lit(false)); + let expected = lit(false); + + assert_eq!(simplify(expr_a), expected); + assert_eq!(simplify(expr_b), expected); + } + + #[test] + fn test_simplify_and_same() { + let expr = col("c2").and(col("c2")); + let expected = col("c2"); + + assert_eq!(simplify(expr), expected); + } + + #[test] + fn test_simplify_and_true() { + let expr_a = lit(true).and(col("c2")); + let expr_b = col("c2").and(lit(true)); + let expected = col("c2"); + + assert_eq!(simplify(expr_a), expected); + assert_eq!(simplify(expr_b), expected); + } + + #[test] + fn test_simplify_multiply_by_one() { + let expr_a = binary_expr(col("c2"), Operator::Multiply, lit(1)); + let expr_b = binary_expr(lit(1), Operator::Multiply, col("c2")); + let expected = col("c2"); + + assert_eq!(simplify(expr_a), expected); + assert_eq!(simplify(expr_b), expected); + + let expr = binary_expr( + col("c2"), + Operator::Multiply, + Expr::Literal(ScalarValue::Decimal128(Some(10000000000), 38, 10)), + ); + assert_eq!(simplify(expr), expected); + let expr = binary_expr( + Expr::Literal(ScalarValue::Decimal128(Some(10000000000), 31, 10)), + Operator::Multiply, + col("c2"), + ); + assert_eq!(simplify(expr), expected); + } + + #[test] + fn test_simplify_multiply_by_null() { + let null = Expr::Literal(ScalarValue::Null); + // A * null --> null + { + let expr = binary_expr(col("c2"), Operator::Multiply, null.clone()); + assert_eq!(simplify(expr), null); + } + // null * A --> null + { + let expr = binary_expr(null.clone(), Operator::Multiply, col("c2")); + assert_eq!(simplify(expr), null); + } + } + + #[test] + fn test_simplify_multiply_by_zero() { + // cannot optimize A * null (null * A) if A is nullable + { + let expr_a = binary_expr(col("c2"), Operator::Multiply, lit(0)); + let expr_b = binary_expr(lit(0), Operator::Multiply, col("c2")); + + assert_eq!(simplify(expr_a.clone()), expr_a); + assert_eq!(simplify(expr_b.clone()), expr_b); + } + // 0 * A --> 0 if A is not nullable + { + let expr = binary_expr(lit(0), Operator::Multiply, col("c2_non_null")); + assert_eq!(simplify(expr), lit(0)); + } + // A * 0 --> 0 if A is not nullable + { + let expr = binary_expr(col("c2_non_null"), Operator::Multiply, lit(0)); + assert_eq!(simplify(expr), lit(0)); + } + // A * Decimal128(0) --> 0 if A is not nullable + { + let expr = binary_expr( + col("c2_non_null"), + Operator::Multiply, + Expr::Literal(ScalarValue::Decimal128(Some(0), 31, 10)), + ); + assert_eq!( + simplify(expr), + Expr::Literal(ScalarValue::Decimal128(Some(0), 31, 10)) + ); + let expr = binary_expr( + Expr::Literal(ScalarValue::Decimal128(Some(0), 31, 10)), + Operator::Multiply, + col("c2_non_null"), + ); + assert_eq!( + simplify(expr), + Expr::Literal(ScalarValue::Decimal128(Some(0), 31, 10)) + ); + } + } + + #[test] + fn test_simplify_divide_by_one() { + let expr = binary_expr(col("c2"), Operator::Divide, lit(1)); + let expected = col("c2"); + assert_eq!(simplify(expr), expected); + let expr = binary_expr( + col("c2"), + Operator::Divide, + Expr::Literal(ScalarValue::Decimal128(Some(10000000000), 31, 10)), + ); assert_eq!(simplify(expr), expected); } @@ -1538,241 +1459,32 @@ mod tests { { let sub_expr = concat_ws(null.clone(), vec![col("c1"), col("c2")]); let expr = concat_ws(lit("|"), vec![sub_expr, col("c3")]); - assert_eq!(simplify(expr), concat_ws(lit("|"), vec![col("c3")])); - } - - // null delimiter (nested) - { - let sub_expr = concat_ws(null.clone(), vec![col("c1"), col("c2")]); - let expr = concat_ws(sub_expr, vec![col("c3"), col("c4")]); - assert_eq!(simplify(expr), null); - } - } - - #[test] - fn test_simplify_concat() { - let null = Expr::Literal(ScalarValue::Utf8(None)); - let expr = concat(&[ - null.clone(), - col("c0"), - lit("hello "), - null.clone(), - lit("rust"), - col("c1"), - lit(""), - null, - ]); - let expected = concat(&[col("c0"), lit("hello rust"), col("c1")]); - assert_eq!(simplify(expr), expected) - } - - // ------------------------------ - // --- ConstEvaluator tests ----- - // ------------------------------ - - #[test] - fn test_const_evaluator() { - // true --> true - test_evaluate(lit(true), lit(true)); - // true or true --> true - test_evaluate(lit(true).or(lit(true)), lit(true)); - // true or false --> true - test_evaluate(lit(true).or(lit(false)), lit(true)); - - // "foo" == "foo" --> true - test_evaluate(lit("foo").eq(lit("foo")), lit(true)); - // "foo" != "foo" --> false - test_evaluate(lit("foo").not_eq(lit("foo")), lit(false)); - - // c = 1 --> c = 1 - test_evaluate(col("c").eq(lit(1)), col("c").eq(lit(1))); - // c = 1 + 2 --> c + 3 - test_evaluate(col("c").eq(lit(1) + lit(2)), col("c").eq(lit(3))); - // (foo != foo) OR (c = 1) --> false OR (c = 1) - test_evaluate( - (lit("foo").not_eq(lit("foo"))).or(col("c").eq(lit(1))), - lit(false).or(col("c").eq(lit(1))), - ); - } - - #[test] - fn test_const_evaluator_scalar_functions() { - // concat("foo", "bar") --> "foobar" - let expr = call_fn("concat", vec![lit("foo"), lit("bar")]).unwrap(); - test_evaluate(expr, lit("foobar")); - - // ensure arguments are also constant folded - // concat("foo", concat("bar", "baz")) --> "foobarbaz" - let concat1 = call_fn("concat", vec![lit("bar"), lit("baz")]).unwrap(); - let expr = call_fn("concat", vec![lit("foo"), concat1]).unwrap(); - test_evaluate(expr, lit("foobarbaz")); - - // Check non string arguments - // to_timestamp("2020-09-08T12:00:00+00:00") --> timestamp(1599566400000000000i64) - let expr = - call_fn("to_timestamp", vec![lit("2020-09-08T12:00:00+00:00")]).unwrap(); - test_evaluate(expr, lit_timestamp_nano(1599566400000000000i64)); - - // check that non foldable arguments are folded - // to_timestamp(a) --> to_timestamp(a) [no rewrite possible] - let expr = call_fn("to_timestamp", vec![col("a")]).unwrap(); - test_evaluate(expr.clone(), expr); - - // check that non foldable arguments are folded - // to_timestamp(a) --> to_timestamp(a) [no rewrite possible] - let expr = call_fn("to_timestamp", vec![col("a")]).unwrap(); - test_evaluate(expr.clone(), expr); - - // volatile / stable functions should not be evaluated - // rand() + (1 + 2) --> rand() + 3 - let fun = BuiltinScalarFunction::Random; - assert_eq!(fun.volatility(), Volatility::Volatile); - let rand = Expr::ScalarFunction { args: vec![], fun }; - let expr = rand.clone() + (lit(1) + lit(2)); - let expected = rand + lit(3); - test_evaluate(expr, expected); - - // parenthesization matters: can't rewrite - // (rand() + 1) + 2 --> (rand() + 1) + 2) - let fun = BuiltinScalarFunction::Random; - let rand = Expr::ScalarFunction { args: vec![], fun }; - let expr = (rand + lit(1)) + lit(2); - test_evaluate(expr.clone(), expr); - } - - #[test] - fn test_const_evaluator_now() { - let ts_nanos = 1599566400000000000i64; - let time = chrono::Utc.timestamp_nanos(ts_nanos); - let ts_string = "2020-09-08T12:05:00+00:00"; - // now() --> ts - test_evaluate_with_start_time(now_expr(), lit_timestamp_nano(ts_nanos), &time); - - // CAST(now() as int64) + 100_i64 --> ts + 100_i64 - let expr = cast_to_int64_expr(now_expr()) + lit(100_i64); - test_evaluate_with_start_time(expr, lit(ts_nanos + 100), &time); - - // CAST(now() as int64) < cast(to_timestamp(...) as int64) + 50000_i64 ---> true - let expr = cast_to_int64_expr(now_expr()) - .lt(cast_to_int64_expr(to_timestamp_expr(ts_string)) + lit(50000i64)); - test_evaluate_with_start_time(expr, lit(true), &time); - } - - fn now_expr() -> Expr { - call_fn("now", vec![]).unwrap() - } - - fn cast_to_int64_expr(expr: Expr) -> Expr { - Expr::Cast(Cast::new(expr.into(), DataType::Int64)) - } - - fn to_timestamp_expr(arg: impl Into) -> Expr { - call_fn("to_timestamp", vec![lit(arg.into())]).unwrap() - } - - #[test] - fn test_evaluator_udfs() { - let args = vec![lit(1) + lit(2), lit(30) + lit(40)]; - let folded_args = vec![lit(3), lit(70)]; - - // immutable UDF should get folded - // udf_add(1+2, 30+40) --> 73 - let expr = Expr::ScalarUDF { - args: args.clone(), - fun: make_udf_add(Volatility::Immutable), - }; - test_evaluate(expr, lit(73)); - - // stable UDF should be entirely folded - // udf_add(1+2, 30+40) --> 73 - let fun = make_udf_add(Volatility::Stable); - let expr = Expr::ScalarUDF { - args: args.clone(), - fun: Arc::clone(&fun), - }; - test_evaluate(expr, lit(73)); - - // volatile UDF should have args folded - // udf_add(1+2, 30+40) --> udf_add(3, 70) - let fun = make_udf_add(Volatility::Volatile); - let expr = Expr::ScalarUDF { - args, - fun: Arc::clone(&fun), - }; - let expected_expr = Expr::ScalarUDF { - args: folded_args, - fun: Arc::clone(&fun), - }; - test_evaluate(expr, expected_expr); - } - - // Make a UDF that adds its two values together, with the specified volatility - fn make_udf_add(volatility: Volatility) -> Arc { - let input_types = vec![DataType::Int32, DataType::Int32]; - let return_type = Arc::new(DataType::Int32); - - let fun = |args: &[ArrayRef]| { - let arg0 = &args[0] - .as_any() - .downcast_ref::() - .expect("cast failed"); - let arg1 = &args[1] - .as_any() - .downcast_ref::() - .expect("cast failed"); - - // 2. perform the computation - let array = arg0 - .iter() - .zip(arg1.iter()) - .map(|args| { - if let (Some(arg0), Some(arg1)) = args { - Some(arg0 + arg1) - } else { - // one or both args were Null - None - } - }) - .collect::(); - - Ok(Arc::new(array) as ArrayRef) - }; - - let fun = make_scalar_function(fun); - Arc::new(create_udf( - "udf_add", - input_types, - return_type, - volatility, - fun, - )) - } - - fn test_evaluate_with_start_time( - input_expr: Expr, - expected_expr: Expr, - date_time: &DateTime, - ) { - let execution_props = ExecutionProps { - query_execution_start_time: *date_time, - var_providers: None, - }; - - let mut const_evaluator = ConstEvaluator::try_new(&execution_props).unwrap(); - let evaluated_expr = input_expr - .clone() - .rewrite(&mut const_evaluator) - .expect("successfully evaluated"); - - assert_eq!( - evaluated_expr, expected_expr, - "Mismatch evaluating {}\n Expected:{}\n Got:{}", - input_expr, expected_expr, evaluated_expr - ); + assert_eq!(simplify(expr), concat_ws(lit("|"), vec![col("c3")])); + } + + // null delimiter (nested) + { + let sub_expr = concat_ws(null.clone(), vec![col("c1"), col("c2")]); + let expr = concat_ws(sub_expr, vec![col("c3"), col("c4")]); + assert_eq!(simplify(expr), null); + } } - fn test_evaluate(input_expr: Expr, expected_expr: Expr) { - test_evaluate_with_start_time(input_expr, expected_expr, &Utc::now()) + #[test] + fn test_simplify_concat() { + let null = Expr::Literal(ScalarValue::Utf8(None)); + let expr = concat(&[ + null.clone(), + col("c0"), + lit("hello "), + null.clone(), + lit("rust"), + col("c1"), + lit(""), + null, + ]); + let expected = concat(&[col("c0"), lit("hello rust"), col("c1")]); + assert_eq!(simplify(expr), expected) } // ------------------------------ @@ -2101,696 +1813,4 @@ mod tests { or(col("c2").lt(lit(3)), col("c2").gt(lit(4))) ); } - - // ------------------------------ - // -- SimplifyExpressions tests - - // (test plans are simplified correctly) - // ------------------------------ - - fn test_table_scan() -> LogicalPlan { - let schema = Schema::new(vec![ - Field::new("a", DataType::Boolean, false), - Field::new("b", DataType::Boolean, false), - Field::new("c", DataType::Boolean, false), - Field::new("d", DataType::UInt32, false), - ]); - table_scan(Some("test"), &schema, None) - .expect("creating scan") - .build() - .expect("building plan") - } - - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { - let rule = SimplifyExpressions::new(); - let optimized_plan = rule - .optimize(plan, &mut OptimizerConfig::new()) - .expect("failed to optimize plan"); - let formatted_plan = format!("{:?}", optimized_plan); - assert_eq!(formatted_plan, expected); - } - - #[test] - fn test_simplify_optimized_plan() { - let table_scan = test_table_scan(); - let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("a")]) - .unwrap() - .filter(and(col("b").gt(lit(1)), col("b").gt(lit(1)))) - .unwrap() - .build() - .unwrap(); - - assert_optimized_plan_eq( - &plan, - "\ - Filter: test.b > Int32(1)\ - \n Projection: test.a\ - \n TableScan: test", - ); - } - - #[test] - fn test_simplify_optimized_plan_with_or() { - let table_scan = test_table_scan(); - let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("a")]) - .unwrap() - .filter(or(col("b").gt(lit(1)), col("b").gt(lit(1)))) - .unwrap() - .build() - .unwrap(); - - assert_optimized_plan_eq( - &plan, - "\ - Filter: test.b > Int32(1)\ - \n Projection: test.a\ - \n TableScan: test", - ); - } - - #[test] - fn test_simplify_optimized_plan_with_composed_and() { - let table_scan = test_table_scan(); - // ((c > 5) AND (d < 6)) AND (c > 5) --> (c > 5) AND (d < 6) - let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("a"), col("b")]) - .unwrap() - .filter(and( - and(col("a").gt(lit(5)), col("b").lt(lit(6))), - col("a").gt(lit(5)), - )) - .unwrap() - .build() - .unwrap(); - - assert_optimized_plan_eq( - &plan, - "\ - Filter: test.a > Int32(5) AND test.b < Int32(6)\ - \n Projection: test.a, test.b\ - \n TableScan: test", - ); - } - - #[test] - fn test_simplity_optimized_plan_eq_expr() { - let table_scan = test_table_scan(); - let plan = LogicalPlanBuilder::from(table_scan) - .filter(col("b").eq(lit(true))) - .unwrap() - .filter(col("c").eq(lit(false))) - .unwrap() - .project(vec![col("a")]) - .unwrap() - .build() - .unwrap(); - - let expected = "\ - Projection: test.a\ - \n Filter: NOT test.c\ - \n Filter: test.b\ - \n TableScan: test"; - - assert_optimized_plan_eq(&plan, expected); - } - - #[test] - fn test_simplity_optimized_plan_not_eq_expr() { - let table_scan = test_table_scan(); - let plan = LogicalPlanBuilder::from(table_scan) - .filter(col("b").not_eq(lit(true))) - .unwrap() - .filter(col("c").not_eq(lit(false))) - .unwrap() - .limit(0, Some(1)) - .unwrap() - .project(vec![col("a")]) - .unwrap() - .build() - .unwrap(); - - let expected = "\ - Projection: test.a\ - \n Limit: skip=0, fetch=1\ - \n Filter: test.c\ - \n Filter: NOT test.b\ - \n TableScan: test"; - - assert_optimized_plan_eq(&plan, expected); - } - - #[test] - fn test_simplity_optimized_plan_and_expr() { - let table_scan = test_table_scan(); - let plan = LogicalPlanBuilder::from(table_scan) - .filter(col("b").not_eq(lit(true)).and(col("c").eq(lit(true)))) - .unwrap() - .project(vec![col("a")]) - .unwrap() - .build() - .unwrap(); - - let expected = "\ - Projection: test.a\ - \n Filter: NOT test.b AND test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(&plan, expected); - } - - #[test] - fn test_simplity_optimized_plan_or_expr() { - let table_scan = test_table_scan(); - let plan = LogicalPlanBuilder::from(table_scan) - .filter(col("b").not_eq(lit(true)).or(col("c").eq(lit(false)))) - .unwrap() - .project(vec![col("a")]) - .unwrap() - .build() - .unwrap(); - - let expected = "\ - Projection: test.a\ - \n Filter: NOT test.b OR NOT test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(&plan, expected); - } - - #[test] - fn test_simplity_optimized_plan_not_expr() { - let table_scan = test_table_scan(); - let plan = LogicalPlanBuilder::from(table_scan) - .filter(col("b").eq(lit(false)).not()) - .unwrap() - .project(vec![col("a")]) - .unwrap() - .build() - .unwrap(); - - let expected = "\ - Projection: test.a\ - \n Filter: test.b\ - \n TableScan: test"; - - assert_optimized_plan_eq(&plan, expected); - } - - #[test] - fn test_simplity_optimized_plan_support_projection() { - let table_scan = test_table_scan(); - let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("a"), col("d"), col("b").eq(lit(false))]) - .unwrap() - .build() - .unwrap(); - - let expected = "\ - Projection: test.a, test.d, NOT test.b AS test.b = Boolean(false)\ - \n TableScan: test"; - - assert_optimized_plan_eq(&plan, expected); - } - - #[test] - fn test_simplity_optimized_plan_support_aggregate() { - let table_scan = test_table_scan(); - let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("a"), col("c"), col("b")]) - .unwrap() - .aggregate( - vec![col("a"), col("c")], - vec![ - datafusion_expr::max(col("b").eq(lit(true))), - datafusion_expr::min(col("b")), - ], - ) - .unwrap() - .build() - .unwrap(); - - let expected = "\ - Aggregate: groupBy=[[test.a, test.c]], aggr=[[MAX(test.b) AS MAX(test.b = Boolean(true)), MIN(test.b)]]\ - \n Projection: test.a, test.c, test.b\ - \n TableScan: test"; - - assert_optimized_plan_eq(&plan, expected); - } - - #[test] - fn test_simplity_optimized_plan_support_values() { - let expr1 = Expr::BinaryExpr(BinaryExpr::new( - Box::new(lit(1)), - Operator::Plus, - Box::new(lit(2)), - )); - let expr2 = Expr::BinaryExpr(BinaryExpr::new( - Box::new(lit(2)), - Operator::Minus, - Box::new(lit(1)), - )); - let values = vec![vec![expr1, expr2]]; - let plan = LogicalPlanBuilder::values(values).unwrap().build().unwrap(); - - let expected = "\ - Values: (Int32(3) AS Int32(1) + Int32(2), Int32(1) AS Int32(2) - Int32(1))"; - - assert_optimized_plan_eq(&plan, expected); - } - - // expect optimizing will result in an error, returning the error string - fn get_optimized_plan_err(plan: &LogicalPlan, date_time: &DateTime) -> String { - let mut config = - OptimizerConfig::new().with_query_execution_start_time(*date_time); - let rule = SimplifyExpressions::new(); - - let err = rule - .optimize(plan, &mut config) - .expect_err("expected optimization to fail"); - - err.to_string() - } - - fn get_optimized_plan_formatted( - plan: &LogicalPlan, - date_time: &DateTime, - ) -> String { - let mut config = - OptimizerConfig::new().with_query_execution_start_time(*date_time); - let rule = SimplifyExpressions::new(); - - let optimized_plan = rule - .optimize(plan, &mut config) - .expect("failed to optimize plan"); - format!("{:?}", optimized_plan) - } - - #[test] - fn to_timestamp_expr_folded() { - let table_scan = test_table_scan(); - let proj = vec![to_timestamp_expr("2020-09-08T12:00:00+00:00")]; - - let plan = LogicalPlanBuilder::from(table_scan) - .project(proj) - .unwrap() - .build() - .unwrap(); - - let expected = "Projection: TimestampNanosecond(1599566400000000000, None) AS totimestamp(Utf8(\"2020-09-08T12:00:00+00:00\"))\ - \n TableScan: test" - .to_string(); - let actual = get_optimized_plan_formatted(&plan, &Utc::now()); - assert_eq!(expected, actual); - } - - #[test] - fn to_timestamp_expr_wrong_arg() { - let table_scan = test_table_scan(); - let proj = vec![to_timestamp_expr("I'M NOT A TIMESTAMP")]; - let plan = LogicalPlanBuilder::from(table_scan) - .project(proj) - .unwrap() - .build() - .unwrap(); - - let expected = "Error parsing 'I'M NOT A TIMESTAMP' as timestamp"; - let actual = get_optimized_plan_err(&plan, &Utc::now()); - assert_contains!(actual, expected); - } - - #[test] - fn cast_expr() { - let table_scan = test_table_scan(); - let proj = vec![Expr::Cast(Cast::new(Box::new(lit("0")), DataType::Int32))]; - let plan = LogicalPlanBuilder::from(table_scan) - .project(proj) - .unwrap() - .build() - .unwrap(); - - let expected = "Projection: Int32(0) AS Utf8(\"0\")\ - \n TableScan: test"; - let actual = get_optimized_plan_formatted(&plan, &Utc::now()); - assert_eq!(expected, actual); - } - - #[test] - fn cast_expr_wrong_arg() { - let table_scan = test_table_scan(); - let proj = vec![Expr::Cast(Cast::new(Box::new(lit("")), DataType::Int32))]; - let plan = LogicalPlanBuilder::from(table_scan) - .project(proj) - .unwrap() - .build() - .unwrap(); - - let expected = "Cannot cast string '' to value of Int32 type"; - let actual = get_optimized_plan_err(&plan, &Utc::now()); - assert_contains!(actual, expected); - } - - #[test] - fn multiple_now_expr() { - let table_scan = test_table_scan(); - let time = Utc::now(); - let proj = vec![ - now_expr(), - Expr::Alias(Box::new(now_expr()), "t2".to_string()), - ]; - let plan = LogicalPlanBuilder::from(table_scan) - .project(proj) - .unwrap() - .build() - .unwrap(); - - // expect the same timestamp appears in both exprs - let actual = get_optimized_plan_formatted(&plan, &time); - let expected = format!( - "Projection: TimestampNanosecond({}, Some(\"UTC\")) AS now(), TimestampNanosecond({}, Some(\"UTC\")) AS t2\ - \n TableScan: test", - time.timestamp_nanos(), - time.timestamp_nanos() - ); - - assert_eq!(expected, actual); - } - - #[test] - fn simplify_and_eval() { - // demonstrate a case where the evaluation needs to run prior - // to the simplifier for it to work - let table_scan = test_table_scan(); - let time = Utc::now(); - // (true or false) != col --> !col - let proj = vec![lit(true).or(lit(false)).not_eq(col("a"))]; - let plan = LogicalPlanBuilder::from(table_scan) - .project(proj) - .unwrap() - .build() - .unwrap(); - - let actual = get_optimized_plan_formatted(&plan, &time); - let expected = - "Projection: NOT test.a AS Boolean(true) OR Boolean(false) != test.a\ - \n TableScan: test"; - - assert_eq!(expected, actual); - } - - #[test] - fn now_less_than_timestamp() { - let table_scan = test_table_scan(); - - let ts_string = "2020-09-08T12:05:00+00:00"; - let time = chrono::Utc.timestamp_nanos(1599566400000000000i64); - - // cast(now() as int) < cast(to_timestamp(...) as int) + 50000_i64 - let plan = - LogicalPlanBuilder::from(table_scan) - .filter( - cast_to_int64_expr(now_expr()) - .lt(cast_to_int64_expr(to_timestamp_expr(ts_string)) - + lit(50000_i64)), - ) - .unwrap() - .build() - .unwrap(); - - // Note that constant folder runs and folds the entire - // expression down to a single constant (true) - let expected = "Filter: Boolean(true)\ - \n TableScan: test"; - let actual = get_optimized_plan_formatted(&plan, &time); - - assert_eq!(expected, actual); - } - - #[test] - fn select_date_plus_interval() { - let table_scan = test_table_scan(); - - let ts_string = "2020-09-08T12:05:00+00:00"; - let time = chrono::Utc.timestamp_nanos(1599566400000000000i64); - - // now() < cast(to_timestamp(...) as int) + 5000000000 - let schema = table_scan.schema(); - - let date_plus_interval_expr = to_timestamp_expr(ts_string) - .cast_to(&DataType::Date32, schema) - .unwrap() - + Expr::Literal(ScalarValue::IntervalDayTime(Some(123i64 << 32))); - - let plan = LogicalPlanBuilder::from(table_scan.clone()) - .project(vec![date_plus_interval_expr]) - .unwrap() - .build() - .unwrap(); - - println!("{:?}", plan); - - // Note that constant folder runs and folds the entire - // expression down to a single constant (true) - let expected = r#"Projection: Date32("18636") AS totimestamp(Utf8("2020-09-08T12:05:00+00:00")) + IntervalDayTime("528280977408") - TableScan: test"#; - let actual = get_optimized_plan_formatted(&plan, &time); - - assert_eq!(expected, actual); - } - - #[test] - fn simplify_not_binary() { - let table_scan = test_table_scan(); - - let plan = LogicalPlanBuilder::from(table_scan) - .filter(col("d").gt(lit(10)).not()) - .unwrap() - .build() - .unwrap(); - let expected = "Filter: test.d <= Int32(10)\ - \n TableScan: test"; - - assert_optimized_plan_eq(&plan, expected); - } - - #[test] - fn simplify_not_bool_and() { - let table_scan = test_table_scan(); - - let plan = LogicalPlanBuilder::from(table_scan) - .filter(col("d").gt(lit(10)).and(col("d").lt(lit(100))).not()) - .unwrap() - .build() - .unwrap(); - let expected = "Filter: test.d <= Int32(10) OR test.d >= Int32(100)\ - \n TableScan: test"; - - assert_optimized_plan_eq(&plan, expected); - } - - #[test] - fn simplify_not_bool_or() { - let table_scan = test_table_scan(); - - let plan = LogicalPlanBuilder::from(table_scan) - .filter(col("d").gt(lit(10)).or(col("d").lt(lit(100))).not()) - .unwrap() - .build() - .unwrap(); - let expected = "Filter: test.d <= Int32(10) AND test.d >= Int32(100)\ - \n TableScan: test"; - - assert_optimized_plan_eq(&plan, expected); - } - - #[test] - fn simplify_not_not() { - let table_scan = test_table_scan(); - - let plan = LogicalPlanBuilder::from(table_scan) - .filter(col("d").gt(lit(10)).not().not()) - .unwrap() - .build() - .unwrap(); - let expected = "Filter: test.d > Int32(10)\ - \n TableScan: test"; - - assert_optimized_plan_eq(&plan, expected); - } - - #[test] - fn simplify_not_null() { - let table_scan = test_table_scan(); - - let plan = LogicalPlanBuilder::from(table_scan) - .filter(col("d").is_null().not()) - .unwrap() - .build() - .unwrap(); - let expected = "Filter: test.d IS NOT NULL\ - \n TableScan: test"; - - assert_optimized_plan_eq(&plan, expected); - } - - #[test] - fn simplify_not_not_null() { - let table_scan = test_table_scan(); - - let plan = LogicalPlanBuilder::from(table_scan) - .filter(col("d").is_not_null().not()) - .unwrap() - .build() - .unwrap(); - let expected = "Filter: test.d IS NULL\ - \n TableScan: test"; - - assert_optimized_plan_eq(&plan, expected); - } - - #[test] - fn simplify_not_in() { - let table_scan = test_table_scan(); - - let plan = LogicalPlanBuilder::from(table_scan) - .filter(col("d").in_list(vec![lit(1), lit(2), lit(3)], false).not()) - .unwrap() - .build() - .unwrap(); - let expected = "Filter: test.d NOT IN ([Int32(1), Int32(2), Int32(3)])\ - \n TableScan: test"; - - assert_optimized_plan_eq(&plan, expected); - } - - #[test] - fn simplify_not_not_in() { - let table_scan = test_table_scan(); - - let plan = LogicalPlanBuilder::from(table_scan) - .filter(col("d").in_list(vec![lit(1), lit(2), lit(3)], true).not()) - .unwrap() - .build() - .unwrap(); - let expected = "Filter: test.d IN ([Int32(1), Int32(2), Int32(3)])\ - \n TableScan: test"; - - assert_optimized_plan_eq(&plan, expected); - } - - #[test] - fn simplify_not_between() { - let table_scan = test_table_scan(); - let qual = Expr::Between(Between::new( - Box::new(col("d")), - false, - Box::new(lit(1)), - Box::new(lit(10)), - )); - - let plan = LogicalPlanBuilder::from(table_scan) - .filter(qual.not()) - .unwrap() - .build() - .unwrap(); - let expected = "Filter: test.d < Int32(1) OR test.d > Int32(10)\ - \n TableScan: test"; - - assert_optimized_plan_eq(&plan, expected); - } - - #[test] - fn simplify_not_not_between() { - let table_scan = test_table_scan(); - let qual = Expr::Between(Between::new( - Box::new(col("d")), - true, - Box::new(lit(1)), - Box::new(lit(10)), - )); - - let plan = LogicalPlanBuilder::from(table_scan) - .filter(qual.not()) - .unwrap() - .build() - .unwrap(); - let expected = "Filter: test.d >= Int32(1) AND test.d <= Int32(10)\ - \n TableScan: test"; - - assert_optimized_plan_eq(&plan, expected); - } - - #[test] - fn simplify_not_like() { - let schema = Schema::new(vec![ - Field::new("a", DataType::Utf8, false), - Field::new("b", DataType::Utf8, false), - ]); - let table_scan = table_scan(Some("test"), &schema, None) - .expect("creating scan") - .build() - .expect("building plan"); - - let plan = LogicalPlanBuilder::from(table_scan) - .filter(col("a").like(col("b")).not()) - .unwrap() - .build() - .unwrap(); - let expected = "Filter: test.a NOT LIKE test.b\ - \n TableScan: test"; - - assert_optimized_plan_eq(&plan, expected); - } - - #[test] - fn simplify_not_not_like() { - let schema = Schema::new(vec![ - Field::new("a", DataType::Utf8, false), - Field::new("b", DataType::Utf8, false), - ]); - let table_scan = table_scan(Some("test"), &schema, None) - .expect("creating scan") - .build() - .expect("building plan"); - - let plan = LogicalPlanBuilder::from(table_scan) - .filter(col("a").not_like(col("b")).not()) - .unwrap() - .build() - .unwrap(); - let expected = "Filter: test.a LIKE test.b\ - \n TableScan: test"; - - assert_optimized_plan_eq(&plan, expected); - } - - #[test] - fn simplify_not_distinct_from() { - let table_scan = test_table_scan(); - - let plan = LogicalPlanBuilder::from(table_scan) - .filter(binary_expr(col("d"), Operator::IsDistinctFrom, lit(10)).not()) - .unwrap() - .build() - .unwrap(); - let expected = "Filter: test.d IS NOT DISTINCT FROM Int32(10)\ - \n TableScan: test"; - - assert_optimized_plan_eq(&plan, expected); - } - - #[test] - fn simplify_not_not_distinct_from() { - let table_scan = test_table_scan(); - - let plan = LogicalPlanBuilder::from(table_scan) - .filter(binary_expr(col("d"), Operator::IsNotDistinctFrom, lit(10)).not()) - .unwrap() - .build() - .unwrap(); - let expected = "Filter: test.d IS DISTINCT FROM Int32(10)\ - \n TableScan: test"; - - assert_optimized_plan_eq(&plan, expected); - } } diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs b/datafusion/optimizer/src/simplify_expressions/mod.rs new file mode 100644 index 000000000000..0eb3359c692f --- /dev/null +++ b/datafusion/optimizer/src/simplify_expressions/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod context; +pub mod expr_simplifier; +pub mod simplify_exprs; +mod utils; + +pub use context::*; +pub use expr_simplifier::*; +pub use simplify_exprs::*; diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs new file mode 100644 index 000000000000..f5ace71b8e51 --- /dev/null +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -0,0 +1,846 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Simplify expressions optimizer rule and implementation + +use super::{ExprSimplifier, SimplifyContext}; +use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_common::Result; +use datafusion_expr::{logical_plan::LogicalPlan, utils::from_plan}; +use datafusion_physical_expr::execution_props::ExecutionProps; + +/// Optimizer Pass that simplifies [`LogicalPlan`]s by rewriting +/// [`Expr`]`s evaluating constants and applying algebraic +/// simplifications +/// +/// # Introduction +/// It uses boolean algebra laws to simplify or reduce the number of terms in expressions. +/// +/// # Example: +/// `Filter: b > 2 AND b > 2` +/// is optimized to +/// `Filter: b > 2` +/// +#[derive(Default)] +pub struct SimplifyExpressions {} + +impl OptimizerRule for SimplifyExpressions { + fn name(&self) -> &str { + "simplify_expressions" + } + + fn optimize( + &self, + plan: &LogicalPlan, + optimizer_config: &mut OptimizerConfig, + ) -> Result { + let mut execution_props = ExecutionProps::new(); + execution_props.query_execution_start_time = + optimizer_config.query_execution_start_time(); + self.optimize_internal(plan, &execution_props) + } +} + +impl SimplifyExpressions { + fn optimize_internal( + &self, + plan: &LogicalPlan, + execution_props: &ExecutionProps, + ) -> Result { + // We need to pass down the all schemas within the plan tree to `optimize_expr` in order to + // to evaluate expression types. For example, a projection plan's schema will only include + // projected columns. With just the projected schema, it's not possible to infer types for + // expressions that references non-projected columns within the same project plan or its + // children plans. + let info = plan + .all_schemas() + .into_iter() + .fold(SimplifyContext::new(execution_props), |context, schema| { + context.with_schema(schema.clone()) + }); + + let simplifier = ExprSimplifier::new(info); + + let new_inputs = plan + .inputs() + .iter() + .map(|input| self.optimize_internal(input, execution_props)) + .collect::>>()?; + + let expr = plan + .expressions() + .into_iter() + .map(|e| { + // We need to keep original expression name, if any. + // Constant folding should not change expression name. + let name = &e.display_name(); + + // Apply the actual simplification logic + let new_e = simplifier.simplify(e)?; + + let new_name = &new_e.display_name(); + + if let (Ok(expr_name), Ok(new_expr_name)) = (name, new_name) { + if expr_name != new_expr_name { + Ok(new_e.alias(expr_name)) + } else { + Ok(new_e) + } + } else { + Ok(new_e) + } + }) + .collect::>>()?; + + from_plan(plan, &expr, &new_inputs) + } +} + +impl SimplifyExpressions { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +#[cfg(test)] +mod tests { + use crate::simplify_expressions::utils::for_test::{ + cast_to_int64_expr, now_expr, to_timestamp_expr, + }; + + use super::*; + use arrow::datatypes::{DataType, Field, Schema}; + use chrono::{DateTime, TimeZone, Utc}; + use datafusion_common::ScalarValue; + use datafusion_expr::{or, Between, BinaryExpr, Cast, Operator}; + + use datafusion_expr::logical_plan::table_scan; + use datafusion_expr::{ + and, binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, Expr, + ExprSchemable, + }; + + /// A macro to assert that one string is contained within another with + /// a nice error message if they are not. + /// + /// Usage: `assert_contains!(actual, expected)` + /// + /// Is a macro so test error + /// messages are on the same line as the failure; + /// + /// Both arguments must be convertable into Strings (Into) + macro_rules! assert_contains { + ($ACTUAL: expr, $EXPECTED: expr) => { + let actual_value: String = $ACTUAL.into(); + let expected_value: String = $EXPECTED.into(); + assert!( + actual_value.contains(&expected_value), + "Can not find expected in actual.\n\nExpected:\n{}\n\nActual:\n{}", + expected_value, + actual_value + ); + }; + } + + fn test_table_scan() -> LogicalPlan { + let schema = Schema::new(vec![ + Field::new("a", DataType::Boolean, false), + Field::new("b", DataType::Boolean, false), + Field::new("c", DataType::Boolean, false), + Field::new("d", DataType::UInt32, false), + ]); + table_scan(Some("test"), &schema, None) + .expect("creating scan") + .build() + .expect("building plan") + } + + fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { + let rule = SimplifyExpressions::new(); + let optimized_plan = rule + .optimize(plan, &mut OptimizerConfig::new()) + .expect("failed to optimize plan"); + let formatted_plan = format!("{:?}", optimized_plan); + assert_eq!(formatted_plan, expected); + } + + #[test] + fn test_simplify_optimized_plan() { + let table_scan = test_table_scan(); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a")]) + .unwrap() + .filter(and(col("b").gt(lit(1)), col("b").gt(lit(1)))) + .unwrap() + .build() + .unwrap(); + + assert_optimized_plan_eq( + &plan, + "\ + Filter: test.b > Int32(1)\ + \n Projection: test.a\ + \n TableScan: test", + ); + } + + #[test] + fn test_simplify_optimized_plan_with_or() { + let table_scan = test_table_scan(); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a")]) + .unwrap() + .filter(or(col("b").gt(lit(1)), col("b").gt(lit(1)))) + .unwrap() + .build() + .unwrap(); + + assert_optimized_plan_eq( + &plan, + "\ + Filter: test.b > Int32(1)\ + \n Projection: test.a\ + \n TableScan: test", + ); + } + + #[test] + fn test_simplify_optimized_plan_with_composed_and() { + let table_scan = test_table_scan(); + // ((c > 5) AND (d < 6)) AND (c > 5) --> (c > 5) AND (d < 6) + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b")]) + .unwrap() + .filter(and( + and(col("a").gt(lit(5)), col("b").lt(lit(6))), + col("a").gt(lit(5)), + )) + .unwrap() + .build() + .unwrap(); + + assert_optimized_plan_eq( + &plan, + "\ + Filter: test.a > Int32(5) AND test.b < Int32(6)\ + \n Projection: test.a, test.b\ + \n TableScan: test", + ); + } + + #[test] + fn test_simplity_optimized_plan_eq_expr() { + let table_scan = test_table_scan(); + let plan = LogicalPlanBuilder::from(table_scan) + .filter(col("b").eq(lit(true))) + .unwrap() + .filter(col("c").eq(lit(false))) + .unwrap() + .project(vec![col("a")]) + .unwrap() + .build() + .unwrap(); + + let expected = "\ + Projection: test.a\ + \n Filter: NOT test.c\ + \n Filter: test.b\ + \n TableScan: test"; + + assert_optimized_plan_eq(&plan, expected); + } + + #[test] + fn test_simplity_optimized_plan_not_eq_expr() { + let table_scan = test_table_scan(); + let plan = LogicalPlanBuilder::from(table_scan) + .filter(col("b").not_eq(lit(true))) + .unwrap() + .filter(col("c").not_eq(lit(false))) + .unwrap() + .limit(0, Some(1)) + .unwrap() + .project(vec![col("a")]) + .unwrap() + .build() + .unwrap(); + + let expected = "\ + Projection: test.a\ + \n Limit: skip=0, fetch=1\ + \n Filter: test.c\ + \n Filter: NOT test.b\ + \n TableScan: test"; + + assert_optimized_plan_eq(&plan, expected); + } + + #[test] + fn test_simplity_optimized_plan_and_expr() { + let table_scan = test_table_scan(); + let plan = LogicalPlanBuilder::from(table_scan) + .filter(col("b").not_eq(lit(true)).and(col("c").eq(lit(true)))) + .unwrap() + .project(vec![col("a")]) + .unwrap() + .build() + .unwrap(); + + let expected = "\ + Projection: test.a\ + \n Filter: NOT test.b AND test.c\ + \n TableScan: test"; + + assert_optimized_plan_eq(&plan, expected); + } + + #[test] + fn test_simplity_optimized_plan_or_expr() { + let table_scan = test_table_scan(); + let plan = LogicalPlanBuilder::from(table_scan) + .filter(col("b").not_eq(lit(true)).or(col("c").eq(lit(false)))) + .unwrap() + .project(vec![col("a")]) + .unwrap() + .build() + .unwrap(); + + let expected = "\ + Projection: test.a\ + \n Filter: NOT test.b OR NOT test.c\ + \n TableScan: test"; + + assert_optimized_plan_eq(&plan, expected); + } + + #[test] + fn test_simplity_optimized_plan_not_expr() { + let table_scan = test_table_scan(); + let plan = LogicalPlanBuilder::from(table_scan) + .filter(col("b").eq(lit(false)).not()) + .unwrap() + .project(vec![col("a")]) + .unwrap() + .build() + .unwrap(); + + let expected = "\ + Projection: test.a\ + \n Filter: test.b\ + \n TableScan: test"; + + assert_optimized_plan_eq(&plan, expected); + } + + #[test] + fn test_simplity_optimized_plan_support_projection() { + let table_scan = test_table_scan(); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("d"), col("b").eq(lit(false))]) + .unwrap() + .build() + .unwrap(); + + let expected = "\ + Projection: test.a, test.d, NOT test.b AS test.b = Boolean(false)\ + \n TableScan: test"; + + assert_optimized_plan_eq(&plan, expected); + } + + #[test] + fn test_simplity_optimized_plan_support_aggregate() { + let table_scan = test_table_scan(); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("c"), col("b")]) + .unwrap() + .aggregate( + vec![col("a"), col("c")], + vec![ + datafusion_expr::max(col("b").eq(lit(true))), + datafusion_expr::min(col("b")), + ], + ) + .unwrap() + .build() + .unwrap(); + + let expected = "\ + Aggregate: groupBy=[[test.a, test.c]], aggr=[[MAX(test.b) AS MAX(test.b = Boolean(true)), MIN(test.b)]]\ + \n Projection: test.a, test.c, test.b\ + \n TableScan: test"; + + assert_optimized_plan_eq(&plan, expected); + } + + #[test] + fn test_simplity_optimized_plan_support_values() { + let expr1 = Expr::BinaryExpr(BinaryExpr::new( + Box::new(lit(1)), + Operator::Plus, + Box::new(lit(2)), + )); + let expr2 = Expr::BinaryExpr(BinaryExpr::new( + Box::new(lit(2)), + Operator::Minus, + Box::new(lit(1)), + )); + let values = vec![vec![expr1, expr2]]; + let plan = LogicalPlanBuilder::values(values).unwrap().build().unwrap(); + + let expected = "\ + Values: (Int32(3) AS Int32(1) + Int32(2), Int32(1) AS Int32(2) - Int32(1))"; + + assert_optimized_plan_eq(&plan, expected); + } + + // expect optimizing will result in an error, returning the error string + fn get_optimized_plan_err(plan: &LogicalPlan, date_time: &DateTime) -> String { + let mut config = + OptimizerConfig::new().with_query_execution_start_time(*date_time); + let rule = SimplifyExpressions::new(); + + let err = rule + .optimize(plan, &mut config) + .expect_err("expected optimization to fail"); + + err.to_string() + } + + fn get_optimized_plan_formatted( + plan: &LogicalPlan, + date_time: &DateTime, + ) -> String { + let mut config = + OptimizerConfig::new().with_query_execution_start_time(*date_time); + let rule = SimplifyExpressions::new(); + + let optimized_plan = rule + .optimize(plan, &mut config) + .expect("failed to optimize plan"); + format!("{:?}", optimized_plan) + } + + #[test] + fn to_timestamp_expr_folded() { + let table_scan = test_table_scan(); + let proj = vec![to_timestamp_expr("2020-09-08T12:00:00+00:00")]; + + let plan = LogicalPlanBuilder::from(table_scan) + .project(proj) + .unwrap() + .build() + .unwrap(); + + let expected = "Projection: TimestampNanosecond(1599566400000000000, None) AS totimestamp(Utf8(\"2020-09-08T12:00:00+00:00\"))\ + \n TableScan: test" + .to_string(); + let actual = get_optimized_plan_formatted(&plan, &Utc::now()); + assert_eq!(expected, actual); + } + + #[test] + fn to_timestamp_expr_wrong_arg() { + let table_scan = test_table_scan(); + let proj = vec![to_timestamp_expr("I'M NOT A TIMESTAMP")]; + let plan = LogicalPlanBuilder::from(table_scan) + .project(proj) + .unwrap() + .build() + .unwrap(); + + let expected = "Error parsing 'I'M NOT A TIMESTAMP' as timestamp"; + let actual = get_optimized_plan_err(&plan, &Utc::now()); + assert_contains!(actual, expected); + } + + #[test] + fn cast_expr() { + let table_scan = test_table_scan(); + let proj = vec![Expr::Cast(Cast::new(Box::new(lit("0")), DataType::Int32))]; + let plan = LogicalPlanBuilder::from(table_scan) + .project(proj) + .unwrap() + .build() + .unwrap(); + + let expected = "Projection: Int32(0) AS Utf8(\"0\")\ + \n TableScan: test"; + let actual = get_optimized_plan_formatted(&plan, &Utc::now()); + assert_eq!(expected, actual); + } + + #[test] + fn cast_expr_wrong_arg() { + let table_scan = test_table_scan(); + let proj = vec![Expr::Cast(Cast::new(Box::new(lit("")), DataType::Int32))]; + let plan = LogicalPlanBuilder::from(table_scan) + .project(proj) + .unwrap() + .build() + .unwrap(); + + let expected = "Cannot cast string '' to value of Int32 type"; + let actual = get_optimized_plan_err(&plan, &Utc::now()); + assert_contains!(actual, expected); + } + + #[test] + fn multiple_now_expr() { + let table_scan = test_table_scan(); + let time = Utc::now(); + let proj = vec![ + now_expr(), + Expr::Alias(Box::new(now_expr()), "t2".to_string()), + ]; + let plan = LogicalPlanBuilder::from(table_scan) + .project(proj) + .unwrap() + .build() + .unwrap(); + + // expect the same timestamp appears in both exprs + let actual = get_optimized_plan_formatted(&plan, &time); + let expected = format!( + "Projection: TimestampNanosecond({}, Some(\"UTC\")) AS now(), TimestampNanosecond({}, Some(\"UTC\")) AS t2\ + \n TableScan: test", + time.timestamp_nanos(), + time.timestamp_nanos() + ); + + assert_eq!(expected, actual); + } + + #[test] + fn simplify_and_eval() { + // demonstrate a case where the evaluation needs to run prior + // to the simplifier for it to work + let table_scan = test_table_scan(); + let time = Utc::now(); + // (true or false) != col --> !col + let proj = vec![lit(true).or(lit(false)).not_eq(col("a"))]; + let plan = LogicalPlanBuilder::from(table_scan) + .project(proj) + .unwrap() + .build() + .unwrap(); + + let actual = get_optimized_plan_formatted(&plan, &time); + let expected = + "Projection: NOT test.a AS Boolean(true) OR Boolean(false) != test.a\ + \n TableScan: test"; + + assert_eq!(expected, actual); + } + + #[test] + fn now_less_than_timestamp() { + let table_scan = test_table_scan(); + + let ts_string = "2020-09-08T12:05:00+00:00"; + let time = chrono::Utc.timestamp_nanos(1599566400000000000i64); + + // cast(now() as int) < cast(to_timestamp(...) as int) + 50000_i64 + let plan = + LogicalPlanBuilder::from(table_scan) + .filter( + cast_to_int64_expr(now_expr()) + .lt(cast_to_int64_expr(to_timestamp_expr(ts_string)) + + lit(50000_i64)), + ) + .unwrap() + .build() + .unwrap(); + + // Note that constant folder runs and folds the entire + // expression down to a single constant (true) + let expected = "Filter: Boolean(true)\ + \n TableScan: test"; + let actual = get_optimized_plan_formatted(&plan, &time); + + assert_eq!(expected, actual); + } + + #[test] + fn select_date_plus_interval() { + let table_scan = test_table_scan(); + + let ts_string = "2020-09-08T12:05:00+00:00"; + let time = chrono::Utc.timestamp_nanos(1599566400000000000i64); + + // now() < cast(to_timestamp(...) as int) + 5000000000 + let schema = table_scan.schema(); + + let date_plus_interval_expr = to_timestamp_expr(ts_string) + .cast_to(&DataType::Date32, schema) + .unwrap() + + Expr::Literal(ScalarValue::IntervalDayTime(Some(123i64 << 32))); + + let plan = LogicalPlanBuilder::from(table_scan.clone()) + .project(vec![date_plus_interval_expr]) + .unwrap() + .build() + .unwrap(); + + println!("{:?}", plan); + + // Note that constant folder runs and folds the entire + // expression down to a single constant (true) + let expected = r#"Projection: Date32("18636") AS totimestamp(Utf8("2020-09-08T12:05:00+00:00")) + IntervalDayTime("528280977408") + TableScan: test"#; + let actual = get_optimized_plan_formatted(&plan, &time); + + assert_eq!(expected, actual); + } + + #[test] + fn simplify_not_binary() { + let table_scan = test_table_scan(); + + let plan = LogicalPlanBuilder::from(table_scan) + .filter(col("d").gt(lit(10)).not()) + .unwrap() + .build() + .unwrap(); + let expected = "Filter: test.d <= Int32(10)\ + \n TableScan: test"; + + assert_optimized_plan_eq(&plan, expected); + } + + #[test] + fn simplify_not_bool_and() { + let table_scan = test_table_scan(); + + let plan = LogicalPlanBuilder::from(table_scan) + .filter(col("d").gt(lit(10)).and(col("d").lt(lit(100))).not()) + .unwrap() + .build() + .unwrap(); + let expected = "Filter: test.d <= Int32(10) OR test.d >= Int32(100)\ + \n TableScan: test"; + + assert_optimized_plan_eq(&plan, expected); + } + + #[test] + fn simplify_not_bool_or() { + let table_scan = test_table_scan(); + + let plan = LogicalPlanBuilder::from(table_scan) + .filter(col("d").gt(lit(10)).or(col("d").lt(lit(100))).not()) + .unwrap() + .build() + .unwrap(); + let expected = "Filter: test.d <= Int32(10) AND test.d >= Int32(100)\ + \n TableScan: test"; + + assert_optimized_plan_eq(&plan, expected); + } + + #[test] + fn simplify_not_not() { + let table_scan = test_table_scan(); + + let plan = LogicalPlanBuilder::from(table_scan) + .filter(col("d").gt(lit(10)).not().not()) + .unwrap() + .build() + .unwrap(); + let expected = "Filter: test.d > Int32(10)\ + \n TableScan: test"; + + assert_optimized_plan_eq(&plan, expected); + } + + #[test] + fn simplify_not_null() { + let table_scan = test_table_scan(); + + let plan = LogicalPlanBuilder::from(table_scan) + .filter(col("d").is_null().not()) + .unwrap() + .build() + .unwrap(); + let expected = "Filter: test.d IS NOT NULL\ + \n TableScan: test"; + + assert_optimized_plan_eq(&plan, expected); + } + + #[test] + fn simplify_not_not_null() { + let table_scan = test_table_scan(); + + let plan = LogicalPlanBuilder::from(table_scan) + .filter(col("d").is_not_null().not()) + .unwrap() + .build() + .unwrap(); + let expected = "Filter: test.d IS NULL\ + \n TableScan: test"; + + assert_optimized_plan_eq(&plan, expected); + } + + #[test] + fn simplify_not_in() { + let table_scan = test_table_scan(); + + let plan = LogicalPlanBuilder::from(table_scan) + .filter(col("d").in_list(vec![lit(1), lit(2), lit(3)], false).not()) + .unwrap() + .build() + .unwrap(); + let expected = "Filter: test.d NOT IN ([Int32(1), Int32(2), Int32(3)])\ + \n TableScan: test"; + + assert_optimized_plan_eq(&plan, expected); + } + + #[test] + fn simplify_not_not_in() { + let table_scan = test_table_scan(); + + let plan = LogicalPlanBuilder::from(table_scan) + .filter(col("d").in_list(vec![lit(1), lit(2), lit(3)], true).not()) + .unwrap() + .build() + .unwrap(); + let expected = "Filter: test.d IN ([Int32(1), Int32(2), Int32(3)])\ + \n TableScan: test"; + + assert_optimized_plan_eq(&plan, expected); + } + + #[test] + fn simplify_not_between() { + let table_scan = test_table_scan(); + let qual = Expr::Between(Between::new( + Box::new(col("d")), + false, + Box::new(lit(1)), + Box::new(lit(10)), + )); + + let plan = LogicalPlanBuilder::from(table_scan) + .filter(qual.not()) + .unwrap() + .build() + .unwrap(); + let expected = "Filter: test.d < Int32(1) OR test.d > Int32(10)\ + \n TableScan: test"; + + assert_optimized_plan_eq(&plan, expected); + } + + #[test] + fn simplify_not_not_between() { + let table_scan = test_table_scan(); + let qual = Expr::Between(Between::new( + Box::new(col("d")), + true, + Box::new(lit(1)), + Box::new(lit(10)), + )); + + let plan = LogicalPlanBuilder::from(table_scan) + .filter(qual.not()) + .unwrap() + .build() + .unwrap(); + let expected = "Filter: test.d >= Int32(1) AND test.d <= Int32(10)\ + \n TableScan: test"; + + assert_optimized_plan_eq(&plan, expected); + } + + #[test] + fn simplify_not_like() { + let schema = Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + ]); + let table_scan = table_scan(Some("test"), &schema, None) + .expect("creating scan") + .build() + .expect("building plan"); + + let plan = LogicalPlanBuilder::from(table_scan) + .filter(col("a").like(col("b")).not()) + .unwrap() + .build() + .unwrap(); + let expected = "Filter: test.a NOT LIKE test.b\ + \n TableScan: test"; + + assert_optimized_plan_eq(&plan, expected); + } + + #[test] + fn simplify_not_not_like() { + let schema = Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + ]); + let table_scan = table_scan(Some("test"), &schema, None) + .expect("creating scan") + .build() + .expect("building plan"); + + let plan = LogicalPlanBuilder::from(table_scan) + .filter(col("a").not_like(col("b")).not()) + .unwrap() + .build() + .unwrap(); + let expected = "Filter: test.a LIKE test.b\ + \n TableScan: test"; + + assert_optimized_plan_eq(&plan, expected); + } + + #[test] + fn simplify_not_distinct_from() { + let table_scan = test_table_scan(); + + let plan = LogicalPlanBuilder::from(table_scan) + .filter(binary_expr(col("d"), Operator::IsDistinctFrom, lit(10)).not()) + .unwrap() + .build() + .unwrap(); + let expected = "Filter: test.d IS NOT DISTINCT FROM Int32(10)\ + \n TableScan: test"; + + assert_optimized_plan_eq(&plan, expected); + } + + #[test] + fn simplify_not_not_distinct_from() { + let table_scan = test_table_scan(); + + let plan = LogicalPlanBuilder::from(table_scan) + .filter(binary_expr(col("d"), Operator::IsNotDistinctFrom, lit(10)).not()) + .unwrap() + .build() + .unwrap(); + let expected = "Filter: test.d IS DISTINCT FROM Int32(10)\ + \n TableScan: test"; + + assert_optimized_plan_eq(&plan, expected); + } +} diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs new file mode 100644 index 000000000000..d9314e329458 --- /dev/null +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -0,0 +1,369 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utitity functions for expression simplification + +use arrow::datatypes::DECIMAL128_MAX_PRECISION; + +use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_expr::{ + expr::{Between, BinaryExpr}, + expr_fn::{and, concat_ws, or}, + lit, BuiltinScalarFunction, Expr, Operator, +}; + +pub static POWS_OF_TEN: [i128; 38] = [ + 1, + 10, + 100, + 1000, + 10000, + 100000, + 1000000, + 10000000, + 100000000, + 1000000000, + 10000000000, + 100000000000, + 1000000000000, + 10000000000000, + 100000000000000, + 1000000000000000, + 10000000000000000, + 100000000000000000, + 1000000000000000000, + 10000000000000000000, + 100000000000000000000, + 1000000000000000000000, + 10000000000000000000000, + 100000000000000000000000, + 1000000000000000000000000, + 10000000000000000000000000, + 100000000000000000000000000, + 1000000000000000000000000000, + 10000000000000000000000000000, + 100000000000000000000000000000, + 1000000000000000000000000000000, + 10000000000000000000000000000000, + 100000000000000000000000000000000, + 1000000000000000000000000000000000, + 10000000000000000000000000000000000, + 100000000000000000000000000000000000, + 1000000000000000000000000000000000000, + 10000000000000000000000000000000000000, +]; + +/// returns true if `needle` is found in a chain of search_op +/// expressions. Such as: (A AND B) AND C +pub fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> bool { + match expr { + Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == search_op => { + expr_contains(left, needle, search_op) + || expr_contains(right, needle, search_op) + } + _ => expr == needle, + } +} + +pub fn is_zero(s: &Expr) -> bool { + match s { + Expr::Literal(ScalarValue::Int8(Some(0))) + | Expr::Literal(ScalarValue::Int16(Some(0))) + | Expr::Literal(ScalarValue::Int32(Some(0))) + | Expr::Literal(ScalarValue::Int64(Some(0))) + | Expr::Literal(ScalarValue::UInt8(Some(0))) + | Expr::Literal(ScalarValue::UInt16(Some(0))) + | Expr::Literal(ScalarValue::UInt32(Some(0))) + | Expr::Literal(ScalarValue::UInt64(Some(0))) => true, + Expr::Literal(ScalarValue::Float32(Some(v))) if *v == 0. => true, + Expr::Literal(ScalarValue::Float64(Some(v))) if *v == 0. => true, + Expr::Literal(ScalarValue::Decimal128(Some(v), _p, _s)) if *v == 0 => true, + _ => false, + } +} + +pub fn is_one(s: &Expr) -> bool { + match s { + Expr::Literal(ScalarValue::Int8(Some(1))) + | Expr::Literal(ScalarValue::Int16(Some(1))) + | Expr::Literal(ScalarValue::Int32(Some(1))) + | Expr::Literal(ScalarValue::Int64(Some(1))) + | Expr::Literal(ScalarValue::UInt8(Some(1))) + | Expr::Literal(ScalarValue::UInt16(Some(1))) + | Expr::Literal(ScalarValue::UInt32(Some(1))) + | Expr::Literal(ScalarValue::UInt64(Some(1))) => true, + Expr::Literal(ScalarValue::Float32(Some(v))) if *v == 1. => true, + Expr::Literal(ScalarValue::Float64(Some(v))) if *v == 1. => true, + Expr::Literal(ScalarValue::Decimal128(Some(v), _p, _s)) => { + *_s < DECIMAL128_MAX_PRECISION && POWS_OF_TEN[*_s as usize] == *v + } + _ => false, + } +} + +pub fn is_true(expr: &Expr) -> bool { + match expr { + Expr::Literal(ScalarValue::Boolean(Some(v))) => *v, + _ => false, + } +} + +/// returns true if expr is a +/// `Expr::Literal(ScalarValue::Boolean(v))` , false otherwise +pub fn is_bool_lit(expr: &Expr) -> bool { + matches!(expr, Expr::Literal(ScalarValue::Boolean(_))) +} + +/// Return a literal NULL value of Boolean data type +pub fn lit_bool_null() -> Expr { + Expr::Literal(ScalarValue::Boolean(None)) +} + +pub fn is_null(expr: &Expr) -> bool { + match expr { + Expr::Literal(v) => v.is_null(), + _ => false, + } +} + +pub fn is_false(expr: &Expr) -> bool { + match expr { + Expr::Literal(ScalarValue::Boolean(Some(v))) => !(*v), + _ => false, + } +} + +/// returns true if `haystack` looks like (needle OP X) or (X OP needle) +pub fn is_op_with(target_op: Operator, haystack: &Expr, needle: &Expr) -> bool { + matches!(haystack, Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == &target_op && (needle == left.as_ref() || needle == right.as_ref())) +} + +/// returns the contained boolean value in `expr` as +/// `Expr::Literal(ScalarValue::Boolean(v))`. +pub fn as_bool_lit(expr: Expr) -> Result> { + match expr { + Expr::Literal(ScalarValue::Boolean(v)) => Ok(v), + _ => Err(DataFusionError::Internal(format!( + "Expected boolean literal, got {:?}", + expr + ))), + } +} + +/// negate a Not clause +/// input is the clause to be negated.(args of Not clause) +/// For BinaryExpr, use the negator of op instead. +/// not ( A > B) ===> (A <= B) +/// For BoolExpr, not (A and B) ===> (not A) or (not B) +/// not (A or B) ===> (not A) and (not B) +/// not (not A) ===> A +/// For NullExpr, not (A is not null) ===> A is null +/// not (A is null) ===> A is not null +/// For InList, not (A not in (..)) ===> A in (..) +/// not (A in (..)) ===> A not in (..) +/// For Between, not (A between B and C) ===> (A not between B and C) +/// not (A not between B and C) ===> (A between B and C) +/// For others, use Not clause +pub fn negate_clause(expr: Expr) -> Expr { + match expr { + Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + if let Some(negated_op) = op.negate() { + return Expr::BinaryExpr(BinaryExpr::new(left, negated_op, right)); + } + match op { + // not (A and B) ===> (not A) or (not B) + Operator::And => { + let left = negate_clause(*left); + let right = negate_clause(*right); + + or(left, right) + } + // not (A or B) ===> (not A) and (not B) + Operator::Or => { + let left = negate_clause(*left); + let right = negate_clause(*right); + + and(left, right) + } + // use not clause + _ => Expr::Not(Box::new(Expr::BinaryExpr(BinaryExpr::new( + left, op, right, + )))), + } + } + // not (not A) ===> A + Expr::Not(expr) => *expr, + // not (A is not null) ===> A is null + Expr::IsNotNull(expr) => expr.is_null(), + // not (A is null) ===> A is not null + Expr::IsNull(expr) => expr.is_not_null(), + // not (A not in (..)) ===> A in (..) + // not (A in (..)) ===> A not in (..) + Expr::InList { + expr, + list, + negated, + } => expr.in_list(list, !negated), + // not (A between B and C) ===> (A not between B and C) + // not (A not between B and C) ===> (A between B and C) + Expr::Between(between) => Expr::Between(Between::new( + between.expr, + !between.negated, + between.low, + between.high, + )), + // use not clause + _ => Expr::Not(Box::new(expr)), + } +} + +/// Simplify the `concat` function by +/// 1. filtering out all `null` literals +/// 2. concatenating contiguous literal arguments +/// +/// For example: +/// `concat(col(a), 'hello ', 'world', col(b), null)` +/// will be optimized to +/// `concat(col(a), 'hello world', col(b))` +pub fn simpl_concat(args: Vec) -> Result { + let mut new_args = Vec::with_capacity(args.len()); + let mut contiguous_scalar = "".to_string(); + for arg in args { + match arg { + // filter out `null` args + Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None)) => {} + // All literals have been converted to Utf8 or LargeUtf8 in type_coercion. + // Concatenate it with the `contiguous_scalar`. + Expr::Literal( + ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)), + ) => contiguous_scalar += &v, + Expr::Literal(x) => { + return Err(DataFusionError::Internal(format!( + "The scalar {} should be casted to string type during the type coercion.", + x + ))) + } + // If the arg is not a literal, we should first push the current `contiguous_scalar` + // to the `new_args` (if it is not empty) and reset it to empty string. + // Then pushing this arg to the `new_args`. + arg => { + if !contiguous_scalar.is_empty() { + new_args.push(lit(contiguous_scalar)); + contiguous_scalar = "".to_string(); + } + new_args.push(arg); + } + } + } + if !contiguous_scalar.is_empty() { + new_args.push(lit(contiguous_scalar)); + } + + Ok(Expr::ScalarFunction { + fun: BuiltinScalarFunction::Concat, + args: new_args, + }) +} + +/// Simply the `concat_ws` function by +/// 1. folding to `null` if the delimiter is null +/// 2. filtering out `null` arguments +/// 3. using `concat` to replace `concat_ws` if the delimiter is an empty string +/// 4. concatenating contiguous literals if the delimiter is a literal. +pub fn simpl_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result { + match delimiter { + Expr::Literal( + ScalarValue::Utf8(delimiter) | ScalarValue::LargeUtf8(delimiter), + ) => { + match delimiter { + // when the delimiter is an empty string, + // we can use `concat` to replace `concat_ws` + Some(delimiter) if delimiter.is_empty() => simpl_concat(args.to_vec()), + Some(delimiter) => { + let mut new_args = Vec::with_capacity(args.len()); + new_args.push(lit(delimiter)); + let mut contiguous_scalar = None; + for arg in args { + match arg { + // filter out null args + Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None)) => {} + Expr::Literal(ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v))) => { + match contiguous_scalar { + None => contiguous_scalar = Some(v.to_string()), + Some(mut pre) => { + pre += delimiter; + pre += v; + contiguous_scalar = Some(pre) + } + } + } + Expr::Literal(s) => return Err(DataFusionError::Internal(format!("The scalar {} should be casted to string type during the type coercion.", s))), + // If the arg is not a literal, we should first push the current `contiguous_scalar` + // to the `new_args` and reset it to None. + // Then pushing this arg to the `new_args`. + arg => { + if let Some(val) = contiguous_scalar { + new_args.push(lit(val)); + } + new_args.push(arg.clone()); + contiguous_scalar = None; + } + } + } + if let Some(val) = contiguous_scalar { + new_args.push(lit(val)); + } + Ok(Expr::ScalarFunction { + fun: BuiltinScalarFunction::ConcatWithSeparator, + args: new_args, + }) + } + // if the delimiter is null, then the value of the whole expression is null. + None => Ok(Expr::Literal(ScalarValue::Utf8(None))), + } + } + Expr::Literal(d) => Err(DataFusionError::Internal(format!( + "The scalar {} should be casted to string type during the type coercion.", + d + ))), + d => Ok(concat_ws( + d.clone(), + args.iter() + .cloned() + .filter(|x| !is_null(x)) + .collect::>(), + )), + } +} + +#[cfg(test)] +pub mod for_test { + use arrow::datatypes::DataType; + use datafusion_expr::{call_fn, lit, Cast, Expr}; + + pub fn now_expr() -> Expr { + call_fn("now", vec![]).unwrap() + } + + pub fn cast_to_int64_expr(expr: Expr) -> Expr { + Expr::Cast(Cast::new(expr.into(), DataType::Int64)) + } + + pub fn to_timestamp_expr(arg: impl Into) -> Expr { + call_fn("to_timestamp", vec![lit(arg.into())]).unwrap() + } +}