diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 38dfbb3ed551..2620c7ee9a8e 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1348,6 +1348,17 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // Expr::Negative(inner) => Transformed::yes(distribute_negation(*inner)), + // + // Rules for Associativity + // + Expr::BinaryExpr(BinaryExpr { + left, + op: op @ (Operator::Plus | Operator::Multiply), + right, + }) if can_rearrange_literals(&left, op, &right) => { + Transformed::yes(rearrange_literals(*left, op, *right)) + } + // // Rules for Case // @@ -3189,6 +3200,27 @@ mod tests { ); } + #[test] + fn simplify_associative() { + // i + 1 + 2 => i + 3 + assert_eq!(simplify(col("c3") + lit(1) + lit(2)), (col("c3") + lit(3))); + + // (i + 1) + 2 => i + 3 + assert_eq!( + simplify((col("c3") + lit(1)) + lit(2)), + (col("c3") + lit(3)) + ); + + // i * 2 * 3 => i * 6 + assert_eq!(simplify(col("c3") * lit(2) * lit(3)), (col("c3") * lit(6))); + + // (i * 2) * 3 => i * 6 + assert_eq!( + simplify((col("c3") * lit(2)) * lit(3)), + (col("c3") * lit(6)) + ); + } + #[test] fn simplify_expr_case_when_then_else() { // CASE WHEN c2 != false THEN "ok" == "not_ok" ELSE c2 == true diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index 5da727cb5990..4eb49c9a6aa1 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -341,3 +341,45 @@ pub fn distribute_negation(expr: Expr) -> Expr { _ => Expr::Negative(Box::new(expr)), } } + +pub fn can_rearrange_literals(left: &Expr, op: Operator, right: &Expr) -> bool { + if let Expr::BinaryExpr(BinaryExpr { + left: l_left, + op: l_op, + right: l_right, + }) = left + { + if l_op == &op + && matches!(**l_left, Expr::Column(_)) + && matches!(**l_right, Expr::Literal(_)) + && matches!(right, Expr::Literal(_)) + { + return true; + } + }; + false +} + +pub fn rearrange_literals(left: Expr, op: Operator, right: Expr) -> Expr { + if let Expr::BinaryExpr(BinaryExpr { + left: l_left, + op: l_op, + right: l_right, + }) = &left + { + if l_op == &op + && matches!(**l_left, Expr::Column(_)) + && matches!(**l_right, Expr::Literal(_)) + && matches!(right, Expr::Literal(_)) + { + let right_expr = Expr::BinaryExpr(BinaryExpr { + left: l_right.clone(), + op, + right: Box::new(right), + }); + return Expr::BinaryExpr(BinaryExpr::new(l_left.clone(), op, Box::new(right_expr))); + } + }; + + Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right))) +}