From 0c33aac61d3e14ffcebc7810234bec1583d4cdca Mon Sep 17 00:00:00 2001 From: Jia-Xuan Liu Date: Fri, 18 Oct 2024 21:58:49 +0800 Subject: [PATCH] enhance unparsing plan with pushdown to avoid unnamed subquery --- datafusion/sql/src/unparser/plan.rs | 61 +++++++++++++++++++++-- datafusion/sql/src/unparser/rewrite.rs | 10 ++-- datafusion/sql/tests/cases/plan_to_sql.rs | 6 +-- 3 files changed, 65 insertions(+), 12 deletions(-) diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index c22400f1faa1..4002ba8617ff 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -617,9 +617,10 @@ impl Unparser<'_> { if !Self::is_scan_with_pushdown(table_scan) { return Ok(None); } + let table_schema = table_scan.source.schema(); let mut filter_alias_rewriter = alias.as_ref().map(|alias_name| TableAliasRewriter { - table_schema: table_scan.source.schema(), + table_schema: &table_schema, alias_name: alias_name.clone(), }); @@ -628,6 +629,17 @@ impl Unparser<'_> { Arc::clone(&table_scan.source), None, )?; + // We will rebase the column references to the new alias if it exists. + // If the projection or filters are empty, we will append alias to the table scan. + // + // Example: + // select t1.c1 from t1 where t1.c1 > 1 -> select a.c1 from t1 as a where a.c1 > 1 + if alias.is_some() + && (table_scan.projection.is_some() || !table_scan.filters.is_empty()) + { + builder = builder.alias(alias.clone().unwrap())?; + } + if let Some(project_vec) = &table_scan.projection { let project_columns = project_vec .iter() @@ -645,9 +657,6 @@ impl Unparser<'_> { } }) .collect::>(); - if let Some(alias) = alias { - builder = builder.alias(alias)?; - } builder = builder.project(project_columns)?; } @@ -677,6 +686,16 @@ impl Unparser<'_> { builder = builder.limit(0, Some(fetch))?; } + // If the table scan has an alias but no projection or filters, it means no column references are rebased. + // So we will append the alias to this subquery. + // Example: + // select * from t1 limit 10 -> (select * from t1 limit 10) as a + if alias.is_some() + && (table_scan.projection.is_none() && table_scan.filters.is_empty()) + { + builder = builder.alias(alias.clone().unwrap())?; + } + Ok(Some(builder.build()?)) } LogicalPlan::SubqueryAlias(subquery_alias) => { @@ -685,6 +704,40 @@ impl Unparser<'_> { Some(subquery_alias.alias.clone()), ) } + // SubqueryAlias could be rewritten to a plan with a projection as the top node by [rewrite::subquery_alias_inner_query_and_columns]. + // The inner table scan could be a scan with pushdown operations. + LogicalPlan::Projection(projection) => { + if let Some(plan) = + Self::unparse_table_scan_pushdown(&projection.input, alias.clone())? + { + let exprs = if alias.is_some() { + let mut alias_rewriter = + alias.as_ref().map(|alias_name| TableAliasRewriter { + table_schema: plan.schema().as_arrow(), + alias_name: alias_name.clone(), + }); + projection + .expr + .iter() + .cloned() + .map(|expr| { + if let Some(ref mut rewriter) = alias_rewriter { + expr.rewrite(rewriter).data() + } else { + Ok(expr) + } + }) + .collect::>>()? + } else { + projection.expr.clone() + }; + Ok(Some( + LogicalPlanBuilder::from(plan).project(exprs)?.build()?, + )) + } else { + Ok(None) + } + } _ => Ok(None), } } diff --git a/datafusion/sql/src/unparser/rewrite.rs b/datafusion/sql/src/unparser/rewrite.rs index 3049df9396cb..57d700f86955 100644 --- a/datafusion/sql/src/unparser/rewrite.rs +++ b/datafusion/sql/src/unparser/rewrite.rs @@ -20,7 +20,7 @@ use std::{ sync::Arc, }; -use arrow_schema::SchemaRef; +use arrow_schema::Schema; use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, Column, Result, TableReference, @@ -293,7 +293,7 @@ pub(super) fn inject_column_aliases_into_subquery( /// - `SELECT col1, col2 FROM table` with aliases `["alias_1", "some_alias_2"]` will be transformed to /// - `SELECT col1 AS alias_1, col2 AS some_alias_2 FROM table` pub(super) fn inject_column_aliases( - projection: &datafusion_expr::Projection, + projection: &Projection, aliases: impl IntoIterator, ) -> LogicalPlan { let mut updated_projection = projection.clone(); @@ -343,12 +343,12 @@ fn find_projection(logical_plan: &LogicalPlan) -> Option<&Projection> { /// from which the columns are referenced. This is used to look up columns by their names. /// * `alias_name`: The alias (`TableReference`) that will replace the table name /// in the column references when applicable. -pub struct TableAliasRewriter { - pub table_schema: SchemaRef, +pub struct TableAliasRewriter<'a> { + pub table_schema: &'a Schema, pub alias_name: TableReference, } -impl TreeNodeRewriter for TableAliasRewriter { +impl TreeNodeRewriter for TableAliasRewriter<'_> { type Node = Expr; fn f_down(&mut self, expr: Expr) -> Result> { diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 74abdf075f23..8f58fb1df039 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -726,7 +726,7 @@ fn test_table_scan_alias() -> Result<()> { let table_scan_with_two_filter = plan_to_sql(&table_scan_with_two_filter)?; assert_eq!( table_scan_with_two_filter.to_string(), - "SELECT * FROM (SELECT t1.id FROM t1 WHERE ((t1.id > 1) AND (t1.age < 2))) AS a" + "SELECT a.id FROM t1 AS a WHERE ((a.id > 1) AND (a.age < 2))" ); let table_scan_with_fetch = @@ -737,7 +737,7 @@ fn test_table_scan_alias() -> Result<()> { let table_scan_with_fetch = plan_to_sql(&table_scan_with_fetch)?; assert_eq!( table_scan_with_fetch.to_string(), - "SELECT * FROM (SELECT t1.id FROM (SELECT * FROM t1 LIMIT 10)) AS a" + "SELECT a.id FROM (SELECT * FROM t1 LIMIT 10) AS a" ); let table_scan_with_pushdown_all = table_scan_with_filter_and_fetch( @@ -753,7 +753,7 @@ fn test_table_scan_alias() -> Result<()> { let table_scan_with_pushdown_all = plan_to_sql(&table_scan_with_pushdown_all)?; assert_eq!( table_scan_with_pushdown_all.to_string(), - "SELECT * FROM (SELECT t1.id FROM (SELECT t1.id, t1.age FROM t1 WHERE (t1.id > 1) LIMIT 10)) AS a" + "SELECT a.id FROM (SELECT a.id, a.age FROM t1 AS a WHERE (a.id > 1) LIMIT 10) AS a" ); Ok(()) }