Skip to content

Commit

Permalink
Support aliases for table value functions
Browse files Browse the repository at this point in the history
  • Loading branch information
maropu committed May 10, 2017
1 parent c0189ab commit 0904fc9
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 18 deletions.
Expand Up @@ -472,15 +472,23 @@ identifierComment
;

relationPrimary
: tableIdentifier sample? (AS? strictIdentifier)? #tableName
| '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery
| '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation
| inlineTable #inlineTableDefault2
| identifier '(' (expression (',' expression)*)? ')' #tableValuedFunction
: tableIdentifier sample? (AS? strictIdentifier)? #tableName
| '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery
| '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation
| inlineTable #inlineTableDefault2
| functionTable #tableValuedFunction
;

inlineTable
: VALUES expression (',' expression)* (AS? identifier identifierList?)?
: VALUES expression (',' expression)* tableAlias
;

functionTable
: identifier '(' (expression (',' expression)*)? ')' tableAlias
;

tableAlias
: (AS? strictIdentifier identifierList?)?
;

rowFormat
Expand Down
Expand Up @@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.analysis

import java.util.Locale

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range}
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Range}
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{DataType, IntegerType, LongType}

Expand Down Expand Up @@ -105,7 +105,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) =>
builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match {
val resolvedFunc = builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match {
case Some(tvf) =>
val resolved = tvf.flatMap { case (argList, resolver) =>
argList.implicitCast(u.functionArgs) match {
Expand All @@ -125,5 +125,21 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] {
case _ =>
u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function")
}

// If alias names assigned, add `Project` with the aliases
if (u.outputNames.nonEmpty) {
val outputAttrs = resolvedFunc.output
// Checks if the number of the aliases is equal to expected one
if (u.outputNames.size != outputAttrs.size) {
u.failAnalysis(s"expected ${outputAttrs.size} columns but " +
s"found ${u.outputNames.size} columns")
}
val aliases = outputAttrs.zip(u.outputNames).map {
case (attr, name) => Alias(attr, name)()
}
Project(aliases, resolvedFunc)
} else {
resolvedFunc
}
}
}
Expand Up @@ -66,10 +66,16 @@ case class UnresolvedInlineTable(
/**
* A table-valued function, e.g.
* {{{
* select * from range(10);
* select id from range(10);
*
* // Assign alias names
* select t.a from range(10) t(a);
* }}}
*/
case class UnresolvedTableValuedFunction(functionName: String, functionArgs: Seq[Expression])
case class UnresolvedTableValuedFunction(
functionName: String,
functionArgs: Seq[Expression],
outputNames: Seq[String])
extends LeafNode {

override def output: Seq[Attribute] = Nil
Expand Down
Expand Up @@ -687,7 +687,16 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
*/
override def visitTableValuedFunction(ctx: TableValuedFunctionContext)
: LogicalPlan = withOrigin(ctx) {
UnresolvedTableValuedFunction(ctx.identifier.getText, ctx.expression.asScala.map(expression))
val func = ctx.functionTable
val aliases = if (func.tableAlias.identifierList != null) {
visitIdentifierList(func.tableAlias.identifierList)
} else {
Seq.empty
}

val tvf = UnresolvedTableValuedFunction(
func.identifier.getText, func.expression.asScala.map(expression), aliases)
tvf.optionalMap(func.tableAlias.strictIdentifier)(aliasPlan)
}

/**
Expand All @@ -705,14 +714,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
}
}

val aliases = if (ctx.identifierList != null) {
visitIdentifierList(ctx.identifierList)
val aliases = if (ctx.tableAlias.identifierList != null) {
visitIdentifierList(ctx.tableAlias.identifierList)
} else {
Seq.tabulate(rows.head.size)(i => s"col${i + 1}")
}

val table = UnresolvedInlineTable(aliases, rows)
table.optionalMap(ctx.identifier)(aliasPlan)
table.optionalMap(ctx.tableAlias.strictIdentifier)(aliasPlan)
}

/**
Expand Down
Expand Up @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.Cross
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -441,4 +440,17 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers {

checkAnalysis(SubqueryAlias("tbl", testRelation).as("tbl2"), testRelation)
}

test("SPARK-20311 range(N) as alias") {
def rangeWithAliases(args: Seq[Int], outputNames: Seq[String]): LogicalPlan = {
SubqueryAlias("t", UnresolvedTableValuedFunction("range", args.map(Literal(_)), outputNames))
.select(star())
}
assertAnalysisSuccess(rangeWithAliases(3 :: Nil, "a" :: Nil))
assertAnalysisSuccess(rangeWithAliases(1 :: 4 :: Nil, "b" :: Nil))
assertAnalysisSuccess(rangeWithAliases(2 :: 6 :: 2 :: Nil, "c" :: Nil))
assertAnalysisError(
rangeWithAliases(3 :: Nil, "a" :: "b" :: Nil),
Seq("expected 1 columns but found 2 columns"))
}
}
Expand Up @@ -468,7 +468,18 @@ class PlanParserSuite extends PlanTest {
test("table valued function") {
assertEqual(
"select * from range(2)",
UnresolvedTableValuedFunction("range", Literal(2) :: Nil).select(star()))
UnresolvedTableValuedFunction("range", Literal(2) :: Nil, Seq.empty).select(star()))
}

test("SPARK-20311 range(N) as alias") {
assertEqual(
"select * from range(10) AS t",
SubqueryAlias("t", UnresolvedTableValuedFunction("range", Literal(10) :: Nil, Seq.empty))
.select(star()))
assertEqual(
"select * from range(7) AS t(a)",
SubqueryAlias("t", UnresolvedTableValuedFunction("range", Literal(7) :: Nil, "a" :: Nil))
.select(star()))
}

test("inline table") {
Expand Down
Expand Up @@ -24,3 +24,7 @@ select * from RaNgE(2);

-- Explain
EXPLAIN select * from RaNgE(2);

-- cross-join table valued functions
set spark.sql.crossJoin.enabled=true;
EXPLAIN EXTENDED select * from range(3) cross join range(3);
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 9
-- Number of queries: 11


-- !query 0
Expand Down Expand Up @@ -103,3 +103,41 @@ struct<plan:string>
-- !query 8 output
== Physical Plan ==
*Range (0, 2, step=1, splits=2)


-- !query 9
set spark.sql.crossJoin.enabled=true
-- !query 9 schema
struct<key:string,value:string>
-- !query 9 output
spark.sql.crossJoin.enabled true


-- !query 10
EXPLAIN EXTENDED select * from range(3) cross join range(3)
-- !query 10 schema
struct<plan:string>
-- !query 10 output
== Parsed Logical Plan ==
'Project [*]
+- 'Join Cross
:- 'UnresolvedTableValuedFunction range, [3]
+- 'UnresolvedTableValuedFunction range, [3]

== Analyzed Logical Plan ==
id: bigint, id: bigint
Project [id#xL, id#xL]
+- Join Cross
:- Range (0, 3, step=1, splits=None)
+- Range (0, 3, step=1, splits=None)

== Optimized Logical Plan ==
Join Cross
:- Range (0, 3, step=1, splits=None)
+- Range (0, 3, step=1, splits=None)

== Physical Plan ==
BroadcastNestedLoopJoin BuildRight, Cross
:- *Range (0, 3, step=1, splits=2)
+- BroadcastExchange IdentityBroadcastMode
+- *Range (0, 3, step=1, splits=2)

0 comments on commit 0904fc9

Please sign in to comment.