diff --git a/datafusion/sql/src/unparser/ast.rs b/datafusion/sql/src/unparser/ast.rs index 2de1ce9125a7..e6762258b002 100644 --- a/datafusion/sql/src/unparser/ast.rs +++ b/datafusion/sql/src/unparser/ast.rs @@ -25,6 +25,8 @@ use core::fmt; use sqlparser::ast; +use super::rewrite::remove_dangling_expr; + #[derive(Clone)] pub(super) struct QueryBuilder { with: Option, @@ -238,7 +240,101 @@ impl SelectBuilder { self.value_table_mode = value; self } - pub fn build(&self) -> Result { + fn collect_valid_idents(&self, relation_builder: &RelationBuilder) -> Vec { + let mut all_idents = Vec::new(); + if let Some(source_alias) = relation_builder.get_alias() { + all_idents.push(source_alias); + } else if let Some(source_name) = relation_builder.get_name() { + let full_ident = source_name.to_string(); + if let Some(name) = source_name.0.last() { + if full_ident != name.to_string() { + // supports identifiers that contain the entire path, as well as just the end table leaf + // like catalog.schema.table and table + all_idents.push(name.to_string()); + } + } + all_idents.push(full_ident); + } + + if let Some(twg) = self.from.last() { + twg.joins.iter().for_each(|join| match &join.relation { + ast::TableFactor::Table { alias, name, .. } => { + if let Some(alias) = alias { + all_idents.push(alias.name.to_string()); + } else { + let full_ident = name.to_string(); + if let Some(name) = name.0.last() { + if full_ident != name.to_string() { + // supports identifiers that contain the entire path, as well as just the end table leaf + // like catalog.schema.table and table + all_idents.push(name.to_string()); + } + } + all_idents.push(full_ident); + } + } + ast::TableFactor::Derived { + alias: Some(alias), .. + } => { + all_idents.push(alias.name.to_string()); + } + _ => {} + }); + } + + all_idents + } + + /// Remove any dangling table identifiers from the projection, selection, group by, order by and function arguments + /// This removes any references to tables that are not part of any from/source or join, as they would be invalid + fn remove_dangling_identifiers( + &mut self, + query: &mut Option, + relation_builder: &RelationBuilder, + ) { + let all_idents = self.collect_valid_idents(relation_builder); + + // Ensure that the projection contains references to sources that actually exist + self.projection.iter_mut().for_each(|select_item| { + if let ast::SelectItem::UnnamedExpr(expr) = select_item { + *expr = remove_dangling_expr(expr.clone(), &all_idents); + } + }); + + // replace dangling references in the selection + if let Some(expr) = self.selection.as_ref() { + self.selection = Some(remove_dangling_expr(expr.clone(), &all_idents)); + } + + // Check the order by as well + if let Some(query) = query.as_mut() { + query.order_by.iter_mut().for_each(|sort_item| { + sort_item.expr = + remove_dangling_expr(sort_item.expr.clone(), &all_idents); + }); + } + + // Order by could be a sort in the select builder + self.sort_by.iter_mut().for_each(|sort_item| { + *sort_item = remove_dangling_expr(sort_item.clone(), &all_idents); + }); + + // check the group by as well + if let Some(ast::GroupByExpr::Expressions(ref mut group_by, _)) = + self.group_by.as_mut() + { + group_by.iter_mut().for_each(|expr| { + *expr = remove_dangling_expr(expr.clone(), &all_idents); + }); + } + } + pub fn build( + &mut self, + query: &mut Option, + relation_builder: &RelationBuilder, + ) -> Result { + self.remove_dangling_identifiers(query, relation_builder); + Ok(ast::Select { distinct: self.distinct.clone(), top: self.top.clone(), @@ -307,7 +403,6 @@ impl TableWithJoinsBuilder { self.relation = Some(value); self } - pub fn joins(&mut self, value: Vec) -> &mut Self { self.joins = value; self @@ -360,6 +455,23 @@ impl RelationBuilder { pub fn has_relation(&self) -> bool { self.relation.is_some() } + pub fn get_name(&self) -> Option<&ast::ObjectName> { + match self.relation { + Some(TableFactorBuilder::Table(ref value)) => value.name.as_ref(), + _ => None, + } + } + pub fn get_alias(&self) -> Option { + match self.relation { + Some(TableFactorBuilder::Table(ref value)) => { + value.alias.as_ref().map(|a| a.name.to_string()) + } + Some(TableFactorBuilder::Derived(ref value)) => { + value.alias.as_ref().map(|a| a.name.to_string()) + } + _ => None, + } + } pub fn table(&mut self, value: TableRelationBuilder) -> &mut Self { self.relation = Some(TableFactorBuilder::Table(value)); self diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 433c456855a3..6d49c291f2e7 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -158,10 +158,12 @@ impl Unparser<'_> { } let mut twj = select_builder.pop_from().unwrap(); - twj.relation(relation_builder); + twj.relation(relation_builder.clone()); select_builder.push_from(twj); - Ok(SetExpr::Select(Box::new(select_builder.build()?))) + Ok(SetExpr::Select(Box::new( + select_builder.build(query, &relation_builder)?, + ))) } /// Reconstructs a SELECT SQL statement from a logical plan by unprojecting column expressions diff --git a/datafusion/sql/src/unparser/rewrite.rs b/datafusion/sql/src/unparser/rewrite.rs index 6b3b999ba04b..fa8b5cfd3fdf 100644 --- a/datafusion/sql/src/unparser/rewrite.rs +++ b/datafusion/sql/src/unparser/rewrite.rs @@ -24,7 +24,7 @@ use datafusion_common::{ }; use datafusion_expr::{expr::Alias, tree_node::transform_sort_vec}; use datafusion_expr::{Expr, LogicalPlan, Projection, Sort, SortExpr}; -use sqlparser::ast::Ident; +use sqlparser::ast::{self, display_separated, Ident}; /// Normalize the schema of a union plan to remove qualifiers from the schema fields and sort expressions. /// @@ -363,3 +363,138 @@ impl TreeNodeRewriter for TableAliasRewriter<'_> { } } } + +/// Takes an input list of identifiers and a list of identifiers that are available from relations or joins. +/// Removes any table identifiers that are not present in the list of available identifiers, retains original column names. +pub fn remove_dangling_identifiers(idents: &mut Vec, available_idents: &[String]) { + if idents.len() > 1 { + let ident_source = display_separated( + &idents + .clone() + .into_iter() + .take(idents.len() - 1) + .collect::>(), + ".", + ) + .to_string(); + // If the identifier is not present in the list of all identifiers, it refers to a table that does not exist + if !available_idents.contains(&ident_source) { + let Some(last) = idents.last() else { + unreachable!("CompoundIdentifier must have a last element"); + }; + // Reset the identifiers to only the last element, which is the column name + *idents = vec![last.clone()]; + } + } +} + +/// Handle removing dangling identifiers from an expression +/// This function can call itself recursively to handle nested expressions +/// Like binary ops or functions which contain nested expressions/arguments +pub fn remove_dangling_expr( + expr: ast::Expr, + available_idents: &Vec, +) -> ast::Expr { + match expr { + ast::Expr::BinaryOp { left, op, right } => { + let left = remove_dangling_expr(*left, available_idents); + let right = remove_dangling_expr(*right, available_idents); + ast::Expr::BinaryOp { + left: Box::new(left), + op, + right: Box::new(right), + } + } + ast::Expr::Nested(expr) => { + let expr = remove_dangling_expr(*expr, available_idents); + ast::Expr::Nested(Box::new(expr)) + } + ast::Expr::CompoundIdentifier(idents) => { + let mut idents = idents.clone(); + remove_dangling_identifiers(&mut idents, available_idents); + + if idents.is_empty() { + unreachable!("Identifier must have at least one element"); + } else if idents.len() == 1 { + ast::Expr::Identifier(idents[0].clone()) + } else { + ast::Expr::CompoundIdentifier(idents) + } + } + ast::Expr::Function(ast::Function { + args, + name, + parameters, + filter, + null_treatment, + over, + within_group, + }) => { + let args = if let ast::FunctionArguments::List(mut args) = args { + args.args.iter_mut().for_each(|arg| match arg { + ast::FunctionArg::Named { + arg: ast::FunctionArgExpr::Expr(expr), + .. + } + | ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(expr)) => { + *expr = remove_dangling_expr(expr.clone(), available_idents); + } + _ => {} + }); + + ast::FunctionArguments::List(args) + } else { + args + }; + + ast::Expr::Function(ast::Function { + args, + name, + parameters, + filter, + null_treatment, + over, + within_group, + }) + } + _ => expr, + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_remove_dangling_identifiers() { + let tests = vec![ + (vec![], vec![Ident::new("column1".to_string())]), + ( + vec!["table1.table2".to_string()], + vec![ + Ident::new("table1".to_string()), + Ident::new("table2".to_string()), + Ident::new("column1".to_string()), + ], + ), + ( + vec!["table1".to_string()], + vec![Ident::new("column1".to_string())], + ), + ]; + + for test in tests { + let test_in = test.0; + let test_out = test.1; + + let mut idents = vec![ + Ident::new("table1".to_string()), + Ident::new("table2".to_string()), + Ident::new("column1".to_string()), + ]; + + remove_dangling_identifiers(&mut idents, &test_in); + assert_eq!(idents, test_out); + } + } +} diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 94e420066d8b..405a4c0e0339 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -345,6 +345,86 @@ fn roundtrip_statement_with_dialect() -> Result<()> { parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), }, + TestStatementWithDialect { + sql: "select j1_id from (select ta.j1_id from j1 ta)", + expected: + // This seems like desirable behavior, but is actually hiding an underlying issue + // The re-written identifier is `ta`.`j1_id`, because `reconstuct_select_statement` runs before the derived projection + // and for some reason, the derived table alias is pre-set to `ta` for the top-level projection + "SELECT `j1_id` FROM (SELECT `ta`.`j1_id` FROM `j1` AS `ta`) AS `derived_projection`", + parser_dialect: Box::new(MySqlDialect {}), + unparser_dialect: Box::new(UnparserMySqlDialect {}), + }, + TestStatementWithDialect { + sql: "select j1_id from (select ta.j1_id from j1 ta)", + expected: + "SELECT j1_id FROM (SELECT ta.j1_id FROM j1 AS ta)", + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "select j1_id from (select ta.j1_id from j1 ta) where j1_id > 1", + expected: + "SELECT j1_id FROM (SELECT ta.j1_id FROM j1 AS ta) WHERE (j1_id > 1)", + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "select j1_id from (select ta.j1_id from j1 ta) group by j1_id", + expected: + "SELECT j1_id FROM (SELECT ta.j1_id FROM j1 AS ta) GROUP BY j1_id", + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "select j1_id from (select ta.j1_id from j1 ta) order by j1_id", + expected: + "SELECT j1_id FROM (SELECT ta.j1_id FROM j1 AS ta) ORDER BY j1_id ASC NULLS LAST", + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "select j1_id from (select ta.j1_id from j1 ta) order by j1_id", + expected: + "SELECT `j1_id` FROM (SELECT `ta`.`j1_id` FROM `j1` AS `ta`) AS `derived_projection` ORDER BY `j1_id` ASC", + parser_dialect: Box::new(MySqlDialect {}), + unparser_dialect: Box::new(UnparserMySqlDialect {}), + }, + TestStatementWithDialect { + sql: "select j1_id from (select ta.j1_id from j1 ta) AS tbl1", + expected: + "SELECT tbl1.j1_id FROM (SELECT ta.j1_id FROM j1 AS ta) AS tbl1", + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "select j1_id, j2_id from (select ta.j1_id from j1 ta) AS tbl1, (select ta.j1_id as j2_id from j1 ta) as tbl2", + expected: + "SELECT tbl1.j1_id, tbl2.j2_id FROM (SELECT ta.j1_id FROM j1 AS ta) AS tbl1 JOIN (SELECT ta.j1_id AS j2_id FROM j1 AS ta) AS tbl2", + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "select j1_id, j2_id from (select ta.j1_id from j1 ta) AS tbl1, (select ta.j1_id as j2_id from j1 ta) as tbl2", + expected: + "SELECT `tbl1`.`j1_id`, `tbl2`.`j2_id` FROM (SELECT `ta`.`j1_id` FROM `j1` AS `ta`) AS `tbl1` JOIN (SELECT `ta`.`j1_id` AS `j2_id` FROM `j1` AS `ta`) AS `tbl2`", + parser_dialect: Box::new(MySqlDialect {}), + unparser_dialect: Box::new(UnparserMySqlDialect {}), + }, + TestStatementWithDialect { + sql: "select j1_id, j2_id from (select ta.j1_id from j1 ta), (select ta.j1_id as j2_id from j1 ta)", + expected: + "SELECT `j1_id`, `j2_id` FROM (SELECT `ta`.`j1_id` FROM `j1` AS `ta`) AS `derived_projection` JOIN (SELECT `ta`.`j1_id` AS `j2_id` FROM `j1` AS `ta`) AS `derived_projection`", + parser_dialect: Box::new(MySqlDialect {}), + unparser_dialect: Box::new(UnparserMySqlDialect {}), + }, + TestStatementWithDialect { + sql: "select j1_id, j2_id from (select ta.j1_id from j1 ta), (select ta.j1_id AS j2_id from j1 ta)", + expected: + "SELECT j1_id, j2_id FROM (SELECT ta.j1_id FROM j1 AS ta) JOIN (SELECT ta.j1_id AS j2_id FROM j1 AS ta)", + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, TestStatementWithDialect { sql: " SELECT @@ -585,7 +665,7 @@ fn test_aggregation_without_projection() -> Result<()> { assert_eq!( actual, - r#"SELECT sum(users.age), users."name" FROM (SELECT users."name", users.age FROM users) GROUP BY users."name""# + r#"SELECT sum(age), "name" FROM (SELECT users."name", users.age FROM users) GROUP BY "name""# ); Ok(())