From 7d5359f6a7520ebe966231c8a45f160c2d09dd7d Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Mon, 9 Mar 2026 22:20:30 -0700 Subject: [PATCH 1/3] Optimize logical optimizer: skip map_subqueries + in-place rewriting Three optimizations that together yield ~23-25% faster optimization on TPC-H/TPC-DS and up to 33% on expression-heavy queries: 1. map_subqueries short-circuit: skip expression tree walks when no subquery expressions exist. Previously rewrite_with_subqueries called map_subqueries at every plan node, walking all expression trees via ownership-based transform_down even with no subqueries. 2. plan_has_subqueries per-pass check: when no subqueries exist in the plan, bypass rewrite_with_subqueries entirely and use the cheaper rewrite_plan_in_place path. 3. rewrite_plan_in_place with Arc::make_mut: new map_children_mut method that mutates children in-place, avoiding the Arc::unwrap_or_clone + Arc::new allocation cycle at every node. The owned-plan rule API is bridged with std::mem::take, which is allocation-free: LogicalPlan::default() is an EmptyRelation that shares the process-wide static empty schema. Also adds optimizer-only benchmarks that isolate optimizer performance from SQL parsing and analysis overhead. Co-Authored-By: Claude Opus 4.6 Co-Authored-By: Claude Opus 4.7 (1M context) --- datafusion/core/benches/sql_planner.rs | 455 +++++++++++++++++- datafusion/expr/src/logical_plan/tree_node.rs | 142 ++++++ datafusion/optimizer/src/optimizer.rs | 135 +++++- 3 files changed, 727 insertions(+), 5 deletions(-) diff --git a/datafusion/core/benches/sql_planner.rs b/datafusion/core/benches/sql_planner.rs index fcc8da30fedd9..5e4d3d2b253d3 100644 --- a/datafusion/core/benches/sql_planner.rs +++ b/datafusion/core/benches/sql_planner.rs @@ -41,11 +41,37 @@ const BENCHMARKS_PATH_1: &str = "../../benchmarks/"; const BENCHMARKS_PATH_2: &str = "./benchmarks/"; const CLICKBENCH_DATA_PATH: &str = "data/hits_partitioned/"; -/// Create a logical plan from the specified sql +/// Create a logical plan from the specified sql (parse + analyze only, NO optimization) fn logical_plan(ctx: &SessionContext, rt: &Runtime, sql: &str) { black_box(rt.block_on(ctx.sql(sql)).unwrap()); } +/// Parse SQL and run the analyzer to get an analyzed (but unoptimized) LogicalPlan. +/// This is the input to the optimizer. +fn analyzed_plan( + ctx: &SessionContext, + rt: &Runtime, + sql: &str, +) -> datafusion_expr::LogicalPlan { + let state = ctx.state(); + let plan = rt.block_on(state.create_logical_plan(sql)).unwrap(); + state + .analyzer() + .execute_and_check(plan, state.config().options(), |_, _| {}) + .unwrap() +} + +/// Run ONLY the optimizer on a pre-analyzed plan. Measures optimizer cost in isolation. +fn optimize_plan(ctx: &SessionContext, plan: &datafusion_expr::LogicalPlan) { + let state = ctx.state(); + black_box( + state + .optimizer() + .optimize(plan.clone(), &state, |_, _| {}) + .unwrap(), + ); +} + /// Create a physical ExecutionPlan (by way of logical plan) fn physical_plan(ctx: &SessionContext, rt: &Runtime, sql: &str) { black_box(rt.block_on(async { @@ -646,6 +672,433 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("with_param_values_many_columns", |b| { benchmark_with_param_values_many_columns(&ctx, &rt, b); }); + + // ========================================================================== + // Optimizer-focused benchmarks + // These benchmarks are designed to stress the logical optimizer with + // varying plan sizes, expression counts, and node type distributions. + // ========================================================================== + + // --- Deep join trees (many plan nodes, few expressions) --- + // Tests optimizer traversal cost as plan node count grows. + // Each join adds ~3 nodes (Join, TableScan, CrossJoin/Filter). + + // Register additional tables for join benchmarks + for i in 3..=16 { + ctx.register_table(format!("j{i}"), create_table_provider("x", 10)) + .unwrap(); + } + + c.bench_function("logical_join_chain_4", |b| { + b.iter(|| { + logical_plan( + &ctx, + &rt, + "SELECT j3.x0 FROM j3 \ + JOIN j4 ON j3.x0 = j4.x0 \ + JOIN j5 ON j4.x0 = j5.x0 \ + JOIN j6 ON j5.x0 = j6.x0", + ) + }) + }); + + c.bench_function("logical_join_chain_8", |b| { + b.iter(|| { + logical_plan( + &ctx, + &rt, + "SELECT j3.x0 FROM j3 \ + JOIN j4 ON j3.x0 = j4.x0 \ + JOIN j5 ON j4.x0 = j5.x0 \ + JOIN j6 ON j5.x0 = j6.x0 \ + JOIN j7 ON j6.x0 = j7.x0 \ + JOIN j8 ON j7.x0 = j8.x0 \ + JOIN j9 ON j8.x0 = j9.x0 \ + JOIN j10 ON j9.x0 = j10.x0", + ) + }) + }); + + c.bench_function("logical_join_chain_16", |b| { + b.iter(|| { + logical_plan( + &ctx, + &rt, + "SELECT j3.x0 FROM j3 \ + JOIN j4 ON j3.x0 = j4.x0 \ + JOIN j5 ON j4.x0 = j5.x0 \ + JOIN j6 ON j5.x0 = j6.x0 \ + JOIN j7 ON j6.x0 = j7.x0 \ + JOIN j8 ON j7.x0 = j8.x0 \ + JOIN j9 ON j8.x0 = j9.x0 \ + JOIN j10 ON j9.x0 = j10.x0 \ + JOIN j11 ON j10.x0 = j11.x0 \ + JOIN j12 ON j11.x0 = j12.x0 \ + JOIN j13 ON j12.x0 = j13.x0 \ + JOIN j14 ON j13.x0 = j14.x0 \ + JOIN j15 ON j14.x0 = j15.x0 \ + JOIN j16 ON j15.x0 = j16.x0 \ + JOIN j3 AS j3b ON j16.x0 = j3b.x0 \ + JOIN j4 AS j4b ON j3b.x0 = j4b.x0", + ) + }) + }); + + // --- Wide expressions (few plan nodes, many expressions) --- + // Tests expression processing overhead in optimizer rules like + // SimplifyExpressions, CommonSubexprEliminate, OptimizeProjections. + + // Many WHERE clauses (filter expressions) + { + let predicates: Vec = (0..50).map(|i| format!("a{i} > 0")).collect(); + let query = format!("SELECT a0 FROM t1 WHERE {}", predicates.join(" AND ")); + c.bench_function("logical_wide_filter_50_predicates", |b| { + b.iter(|| logical_plan(&ctx, &rt, &query)) + }); + } + + { + let predicates: Vec = (0..200).map(|i| format!("a{i} > 0")).collect(); + let query = format!("SELECT a0 FROM t1 WHERE {}", predicates.join(" AND ")); + c.bench_function("logical_wide_filter_200_predicates", |b| { + b.iter(|| logical_plan(&ctx, &rt, &query)) + }); + } + + // Many aggregate expressions + { + let aggs: Vec = + (0..50).map(|i| format!("SUM(a{i}), AVG(a{i})")).collect(); + let query = format!("SELECT {} FROM t1", aggs.join(", ")); + c.bench_function("logical_wide_aggregate_100_exprs", |b| { + b.iter(|| logical_plan(&ctx, &rt, &query)) + }); + } + + // Many CASE WHEN expressions (complex expressions) + { + let cases: Vec = (0..50) + .map(|i| { + format!("CASE WHEN a{i} > 0 THEN a{i} * 2 ELSE a{i} + 1 END AS r{i}") + }) + .collect(); + let query = format!("SELECT {} FROM t1", cases.join(", ")); + c.bench_function("logical_wide_case_50_exprs", |b| { + b.iter(|| logical_plan(&ctx, &rt, &query)) + }); + } + + // --- Mixed: deep plan + wide expressions --- + // This is the worst case for optimizer: many nodes AND many expressions. + + c.bench_function("logical_join_4_with_agg_and_filter", |b| { + b.iter(|| { + logical_plan( + &ctx, + &rt, + "SELECT j3.x0, SUM(j4.x1), AVG(j5.x2), COUNT(j6.x3), \ + MIN(j3.x4), MAX(j4.x5) \ + FROM j3 \ + JOIN j4 ON j3.x0 = j4.x0 \ + JOIN j5 ON j4.x0 = j5.x0 \ + JOIN j6 ON j5.x0 = j6.x0 \ + WHERE j3.x1 > 0 AND j4.x2 < 100 AND j5.x3 != j6.x4 \ + GROUP BY j3.x0 \ + HAVING SUM(j4.x1) > 10 \ + ORDER BY j3.x0", + ) + }) + }); + + c.bench_function("logical_join_8_with_agg_sort_limit", |b| { + b.iter(|| { + logical_plan( + &ctx, + &rt, + "SELECT j3.x0, j4.x1, j5.x2, \ + SUM(j6.x3), AVG(j7.x4), COUNT(j8.x5), \ + MIN(j9.x6), MAX(j10.x7) \ + FROM j3 \ + JOIN j4 ON j3.x0 = j4.x0 \ + JOIN j5 ON j4.x0 = j5.x0 \ + JOIN j6 ON j5.x0 = j6.x0 \ + JOIN j7 ON j6.x0 = j7.x0 \ + JOIN j8 ON j7.x0 = j8.x0 \ + JOIN j9 ON j8.x0 = j9.x0 \ + JOIN j10 ON j9.x0 = j10.x0 \ + WHERE j3.x1 > 0 AND j5.x2 < 100 \ + GROUP BY j3.x0, j4.x1, j5.x2 \ + ORDER BY j3.x0 DESC \ + LIMIT 100", + ) + }) + }); + + // --- Subqueries (trigger decorrelation rules) --- + // Tests rules like DecorrelatePredicateSubquery, ScalarSubqueryToJoin. + + c.bench_function("logical_correlated_subquery_exists", |b| { + b.iter(|| { + logical_plan( + &ctx, + &rt, + "SELECT a0, a1 FROM t1 \ + WHERE EXISTS (SELECT 1 FROM t2 WHERE t2.b0 = t1.a0)", + ) + }) + }); + + c.bench_function("logical_correlated_subquery_in", |b| { + b.iter(|| { + logical_plan( + &ctx, + &rt, + "SELECT a0, a1 FROM t1 \ + WHERE a0 IN (SELECT b0 FROM t2 WHERE t2.b1 = t1.a1)", + ) + }) + }); + + c.bench_function("logical_scalar_subquery", |b| { + b.iter(|| { + logical_plan( + &ctx, + &rt, + "SELECT a0, (SELECT MAX(b1) FROM t2 WHERE t2.b0 = t1.a0) AS max_b \ + FROM t1", + ) + }) + }); + + c.bench_function("logical_multiple_subqueries", |b| { + b.iter(|| { + logical_plan( + &ctx, + &rt, + "SELECT a0, a1 FROM t1 \ + WHERE a0 IN (SELECT b0 FROM t2 WHERE b1 > 0) \ + AND EXISTS (SELECT 1 FROM t2 WHERE t2.b0 = t1.a0 AND t2.b1 < 100) \ + AND a1 > (SELECT AVG(b1) FROM t2)", + ) + }) + }); + + // --- UNION queries (test OptimizeUnions, PropagateEmptyRelation) --- + + c.bench_function("logical_union_4_branches", |b| { + b.iter(|| { + logical_plan( + &ctx, + &rt, + "SELECT a0, a1 FROM t1 WHERE a0 > 0 \ + UNION ALL SELECT a0, a1 FROM t1 WHERE a0 > 10 \ + UNION ALL SELECT a0, a1 FROM t1 WHERE a0 > 20 \ + UNION ALL SELECT a0, a1 FROM t1 WHERE a0 > 30", + ) + }) + }); + + c.bench_function("logical_union_8_branches", |b| { + b.iter(|| { + logical_plan( + &ctx, + &rt, + "SELECT a0, a1 FROM t1 WHERE a0 > 0 \ + UNION ALL SELECT a0, a1 FROM t1 WHERE a0 > 10 \ + UNION ALL SELECT a0, a1 FROM t1 WHERE a0 > 20 \ + UNION ALL SELECT a0, a1 FROM t1 WHERE a0 > 30 \ + UNION ALL SELECT a0, a1 FROM t1 WHERE a0 > 40 \ + UNION ALL SELECT a0, a1 FROM t1 WHERE a0 > 50 \ + UNION ALL SELECT a0, a1 FROM t1 WHERE a0 > 60 \ + UNION ALL SELECT a0, a1 FROM t1 WHERE a0 > 70", + ) + }) + }); + + // --- DISTINCT (test ReplaceDistinctWithAggregate) --- + + c.bench_function("logical_distinct_many_columns", |b| { + let cols: Vec = (0..50).map(|i| format!("a{i}")).collect(); + let query = format!("SELECT DISTINCT {} FROM t1", cols.join(", ")); + b.iter(|| logical_plan(&ctx, &rt, &query)) + }); + + // --- Nested views / CTEs (deeper plan trees) --- + + c.bench_function("logical_nested_cte_4_levels", |b| { + b.iter(|| { + logical_plan( + &ctx, + &rt, + "WITH \ + cte1 AS (SELECT a0, a1, a2 FROM t1 WHERE a0 > 0), \ + cte2 AS (SELECT a0, a1 FROM cte1 WHERE a1 > 0), \ + cte3 AS (SELECT a0 FROM cte2 WHERE a0 < 100), \ + cte4 AS (SELECT a0, COUNT(*) AS cnt FROM cte3 GROUP BY a0) \ + SELECT * FROM cte4 ORDER BY a0 LIMIT 10", + ) + }) + }); + + // --- TPC-H logical plans (uncommented from existing code) --- + // These test real-world query patterns with moderate plan complexity. + + c.bench_function("logical_plan_tpch_all", |b| { + b.iter(|| { + for sql in &all_tpch_sql_queries { + logical_plan(&tpch_ctx, &rt, sql) + } + }) + }); + + c.bench_function("logical_plan_tpcds_all", |b| { + b.iter(|| { + for sql in &all_tpcds_sql_queries { + logical_plan(&tpcds_ctx, &rt, sql) + } + }) + }); + + // ========================================================================== + // Optimizer-only benchmarks + // These measure ONLY the optimizer, not SQL parsing or analysis. + // Plans are pre-parsed and pre-analyzed in setup, then only optimization + // is measured in the benchmark loop. + // ========================================================================== + + // Simple select (baseline: few nodes, few expressions) + { + let plan = analyzed_plan(&ctx, &rt, "SELECT c1 FROM t700"); + c.bench_function("optimizer_select_one_from_700", |b| { + b.iter(|| optimize_plan(&ctx, &plan)) + }); + } + + // Wide select (many expressions, few nodes) + { + let plan = analyzed_plan(&ctx, &rt, "SELECT * FROM t1000"); + c.bench_function("optimizer_select_all_from_1000", |b| { + b.iter(|| optimize_plan(&ctx, &plan)) + }); + } + + // Deep join chains (many nodes, few expressions) + { + let plan = analyzed_plan( + &ctx, + &rt, + "SELECT j3.x0 FROM j3 \ + JOIN j4 ON j3.x0 = j4.x0 \ + JOIN j5 ON j4.x0 = j5.x0 \ + JOIN j6 ON j5.x0 = j6.x0", + ); + c.bench_function("optimizer_join_chain_4", |b| { + b.iter(|| optimize_plan(&ctx, &plan)) + }); + } + + { + let plan = analyzed_plan( + &ctx, + &rt, + "SELECT j3.x0 FROM j3 \ + JOIN j4 ON j3.x0 = j4.x0 \ + JOIN j5 ON j4.x0 = j5.x0 \ + JOIN j6 ON j5.x0 = j6.x0 \ + JOIN j7 ON j6.x0 = j7.x0 \ + JOIN j8 ON j7.x0 = j8.x0 \ + JOIN j9 ON j8.x0 = j9.x0 \ + JOIN j10 ON j9.x0 = j10.x0", + ); + c.bench_function("optimizer_join_chain_8", |b| { + b.iter(|| optimize_plan(&ctx, &plan)) + }); + } + + // Wide filter (many expressions) + { + let predicates: Vec = (0..200).map(|i| format!("a{i} > 0")).collect(); + let query = format!("SELECT a0 FROM t1 WHERE {}", predicates.join(" AND ")); + let plan = analyzed_plan(&ctx, &rt, &query); + c.bench_function("optimizer_wide_filter_200", |b| { + b.iter(|| optimize_plan(&ctx, &plan)) + }); + } + + // Wide aggregate (many expressions) + { + let aggs: Vec = + (0..50).map(|i| format!("SUM(a{i}), AVG(a{i})")).collect(); + let query = format!("SELECT {} FROM t1", aggs.join(", ")); + let plan = analyzed_plan(&ctx, &rt, &query); + c.bench_function("optimizer_wide_aggregate_100", |b| { + b.iter(|| optimize_plan(&ctx, &plan)) + }); + } + + // Subquery (tests decorrelation rules) + { + let plan = analyzed_plan( + &ctx, + &rt, + "SELECT a0, a1 FROM t1 \ + WHERE EXISTS (SELECT 1 FROM t2 WHERE t2.b0 = t1.a0)", + ); + c.bench_function("optimizer_correlated_exists", |b| { + b.iter(|| optimize_plan(&ctx, &plan)) + }); + } + + // Mixed: joins + aggregates + filter + { + let plan = analyzed_plan( + &ctx, + &rt, + "SELECT j3.x0, SUM(j4.x1), AVG(j5.x2), COUNT(j6.x3), \ + MIN(j3.x4), MAX(j4.x5) \ + FROM j3 \ + JOIN j4 ON j3.x0 = j4.x0 \ + JOIN j5 ON j4.x0 = j5.x0 \ + JOIN j6 ON j5.x0 = j6.x0 \ + WHERE j3.x1 > 0 AND j4.x2 < 100 AND j5.x3 != j6.x4 \ + GROUP BY j3.x0 \ + HAVING SUM(j4.x1) > 10 \ + ORDER BY j3.x0", + ); + c.bench_function("optimizer_join_4_with_agg_filter", |b| { + b.iter(|| optimize_plan(&ctx, &plan)) + }); + } + + // TPC-H all queries (optimizer only) + { + let plans: Vec<_> = all_tpch_sql_queries + .iter() + .map(|sql| analyzed_plan(&tpch_ctx, &rt, sql)) + .collect(); + c.bench_function("optimizer_tpch_all", |b| { + b.iter(|| { + for plan in &plans { + optimize_plan(&tpch_ctx, plan) + } + }) + }); + } + + // TPC-DS all queries (optimizer only) + { + let plans: Vec<_> = all_tpcds_sql_queries + .iter() + .map(|sql| analyzed_plan(&tpcds_ctx, &rt, sql)) + .collect(); + c.bench_function("optimizer_tpcds_all", |b| { + b.iter(|| { + for plan in &plans { + optimize_plan(&tpcds_ctx, plan) + } + }) + }); + } } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index ef9382a57209a..bf04580e07692 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -52,6 +52,7 @@ use datafusion_common::tree_node::{ TreeNodeRewriter, TreeNodeVisitor, }; use datafusion_common::{Result, internal_err}; +use std::sync::Arc; impl TreeNode for LogicalPlan { fn apply_children<'n, F: FnMut(&'n Self) -> Result>( @@ -393,6 +394,113 @@ macro_rules! handle_transform_recursion { } impl LogicalPlan { + /// Applies `f` to each child (input) of this plan node in place, + /// using [`Arc::make_mut`] for copy-on-write semantics. + /// + /// When the `Arc` refcount is 1 (the common case in the optimizer), + /// `Arc::make_mut` returns a `&mut` reference without cloning. + /// When the refcount is >1, it clones the inner value first. + /// + /// Returns `Ok(true)` if any child was modified by `f`. + pub fn map_children_mut Result>( + &mut self, + mut f: F, + ) -> Result { + Ok(match self { + LogicalPlan::Projection(Projection { input, .. }) + | LogicalPlan::Filter(Filter { input, .. }) + | LogicalPlan::Repartition(Repartition { input, .. }) + | LogicalPlan::Window(Window { input, .. }) + | LogicalPlan::Aggregate(Aggregate { input, .. }) + | LogicalPlan::Sort(Sort { input, .. }) + | LogicalPlan::Limit(Limit { input, .. }) + | LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) + | LogicalPlan::Analyze(Analyze { input, .. }) + | LogicalPlan::Dml(DmlStatement { input, .. }) + | LogicalPlan::Copy(CopyTo { input, .. }) + | LogicalPlan::Unnest(Unnest { input, .. }) => f(Arc::make_mut(input))?, + LogicalPlan::Subquery(Subquery { subquery, .. }) => { + f(Arc::make_mut(subquery))? + } + LogicalPlan::Join(Join { left, right, .. }) => { + let l = f(Arc::make_mut(left))?; + let r = f(Arc::make_mut(right))?; + l || r + } + LogicalPlan::Union(Union { inputs, .. }) => { + let mut changed = false; + for input in inputs { + changed |= f(Arc::make_mut(input))?; + } + changed + } + LogicalPlan::Distinct(Distinct::All(input)) => f(Arc::make_mut(input))?, + LogicalPlan::Distinct(Distinct::On(DistinctOn { input, .. })) => { + f(Arc::make_mut(input))? + } + LogicalPlan::Explain(Explain { plan, .. }) => f(Arc::make_mut(plan))?, + LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(CreateMemoryTable { + input, + .. + })) + | LogicalPlan::Ddl(DdlStatement::CreateView(CreateView { input, .. })) => { + f(Arc::make_mut(input))? + } + LogicalPlan::RecursiveQuery(RecursiveQuery { + static_term, + recursive_term, + .. + }) => { + let s = f(Arc::make_mut(static_term))?; + let r = f(Arc::make_mut(recursive_term))?; + s || r + } + LogicalPlan::Statement(Statement::Prepare(p)) => { + f(Arc::make_mut(&mut p.input))? + } + LogicalPlan::Extension(Extension { node }) => { + let inputs = node.inputs(); + if inputs.is_empty() { + false + } else { + // Extension nodes don't expose mutable children, + // fall back to the ownership-based API + let mut changed = false; + let exprs = node.expressions(); + let new_inputs: Vec = inputs + .into_iter() + .map(|input| { + let mut plan = input.clone(); + if f(&mut plan)? { + changed = true; + } + Ok(plan) + }) + .collect::>>()?; + if changed { + *node = node.with_exprs_and_inputs(exprs, new_inputs)?; + } + changed + } + } + // plans without inputs + LogicalPlan::TableScan { .. } + | LogicalPlan::EmptyRelation { .. } + | LogicalPlan::Values { .. } + | LogicalPlan::DescribeTable(_) + | LogicalPlan::Ddl(DdlStatement::CreateExternalTable(_)) + | LogicalPlan::Ddl(DdlStatement::CreateCatalogSchema(_)) + | LogicalPlan::Ddl(DdlStatement::CreateCatalog(_)) + | LogicalPlan::Ddl(DdlStatement::CreateIndex(_)) + | LogicalPlan::Ddl(DdlStatement::DropTable(_)) + | LogicalPlan::Ddl(DdlStatement::DropView(_)) + | LogicalPlan::Ddl(DdlStatement::DropCatalogSchema(_)) + | LogicalPlan::Ddl(DdlStatement::CreateFunction(_)) + | LogicalPlan::Ddl(DdlStatement::DropFunction(_)) + | LogicalPlan::Statement(_) => false, + }) + } + /// Calls `f` on all expressions in the current `LogicalPlan` node. /// /// # Notes @@ -841,6 +949,32 @@ impl LogicalPlan { }) } + /// Returns true if any expression in this node contains a subquery + /// (Exists, InSubquery, SetComparison, or ScalarSubquery). + fn has_subquery_expressions(&self) -> bool { + let mut found = false; + let _ = self.apply_expressions(|expr| { + if found { + return Ok(TreeNodeRecursion::Stop); + } + expr.apply(|e| { + if matches!( + e, + Expr::Exists(_) + | Expr::InSubquery(_) + | Expr::SetComparison(_) + | Expr::ScalarSubquery(_) + ) { + found = true; + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + }) + }); + found + } + /// Similarly to [`Self::map_children`], rewrites all subqueries that may /// appear in expressions such as `IN (SELECT ...)` using `f`. /// @@ -849,6 +983,14 @@ impl LogicalPlan { self, mut f: F, ) -> Result> { + // Fast path: skip the expensive ownership-based expression traversal + // when this node has no subquery expressions. This avoids + // map_expressions → transform_down walking every expression node + // via consume+recreate just to find no subqueries. + if !self.has_subquery_expressions() { + return Ok(Transformed::no(self)); + } + self.map_expressions(|expr| { expr.transform_down(|expr| match expr { Expr::Exists(Exists { subquery, negated }) => { diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index d0fbb31414dab..701749f45a8e1 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -28,8 +28,11 @@ use log::{debug, warn}; use datafusion_common::alias::AliasGenerator; use datafusion_common::config::ConfigOptions; use datafusion_common::instant::Instant; -use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, +}; use datafusion_common::{DFSchema, DataFusionError, HashSet, Result, internal_err}; +use datafusion_expr::Expr; use datafusion_expr::logical_plan::LogicalPlan; use crate::common_subexpr_eliminate::CommonSubexprEliminate; @@ -357,6 +360,95 @@ impl TreeNodeRewriter for Rewriter<'_> { } } +/// Rewrites a plan tree in place using `Arc::make_mut` for +/// copy-on-write semantics on `Arc` children. +/// +/// This avoids the `Arc::unwrap_or_clone` + `Arc::new` cycle that the +/// ownership-based `TreeNode::rewrite` performs at every child node. +/// When the `Arc` refcount is 1 (always true in the optimizer), +/// `Arc::make_mut` is essentially free. +/// +/// The `rule.rewrite()` API takes ownership, so we bridge the `&mut` to an +/// owned value with [`std::mem::take`]. `LogicalPlan::default()` is a cheap +/// empty placeholder (shared empty schema, no allocation) and is overwritten +/// with the rule's output on the very next line. +#[cfg_attr(feature = "recursive_protection", recursive::recursive)] +fn rewrite_plan_in_place( + plan: &mut LogicalPlan, + apply_order: ApplyOrder, + rule: &dyn OptimizerRule, + config: &dyn OptimizerConfig, +) -> Result { + // f_down phase + let mut changed = false; + if apply_order == ApplyOrder::TopDown { + let owned = std::mem::take(plan); + let result = rule.rewrite(owned, config)?; + *plan = result.data; + changed |= result.transformed; + // Respect TreeNodeRecursion::Stop/Jump from the rule + if result.tnr == TreeNodeRecursion::Stop { + return Ok(changed); + } + } + + // Recurse into children using Arc::make_mut (zero-cost when refcount == 1) + changed |= plan.map_children_mut(|child| { + rewrite_plan_in_place(child, apply_order, rule, config) + })?; + + // f_up phase + if apply_order == ApplyOrder::BottomUp { + let owned = std::mem::take(plan); + let result = rule.rewrite(owned, config)?; + *plan = result.data; + changed |= result.transformed; + } + + Ok(changed) +} + +/// Returns true if the plan contains any subquery expressions +/// (EXISTS, IN subquery, scalar subquery, set comparison). +/// +/// Used to determine whether the more expensive `rewrite_with_subqueries` +/// traversal is needed. When the plan has no subqueries, the cheaper +/// `rewrite` traversal is sufficient since all plan nodes are reachable +/// via direct children. +fn plan_has_subqueries(plan: &LogicalPlan) -> bool { + let mut found = false; + let _ = plan.apply(|node| { + if found { + return Ok(TreeNodeRecursion::Stop); + } + node.apply_expressions(|expr| { + if found { + return Ok(TreeNodeRecursion::Stop); + } + expr.apply(|e| { + if matches!( + e, + Expr::Exists(_) + | Expr::InSubquery(_) + | Expr::SetComparison(_) + | Expr::ScalarSubquery(_) + ) { + found = true; + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + }) + })?; + Ok(if found { + TreeNodeRecursion::Stop + } else { + TreeNodeRecursion::Continue + }) + }); + found +} + impl Optimizer { /// Optimizes the logical plan by applying optimizer rules, and /// invoking observer function after each call @@ -386,6 +478,14 @@ impl Optimizer { while i < options.optimizer.max_passes { log_plan(&format!("Optimizer input (pass {i})"), &new_plan); + // Check once per pass whether the plan contains subquery + // expressions. When there are no subqueries, we use the + // cheaper `rewrite` traversal instead of + // `rewrite_with_subqueries`, avoiding the per-node + // map_subqueries call that walks all expression trees + // via ownership-based transform_down. + let has_subqueries = plan_has_subqueries(&new_plan); + for rule in &self.rules { // If skipping failed rules, copy plan before attempting to rewrite // as rewriting is destructive @@ -398,9 +498,36 @@ impl Optimizer { let result = match rule.apply_order() { // optimizer handles recursion - Some(apply_order) => new_plan.rewrite_with_subqueries( - &mut Rewriter::new(apply_order, rule.as_ref(), config), - ), + Some(apply_order) => { + if has_subqueries { + // Plans with subqueries need the full + // rewrite_with_subqueries traversal to + // recurse into subquery plans. + new_plan.rewrite_with_subqueries( + &mut Rewriter::new( + apply_order, + rule.as_ref(), + config, + ), + ) + } else { + // No subqueries: use in-place rewriting + // with Arc::make_mut for zero-cost CoW on + // children, avoiding Arc unwrap/rewrap. + rewrite_plan_in_place( + &mut new_plan, + apply_order, + rule.as_ref(), + config, + ) + .map(|transformed| { + Transformed::new_transformed( + std::mem::take(&mut new_plan), + transformed, + ) + }) + } + } // rule handles recursion itself None => { rule.rewrite(new_plan, config) From 3e1f856fa4224a6950c0b7b677ac6a8da60f35c7 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Mon, 18 May 2026 14:09:50 -0700 Subject: [PATCH 2/3] docs: clarify error semantics of in-place plan rewriting Address review feedback on the in-place optimizer rewrite: - Document the error contract of `rewrite_plan_in_place`: on `Err` the plan is left in an unspecified state because `rule.rewrite()` consumes it by value, and explain why it cannot be recovered without the clone the function exists to avoid. Note how `Optimizer::optimize` handles it. - Move the `mem::take` bridge explanation from the doc comment into an inline comment next to the code it describes. - Drop the inaccurate "Arc refcount is 1 is always true" claim. - Document that `LogicalPlan::map_children_mut` does not roll back partial mutations when `f` fails. Comment-only changes; no behavior change. Co-Authored-By: Claude Opus 4.7 (1M context) --- datafusion/expr/src/logical_plan/tree_node.rs | 8 +++++ datafusion/optimizer/src/optimizer.rs | 31 +++++++++++++++---- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index bf04580e07692..4d5a940f12310 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -402,6 +402,14 @@ impl LogicalPlan { /// When the refcount is >1, it clones the inner value first. /// /// Returns `Ok(true)` if any child was modified by `f`. + /// + /// # Error semantics + /// + /// If `f` returns `Err` for a child, this method returns that error + /// immediately. Children visited earlier in the same call keep whatever + /// modifications `f` already applied to them — they are **not** rolled + /// back. Callers that need the pre-call tree on error must save a copy + /// beforehand. pub fn map_children_mut Result>( &mut self, mut f: F, diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 701749f45a8e1..a8154653d4e37 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -365,13 +365,22 @@ impl TreeNodeRewriter for Rewriter<'_> { /// /// This avoids the `Arc::unwrap_or_clone` + `Arc::new` cycle that the /// ownership-based `TreeNode::rewrite` performs at every child node. -/// When the `Arc` refcount is 1 (always true in the optimizer), -/// `Arc::make_mut` is essentially free. /// -/// The `rule.rewrite()` API takes ownership, so we bridge the `&mut` to an -/// owned value with [`std::mem::take`]. `LogicalPlan::default()` is a cheap -/// empty placeholder (shared empty schema, no allocation) and is overwritten -/// with the rule's output on the very next line. +/// # Error semantics +/// +/// On `Err`, `*plan` is left in an **unspecified** state and must not be used. +/// Because `rule.rewrite()` consumes the plan by value, a failing rule drops +/// the node it was handed and the [`std::mem::take`] placeholder +/// (`LogicalPlan::default()`, an `EmptyRelation`) is left in its place — at the +/// root for a top-level failure, or somewhere in a subtree for a failure deeper +/// in the recursion. The pre-rule plan is **not** recoverable here: restoring it +/// would require cloning it before every rule invocation, which is exactly the +/// allocation this function exists to avoid. +/// +/// Callers must therefore discard `*plan` on `Err`, or restore it from a copy +/// saved beforehand. [`Optimizer::optimize`] does this: with `skip_failed_rules` +/// it restores `new_plan` from the `prev_plan` clone it already holds, and +/// without it the error aborts the pass and the plan is dropped unobserved. #[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn rewrite_plan_in_place( plan: &mut LogicalPlan, @@ -382,6 +391,10 @@ fn rewrite_plan_in_place( // f_down phase let mut changed = false; if apply_order == ApplyOrder::TopDown { + // `rule.rewrite()` takes the plan by value, so bridge the `&mut` to an + // owned value with `std::mem::take`. `LogicalPlan::default()` is a cheap + // empty placeholder (shared empty schema, no allocation) and is + // overwritten with the rule's output on the next line. let owned = std::mem::take(plan); let result = rule.rewrite(owned, config)?; *plan = result.data; @@ -514,6 +527,12 @@ impl Optimizer { // No subqueries: use in-place rewriting // with Arc::make_mut for zero-cost CoW on // children, avoiding Arc unwrap/rewrap. + // + // On error `new_plan` is left in an unspecified + // state (see `rewrite_plan_in_place`); the result + // handling below discards it, restoring `prev_plan` + // when `skip_failed_rules` is set or propagating + // the error otherwise. rewrite_plan_in_place( &mut new_plan, apply_order, From 646d2afb298b1ce538ede6cef5cea8f15fd4a5bf Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Mon, 18 May 2026 15:52:02 -0700 Subject: [PATCH 3/3] refactor: keep map_children_mut private to the optimizer crate `map_children_mut` was added as a public method on `LogicalPlan` in `datafusion-expr` only because the optimizer, in a different crate, needed to call it. But the optimizer is its sole consumer, and the `Arc::make_mut` in-place trick does not generalize to the other tree types (`Expr` children are `Box`ed; `PhysicalExpr`/`ExecutionPlan` children are `Arc`, which `Arc::make_mut` cannot handle), so committing to it as public API is not warranted. Move it into `optimizer.rs` as a private free function next to `rewrite_plan_in_place`, its only caller. `datafusion-expr` is now untouched by this PR except for the `map_subqueries` short-circuit, and the optimizer adds no public API. If `TreeNode` ever grows an in-place traversal this logic can move there with no breaking change. No behavior change; the 713 datafusion-optimizer tests pass unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) --- datafusion/expr/src/logical_plan/tree_node.rs | 116 ---------------- datafusion/optimizer/src/optimizer.rs | 125 +++++++++++++++++- 2 files changed, 123 insertions(+), 118 deletions(-) diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 4d5a940f12310..1f58de37e93b0 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -52,7 +52,6 @@ use datafusion_common::tree_node::{ TreeNodeRewriter, TreeNodeVisitor, }; use datafusion_common::{Result, internal_err}; -use std::sync::Arc; impl TreeNode for LogicalPlan { fn apply_children<'n, F: FnMut(&'n Self) -> Result>( @@ -394,121 +393,6 @@ macro_rules! handle_transform_recursion { } impl LogicalPlan { - /// Applies `f` to each child (input) of this plan node in place, - /// using [`Arc::make_mut`] for copy-on-write semantics. - /// - /// When the `Arc` refcount is 1 (the common case in the optimizer), - /// `Arc::make_mut` returns a `&mut` reference without cloning. - /// When the refcount is >1, it clones the inner value first. - /// - /// Returns `Ok(true)` if any child was modified by `f`. - /// - /// # Error semantics - /// - /// If `f` returns `Err` for a child, this method returns that error - /// immediately. Children visited earlier in the same call keep whatever - /// modifications `f` already applied to them — they are **not** rolled - /// back. Callers that need the pre-call tree on error must save a copy - /// beforehand. - pub fn map_children_mut Result>( - &mut self, - mut f: F, - ) -> Result { - Ok(match self { - LogicalPlan::Projection(Projection { input, .. }) - | LogicalPlan::Filter(Filter { input, .. }) - | LogicalPlan::Repartition(Repartition { input, .. }) - | LogicalPlan::Window(Window { input, .. }) - | LogicalPlan::Aggregate(Aggregate { input, .. }) - | LogicalPlan::Sort(Sort { input, .. }) - | LogicalPlan::Limit(Limit { input, .. }) - | LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) - | LogicalPlan::Analyze(Analyze { input, .. }) - | LogicalPlan::Dml(DmlStatement { input, .. }) - | LogicalPlan::Copy(CopyTo { input, .. }) - | LogicalPlan::Unnest(Unnest { input, .. }) => f(Arc::make_mut(input))?, - LogicalPlan::Subquery(Subquery { subquery, .. }) => { - f(Arc::make_mut(subquery))? - } - LogicalPlan::Join(Join { left, right, .. }) => { - let l = f(Arc::make_mut(left))?; - let r = f(Arc::make_mut(right))?; - l || r - } - LogicalPlan::Union(Union { inputs, .. }) => { - let mut changed = false; - for input in inputs { - changed |= f(Arc::make_mut(input))?; - } - changed - } - LogicalPlan::Distinct(Distinct::All(input)) => f(Arc::make_mut(input))?, - LogicalPlan::Distinct(Distinct::On(DistinctOn { input, .. })) => { - f(Arc::make_mut(input))? - } - LogicalPlan::Explain(Explain { plan, .. }) => f(Arc::make_mut(plan))?, - LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(CreateMemoryTable { - input, - .. - })) - | LogicalPlan::Ddl(DdlStatement::CreateView(CreateView { input, .. })) => { - f(Arc::make_mut(input))? - } - LogicalPlan::RecursiveQuery(RecursiveQuery { - static_term, - recursive_term, - .. - }) => { - let s = f(Arc::make_mut(static_term))?; - let r = f(Arc::make_mut(recursive_term))?; - s || r - } - LogicalPlan::Statement(Statement::Prepare(p)) => { - f(Arc::make_mut(&mut p.input))? - } - LogicalPlan::Extension(Extension { node }) => { - let inputs = node.inputs(); - if inputs.is_empty() { - false - } else { - // Extension nodes don't expose mutable children, - // fall back to the ownership-based API - let mut changed = false; - let exprs = node.expressions(); - let new_inputs: Vec = inputs - .into_iter() - .map(|input| { - let mut plan = input.clone(); - if f(&mut plan)? { - changed = true; - } - Ok(plan) - }) - .collect::>>()?; - if changed { - *node = node.with_exprs_and_inputs(exprs, new_inputs)?; - } - changed - } - } - // plans without inputs - LogicalPlan::TableScan { .. } - | LogicalPlan::EmptyRelation { .. } - | LogicalPlan::Values { .. } - | LogicalPlan::DescribeTable(_) - | LogicalPlan::Ddl(DdlStatement::CreateExternalTable(_)) - | LogicalPlan::Ddl(DdlStatement::CreateCatalogSchema(_)) - | LogicalPlan::Ddl(DdlStatement::CreateCatalog(_)) - | LogicalPlan::Ddl(DdlStatement::CreateIndex(_)) - | LogicalPlan::Ddl(DdlStatement::DropTable(_)) - | LogicalPlan::Ddl(DdlStatement::DropView(_)) - | LogicalPlan::Ddl(DdlStatement::DropCatalogSchema(_)) - | LogicalPlan::Ddl(DdlStatement::CreateFunction(_)) - | LogicalPlan::Ddl(DdlStatement::DropFunction(_)) - | LogicalPlan::Statement(_) => false, - }) - } - /// Calls `f` on all expressions in the current `LogicalPlan` node. /// /// # Notes diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index a8154653d4e37..9e015cf55b75d 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -32,8 +32,14 @@ use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; use datafusion_common::{DFSchema, DataFusionError, HashSet, Result, internal_err}; -use datafusion_expr::Expr; +use datafusion_expr::dml::CopyTo; use datafusion_expr::logical_plan::LogicalPlan; +use datafusion_expr::{ + Aggregate, Analyze, CreateMemoryTable, CreateView, DdlStatement, Distinct, + DistinctOn, DmlStatement, Explain, Expr, Extension, Filter, Join, Limit, Projection, + RecursiveQuery, Repartition, Sort, Statement, Subquery, SubqueryAlias, Union, Unnest, + Window, +}; use crate::common_subexpr_eliminate::CommonSubexprEliminate; use crate::decorrelate_lateral_join::DecorrelateLateralJoin; @@ -360,6 +366,121 @@ impl TreeNodeRewriter for Rewriter<'_> { } } +/// Applies `f` to each child (input) of `plan` in place, using +/// [`Arc::make_mut`] for copy-on-write semantics on `Arc` +/// children. When the `Arc` refcount is 1 (the common case here) +/// `Arc::make_mut` hands out a `&mut` without cloning; when it is >1 the +/// inner value is cloned first. +/// +/// Returns `Ok(true)` if any child was modified by `f`. +/// +/// This is deliberately private to the optimizer rather than a method on +/// [`LogicalPlan`]: it is an implementation detail of in-place rewriting, and +/// the `Arc::make_mut` approach does not generalize to the other tree types +/// (`Expr` children are `Box`ed; `PhysicalExpr`/`ExecutionPlan` children are +/// `Arc`, which `Arc::make_mut` cannot handle). If `TreeNode` ever +/// grows an in-place traversal this logic can move there. +/// +/// # Error semantics +/// +/// If `f` returns `Err` for a child, that error is returned immediately; +/// children visited earlier keep whatever modifications `f` already applied +/// to them — they are **not** rolled back. +fn map_children_mut Result>( + plan: &mut LogicalPlan, + mut f: F, +) -> Result { + Ok(match plan { + LogicalPlan::Projection(Projection { input, .. }) + | LogicalPlan::Filter(Filter { input, .. }) + | LogicalPlan::Repartition(Repartition { input, .. }) + | LogicalPlan::Window(Window { input, .. }) + | LogicalPlan::Aggregate(Aggregate { input, .. }) + | LogicalPlan::Sort(Sort { input, .. }) + | LogicalPlan::Limit(Limit { input, .. }) + | LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) + | LogicalPlan::Analyze(Analyze { input, .. }) + | LogicalPlan::Dml(DmlStatement { input, .. }) + | LogicalPlan::Copy(CopyTo { input, .. }) + | LogicalPlan::Unnest(Unnest { input, .. }) => f(Arc::make_mut(input))?, + LogicalPlan::Subquery(Subquery { subquery, .. }) => f(Arc::make_mut(subquery))?, + LogicalPlan::Join(Join { left, right, .. }) => { + let l = f(Arc::make_mut(left))?; + let r = f(Arc::make_mut(right))?; + l || r + } + LogicalPlan::Union(Union { inputs, .. }) => { + let mut changed = false; + for input in inputs { + changed |= f(Arc::make_mut(input))?; + } + changed + } + LogicalPlan::Distinct(Distinct::All(input)) => f(Arc::make_mut(input))?, + LogicalPlan::Distinct(Distinct::On(DistinctOn { input, .. })) => { + f(Arc::make_mut(input))? + } + LogicalPlan::Explain(Explain { plan, .. }) => f(Arc::make_mut(plan))?, + LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(CreateMemoryTable { + input, + .. + })) + | LogicalPlan::Ddl(DdlStatement::CreateView(CreateView { input, .. })) => { + f(Arc::make_mut(input))? + } + LogicalPlan::RecursiveQuery(RecursiveQuery { + static_term, + recursive_term, + .. + }) => { + let s = f(Arc::make_mut(static_term))?; + let r = f(Arc::make_mut(recursive_term))?; + s || r + } + LogicalPlan::Statement(Statement::Prepare(p)) => f(Arc::make_mut(&mut p.input))?, + LogicalPlan::Extension(Extension { node }) => { + let inputs = node.inputs(); + if inputs.is_empty() { + false + } else { + // Extension nodes don't expose mutable children, + // fall back to the ownership-based API + let mut changed = false; + let exprs = node.expressions(); + let new_inputs: Vec = inputs + .into_iter() + .map(|input| { + let mut plan = input.clone(); + if f(&mut plan)? { + changed = true; + } + Ok(plan) + }) + .collect::>>()?; + if changed { + *node = node.with_exprs_and_inputs(exprs, new_inputs)?; + } + changed + } + } + // plans without inputs + LogicalPlan::TableScan { .. } + | LogicalPlan::EmptyRelation { .. } + | LogicalPlan::Values { .. } + | LogicalPlan::DescribeTable(_) + | LogicalPlan::Ddl(DdlStatement::CreateExternalTable(_)) + | LogicalPlan::Ddl(DdlStatement::CreateCatalogSchema(_)) + | LogicalPlan::Ddl(DdlStatement::CreateCatalog(_)) + | LogicalPlan::Ddl(DdlStatement::CreateIndex(_)) + | LogicalPlan::Ddl(DdlStatement::DropTable(_)) + | LogicalPlan::Ddl(DdlStatement::DropView(_)) + | LogicalPlan::Ddl(DdlStatement::DropCatalogSchema(_)) + | LogicalPlan::Ddl(DdlStatement::CreateFunction(_)) + | LogicalPlan::Ddl(DdlStatement::DropFunction(_)) + | LogicalPlan::Statement(_) => false, + }) +} + /// Rewrites a plan tree in place using `Arc::make_mut` for /// copy-on-write semantics on `Arc` children. /// @@ -406,7 +527,7 @@ fn rewrite_plan_in_place( } // Recurse into children using Arc::make_mut (zero-cost when refcount == 1) - changed |= plan.map_children_mut(|child| { + changed |= map_children_mut(plan, |child| { rewrite_plan_in_place(child, apply_order, rule, config) })?;