diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index 852b350b27df..e082cabaadaf 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -22,6 +22,7 @@ use rstest::rstest; use datafusion::config::ConfigOptions; use datafusion::physical_plan::display::DisplayableExecutionPlan; use datafusion::physical_plan::metrics::Timestamp; +use object_store::path::Path; #[tokio::test] async fn explain_analyze_baseline_metrics() { @@ -727,6 +728,130 @@ async fn parquet_explain_analyze() { assert_contains!(&formatted, "row_groups_pruned_statistics=0"); } +// This test reproduces the behavior described in +// https://github.com/apache/datafusion/issues/16684 where projection +// pushdown with recursive CTEs could fail to remove unused columns +// (e.g. nested/recursive expansion causing full schema to be scanned). +// Keeping this test ensures we don't regress that behavior. +#[tokio::test] +#[cfg_attr(tarpaulin, ignore)] +async fn parquet_recursive_projection_pushdown() -> Result<()> { + use parquet::arrow::arrow_writer::ArrowWriter; + use parquet::file::properties::WriterProperties; + + let temp_dir = TempDir::new().unwrap(); + let parquet_path = temp_dir.path().join("hierarchy.parquet"); + + let ids = Int64Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + let parent_ids = Int64Array::from(vec![0, 1, 1, 2, 2, 3, 4, 5, 6, 7]); + let values = Int64Array::from(vec![10, 20, 30, 40, 50, 60, 70, 80, 90, 100]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("parent_id", DataType::Int64, true), + Field::new("value", DataType::Int64, false), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(ids), Arc::new(parent_ids), Arc::new(values)], + ) + .unwrap(); + + let file = File::create(&parquet_path).unwrap(); + let props = WriterProperties::builder().build(); + let mut writer = ArrowWriter::try_new(file, schema, Some(props)).unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + + let ctx = SessionContext::new(); + ctx.register_parquet( + "hierarchy", + parquet_path.to_str().unwrap(), + ParquetReadOptions::default(), + ) + .await?; + + let sql = r#" + WITH RECURSIVE number_series AS ( + SELECT id, 1 as level + FROM hierarchy + WHERE id = 1 + + UNION ALL + + SELECT ns.id + 1, ns.level + 1 + FROM number_series ns + WHERE ns.id < 10 + ) + SELECT * FROM number_series ORDER BY id + "#; + + let dataframe = ctx.sql(sql).await?; + let physical_plan = dataframe.create_physical_plan().await?; + + let normalizer = ExplainNormalizer::new(); + let mut actual = format!("{}", displayable(physical_plan.as_ref()).indent(true)) + .trim() + .lines() + .map(|line| normalizer.normalize(line)) + .collect::>() + .join("\n"); + + fn replace_path_variants(actual: &mut String, path: &str) { + let mut candidates = vec![path.to_string()]; + + let trimmed = path.trim_start_matches(std::path::MAIN_SEPARATOR); + if trimmed != path { + candidates.push(trimmed.to_string()); + } + + let forward_slash = path.replace('\\', "/"); + if forward_slash != path { + candidates.push(forward_slash.clone()); + + let trimmed_forward = forward_slash.trim_start_matches('/'); + if trimmed_forward != forward_slash { + candidates.push(trimmed_forward.to_string()); + } + } + + for candidate in candidates { + *actual = actual.replace(&candidate, "TMP_DIR"); + } + } + + let temp_dir_path = temp_dir.path(); + let fs_path = temp_dir_path.to_string_lossy().to_string(); + replace_path_variants(&mut actual, &fs_path); + + if let Ok(url_path) = Path::from_filesystem_path(temp_dir_path) { + replace_path_variants(&mut actual, url_path.as_ref()); + } + + assert_snapshot!( + actual, + @r" + SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] + RecursiveQueryExec: name=number_series, is_distinct=false + CoalescePartitionsExec + ProjectionExec: expr=[id@0 as id, 1 as level] + CoalesceBatchesExec: target_batch_size=8192 + FilterExec: id@0 = 1 + RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES), input_partitions=1 + DataSourceExec: file_groups={1 group: [[TMP_DIR/hierarchy.parquet]]}, projection=[id], file_type=parquet, predicate=id@0 = 1, pruning_predicate=id_null_count@2 != row_count@3 AND id_min@0 <= 1 AND 1 <= id_max@1, required_guarantees=[id in (1)] + CoalescePartitionsExec + ProjectionExec: expr=[id@0 + 1 as ns.id + Int64(1), level@1 + 1 as ns.level + Int64(1)] + CoalesceBatchesExec: target_batch_size=8192 + FilterExec: id@0 < 10 + RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES), input_partitions=1 + WorkTableExec: name=number_series + " + ); + + Ok(()) +} + #[tokio::test] #[cfg_attr(tarpaulin, ignore)] async fn parquet_explain_analyze_verbose() { diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 312e788db7be..5db71417bc8f 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -880,7 +880,9 @@ pub fn is_projection_unnecessary( /// pushdown for now because we cannot safely reason about their column usage. fn plan_contains_other_subqueries(plan: &LogicalPlan, cte_name: &str) -> bool { if let LogicalPlan::SubqueryAlias(alias) = plan { - if alias.alias.table() != cte_name { + if alias.alias.table() != cte_name + && !subquery_alias_targets_recursive_cte(alias.input.as_ref(), cte_name) + { return true; } } @@ -913,6 +915,23 @@ fn expr_contains_subquery(expr: &Expr) -> bool { .unwrap() } +fn subquery_alias_targets_recursive_cte(plan: &LogicalPlan, cte_name: &str) -> bool { + match plan { + LogicalPlan::TableScan(scan) => scan.table_name.table() == cte_name, + LogicalPlan::SubqueryAlias(alias) => { + subquery_alias_targets_recursive_cte(alias.input.as_ref(), cte_name) + } + _ => { + let inputs = plan.inputs(); + if inputs.len() == 1 { + subquery_alias_targets_recursive_cte(inputs[0], cte_name) + } else { + false + } + } + } +} + #[cfg(test)] mod tests { use std::cmp::Ordering; diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index b17375ac01b7..c0f48b8ebfc4 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -548,6 +548,29 @@ fn recursive_cte_projection_pushdown() -> Result<()> { Ok(()) } +#[test] +fn recursive_cte_with_aliased_self_reference() -> Result<()> { + let sql = "WITH RECURSIVE nodes AS (\ + SELECT col_int32 AS id, col_utf8 AS name FROM test \ + UNION ALL \ + SELECT child.id + 1, child.name FROM nodes AS child WHERE child.id < 3\ + ) SELECT id FROM nodes"; + let plan = test_sql(sql)?; + + assert_snapshot!( + format!("{plan}"), + @r#"SubqueryAlias: nodes + RecursiveQuery: is_distinct=false + Projection: test.col_int32 AS id + TableScan: test projection=[col_int32] + Projection: CAST(CAST(child.id AS Int64) + Int64(1) AS Int32) + SubqueryAlias: child + Filter: nodes.id < Int32(3) + TableScan: nodes projection=[id]"#, + ); + Ok(()) +} + #[test] fn recursive_cte_with_unused_columns() -> Result<()> { // Test projection pushdown with a recursive CTE where the base case