Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-20311][SQL] Support aliases for table value functions #17928

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -472,15 +472,23 @@ identifierComment
;

relationPrimary
: tableIdentifier sample? (AS? strictIdentifier)? #tableName
| '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery
| '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation
| inlineTable #inlineTableDefault2
| identifier '(' (expression (',' expression)*)? ')' #tableValuedFunction
: tableIdentifier sample? (AS? strictIdentifier)? #tableName
| '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery
| '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation
| inlineTable #inlineTableDefault2
| functionTable #tableValuedFunction
;

inlineTable
: VALUES expression (',' expression)* (AS? identifier identifierList?)?
: VALUES expression (',' expression)* tableAlias
;

functionTable
: identifier '(' (expression (',' expression)*)? ')' tableAlias
;

tableAlias
: (AS? strictIdentifier identifierList?)?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also hits another bug in inline tables. Maybe you also can include the following query in the test case inline-table.sql?

sql("SELECT * FROM VALUES (\"one\", 1), (\"three\", null) CROSS JOIN VALUES (\"one\", 1), (\"three\", null)")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

;

rowFormat
Expand Down
Expand Up @@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.analysis

import java.util.Locale

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range}
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Range}
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{DataType, IntegerType, LongType}

Expand Down Expand Up @@ -105,7 +105,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) =>
builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match {
val resolvedFunc = builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match {
case Some(tvf) =>
val resolved = tvf.flatMap { case (argList, resolver) =>
argList.implicitCast(u.functionArgs) match {
Expand All @@ -125,5 +125,21 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] {
case _ =>
u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function")
}

// If alias names assigned, add `Project` with the aliases
if (u.outputNames.nonEmpty) {
val outputAttrs = resolvedFunc.output
// Checks if the number of the aliases is equal to expected one
if (u.outputNames.size != outputAttrs.size) {
u.failAnalysis(s"expected ${outputAttrs.size} columns but " +
s"found ${u.outputNames.size} columns")
}
val aliases = outputAttrs.zip(u.outputNames).map {
case (attr, name) => Alias(attr, name)()
}
Project(aliases, resolvedFunc)
} else {
resolvedFunc
}
}
}
Expand Up @@ -66,10 +66,16 @@ case class UnresolvedInlineTable(
/**
* A table-valued function, e.g.
* {{{
* select * from range(10);
* select id from range(10);
*
* // Assign alias names
* select t.a from range(10) t(a);
* }}}
*/
case class UnresolvedTableValuedFunction(functionName: String, functionArgs: Seq[Expression])
case class UnresolvedTableValuedFunction(
functionName: String,
functionArgs: Seq[Expression],
outputNames: Seq[String])
extends LeafNode {

override def output: Seq[Attribute] = Nil
Expand Down
Expand Up @@ -687,7 +687,16 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
*/
override def visitTableValuedFunction(ctx: TableValuedFunctionContext)
: LogicalPlan = withOrigin(ctx) {
UnresolvedTableValuedFunction(ctx.identifier.getText, ctx.expression.asScala.map(expression))
val func = ctx.functionTable
val aliases = if (func.tableAlias.identifierList != null) {
visitIdentifierList(func.tableAlias.identifierList)
} else {
Seq.empty
}

val tvf = UnresolvedTableValuedFunction(
func.identifier.getText, func.expression.asScala.map(expression), aliases)
tvf.optionalMap(func.tableAlias.strictIdentifier)(aliasPlan)
}

/**
Expand All @@ -705,14 +714,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
}
}

val aliases = if (ctx.identifierList != null) {
visitIdentifierList(ctx.identifierList)
val aliases = if (ctx.tableAlias.identifierList != null) {
visitIdentifierList(ctx.tableAlias.identifierList)
} else {
Seq.tabulate(rows.head.size)(i => s"col${i + 1}")
}

val table = UnresolvedInlineTable(aliases, rows)
table.optionalMap(ctx.identifier)(aliasPlan)
table.optionalMap(ctx.tableAlias.strictIdentifier)(aliasPlan)
}

/**
Expand Down
Expand Up @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.Cross
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -441,4 +440,17 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers {

checkAnalysis(SubqueryAlias("tbl", testRelation).as("tbl2"), testRelation)
}

test("SPARK-20311 range(N) as alias") {
def rangeWithAliases(args: Seq[Int], outputNames: Seq[String]): LogicalPlan = {
SubqueryAlias("t", UnresolvedTableValuedFunction("range", args.map(Literal(_)), outputNames))
.select(star())
}
assertAnalysisSuccess(rangeWithAliases(3 :: Nil, "a" :: Nil))
assertAnalysisSuccess(rangeWithAliases(1 :: 4 :: Nil, "b" :: Nil))
assertAnalysisSuccess(rangeWithAliases(2 :: 6 :: 2 :: Nil, "c" :: Nil))
assertAnalysisError(
rangeWithAliases(3 :: Nil, "a" :: "b" :: Nil),
Seq("expected 1 columns but found 2 columns"))
}
}
Expand Up @@ -468,7 +468,18 @@ class PlanParserSuite extends PlanTest {
test("table valued function") {
assertEqual(
"select * from range(2)",
UnresolvedTableValuedFunction("range", Literal(2) :: Nil).select(star()))
UnresolvedTableValuedFunction("range", Literal(2) :: Nil, Seq.empty).select(star()))
}

test("SPARK-20311 range(N) as alias") {
assertEqual(
"SELECT * FROM range(10) AS t",
SubqueryAlias("t", UnresolvedTableValuedFunction("range", Literal(10) :: Nil, Seq.empty))
.select(star()))
assertEqual(
"SELECT * FROM range(7) AS t(a)",
SubqueryAlias("t", UnresolvedTableValuedFunction("range", Literal(7) :: Nil, "a" :: Nil))
.select(star()))
}

test("inline table") {
Expand Down
3 changes: 3 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/inline-table.sql
Expand Up @@ -49,3 +49,6 @@ select * from values ("one", count(1)), ("two", 2) as data(a, b);

-- string to timestamp
select * from values (timestamp('1991-12-06 00:00:00.0'), array(timestamp('1991-12-06 01:00:00.0'), timestamp('1991-12-06 12:00:00.0'))) as data(a, b);

-- cross-join inline tables
SELECT * FROM VALUES ('one', 1), ('three', null) CROSS JOIN VALUES ('one', 1), ('three', null);
Copy link
Contributor

@cloud-fan cloud-fan May 11, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this expose the bug? If we treat CROSS as an alias, we still get the same result. how about we run EXPLAIN?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, ok

Expand Up @@ -24,3 +24,7 @@ select * from RaNgE(2);

-- Explain
EXPLAIN select * from RaNgE(2);

-- cross-join table valued functions
SET spark.sql.crossJoin.enabled=true;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you remove this line? If we specify CROSS JOIN in the query, no need to set this parm.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

EXPLAIN EXTENDED SELECT * FROM range(3) CROSS JOIN range(3);
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 17
-- Number of queries: 18


-- !query 0
Expand Down Expand Up @@ -151,3 +151,14 @@ select * from values (timestamp('1991-12-06 00:00:00.0'), array(timestamp('1991-
struct<a:timestamp,b:array<timestamp>>
-- !query 16 output
1991-12-06 00:00:00 [1991-12-06 01:00:00.0,1991-12-06 12:00:00.0]


-- !query 17
SELECT * FROM VALUES ('one', 1), ('three', null) CROSS JOIN VALUES ('one', 1), ('three', null)
-- !query 17 schema
struct<col1:string,col2:int,col1:string,col2:int>
-- !query 17 output
one 1 one 1
one 1 three NULL
three NULL one 1
three NULL three NULL
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 9
-- Number of queries: 11


-- !query 0
Expand Down Expand Up @@ -103,3 +103,41 @@ struct<plan:string>
-- !query 8 output
== Physical Plan ==
*Range (0, 2, step=1, splits=2)


-- !query 9
SET spark.sql.crossJoin.enabled=true
-- !query 9 schema
struct<key:string,value:string>
-- !query 9 output
spark.sql.crossJoin.enabled true


-- !query 10
EXPLAIN EXTENDED SELECT * FROM range(3) CROSS JOIN range(3)
-- !query 10 schema
struct<plan:string>
-- !query 10 output
== Parsed Logical Plan ==
'Project [*]
+- 'Join Cross
:- 'UnresolvedTableValuedFunction range, [3]
+- 'UnresolvedTableValuedFunction range, [3]

== Analyzed Logical Plan ==
id: bigint, id: bigint
Project [id#xL, id#xL]
+- Join Cross
:- Range (0, 3, step=1, splits=None)
+- Range (0, 3, step=1, splits=None)

== Optimized Logical Plan ==
Join Cross
:- Range (0, 3, step=1, splits=None)
+- Range (0, 3, step=1, splits=None)

== Physical Plan ==
BroadcastNestedLoopJoin BuildRight, Cross
:- *Range (0, 3, step=1, splits=2)
+- BroadcastExchange IdentityBroadcastMode
+- *Range (0, 3, step=1, splits=2)