Skip to content

Commit

Permalink
Use column aliases specified by WITH statements (#3717)
Browse files Browse the repository at this point in the history
  • Loading branch information
isidentical committed Oct 5, 2022
1 parent 64669e9 commit 23682f6
Showing 1 changed file with 97 additions and 24 deletions.
121 changes: 97 additions & 24 deletions datafusion/sql/src/planner.rs
Expand Up @@ -59,8 +59,8 @@ use sqlparser::ast::{
BinaryOperator, DataType as SQLDataType, DateTimeField, Expr as SQLExpr, FunctionArg,
FunctionArgExpr, Ident, Join, JoinConstraint, JoinOperator, ObjectName,
Offset as SQLOffset, Query, Select, SelectItem, SetExpr, SetOperator,
ShowCreateObject, ShowStatementFilter, TableFactor, TableWithJoins, TimezoneInfo,
TrimWhereField, UnaryOperator, Value, Values as SQLValues,
ShowCreateObject, ShowStatementFilter, TableAlias, TableFactor, TableWithJoins,
TimezoneInfo, TrimWhereField, UnaryOperator, Value, Values as SQLValues,
};
use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption};
use sqlparser::ast::{ObjectType, OrderByExpr, Statement};
Expand Down Expand Up @@ -376,6 +376,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
&mut ctes.clone(),
outer_query_schema,
)?;

// Each `WITH` block can change the column names in the last
// projection (e.g. "WITH table(t1, t2) AS SELECT 1, 2").
let logical_plan = self.apply_table_alias(logical_plan, cte.alias)?;

ctes.insert(cte_name, logical_plan);
}
}
Expand Down Expand Up @@ -785,33 +790,40 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
};
if let Some(alias) = alias {
let columns_alias = alias.clone().columns;
if columns_alias.is_empty() {
// sqlparser-rs encodes AS t as an empty list of column alias
Ok(plan)
} else if columns_alias.len() != plan.schema().fields().len() {
Err(DataFusionError::Plan(format!(
"Source table contains {} columns but only {} names given as column alias",
plan.schema().fields().len(),
columns_alias.len(),
)))
} else {
Ok(LogicalPlanBuilder::from(plan.clone())
.project_with_alias(
plan.schema().fields().iter().zip(columns_alias.iter()).map(
|(field, ident)| {
col(field.name()).alias(&normalize_ident(ident))
},
),
Some(normalize_ident(&alias.name)),
)?
.build()?)
}
self.apply_table_alias(plan, alias)
} else {
Ok(plan)
}
}

/// Apply the given TableAlias to the top-level projection.
fn apply_table_alias(
&self,
plan: LogicalPlan,
alias: TableAlias,
) -> Result<LogicalPlan> {
let columns_alias = alias.clone().columns;
if columns_alias.is_empty() {
// sqlparser-rs encodes AS t as an empty list of column alias
Ok(plan)
} else if columns_alias.len() != plan.schema().fields().len() {
Err(DataFusionError::Plan(format!(
"Source table contains {} columns but only {} names given as column alias",
plan.schema().fields().len(),
columns_alias.len(),
)))
} else {
Ok(LogicalPlanBuilder::from(plan.clone())
.project_with_alias(
plan.schema().fields().iter().zip(columns_alias.iter()).map(
|(field, ident)| col(field.name()).alias(&normalize_ident(ident)),
),
Some(normalize_ident(&alias.name)),
)?
.build()?)
}
}

/// Generate a logic plan from selection clause, the function contain optimization for cross join to inner join
/// Related PR: <https://github.com/apache/arrow-datafusion/pull/1566>
fn plan_selection(
Expand Down Expand Up @@ -5046,6 +5058,67 @@ mod tests {
quick_test(sql, expected)
}

#[test]
fn cte_with_no_column_names() {
let sql = "WITH \
numbers AS ( \
SELECT 1 as a, 2 as b, 3 as c \
) \
SELECT * FROM numbers;";

let expected = "Projection: numbers.a, numbers.b, numbers.c\
\n Projection: Int64(1) AS a, Int64(2) AS b, Int64(3) AS c, alias=numbers\
\n EmptyRelation";

quick_test(sql, expected)
}

#[test]
fn cte_with_column_names() {
let sql = "WITH \
numbers(a, b, c) AS ( \
SELECT 1, 2, 3 \
) \
SELECT * FROM numbers;";

let expected = "Projection: numbers.a, numbers.b, numbers.c\
\n Projection: numbers.Int64(1) AS a, numbers.Int64(2) AS b, numbers.Int64(3) AS c, alias=numbers\
\n Projection: Int64(1), Int64(2), Int64(3), alias=numbers\
\n EmptyRelation";

quick_test(sql, expected)
}

#[test]
fn cte_with_column_aliases_precedence() {
// The end result should always be what CTE specification says
let sql = "WITH \
numbers(a, b, c) AS ( \
SELECT 1 as x, 2 as y, 3 as z \
) \
SELECT * FROM numbers;";

let expected = "Projection: numbers.a, numbers.b, numbers.c\
\n Projection: numbers.x AS a, numbers.y AS b, numbers.z AS c, alias=numbers\
\n Projection: Int64(1) AS x, Int64(2) AS y, Int64(3) AS z, alias=numbers\
\n EmptyRelation";

quick_test(sql, expected)
}

#[test]
fn cte_unbalanced_number_of_columns() {
let sql = "WITH \
numbers(a) AS ( \
SELECT 1, 2, 3 \
) \
SELECT * FROM numbers;";

let expected = "Error during planning: Source table contains 3 columns but only 1 names given as column alias";
let result = logical_plan(sql).err().unwrap();
assert_eq!(expected, format!("{}", result));
}

#[test]
fn aggregate_with_rollup() {
let sql = "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, ROLLUP (state, age)";
Expand Down

0 comments on commit 23682f6

Please sign in to comment.