Skip to content

Commit

Permalink
[SPARK-28227][SQL] Support projection, aggregate/window functions, an…
Browse files Browse the repository at this point in the history
…d lateral view in the TRANSFORM clause

### What changes were proposed in this pull request?
For Spark SQL, it can't support script transform SQL with aggregationClause/windowClause/LateralView.
This case we can't directly migration Hive SQL to Spark SQL.

In this PR, we treat all script transform statement's query part (exclude transform about part)  as a  separate query block and solve it as ScriptTransformation's child and pass a UnresolvedStart as ScriptTransform's input. Then in analyzer level, we pass child's output as ScriptTransform's input. Then we can support all kind of normal SELECT query combine with script transformation.

Such as transform with aggregation:
```
SELECT TRANSFORM ( d2, max(d1) as max_d1, sum(d3))
USING 'cat' AS (a,b,c)
FROM script_trans
WHERE d1 <= 100
GROUP BY d2
 HAVING max_d1 > 0
```
When we build AST, we treat it as
```
SELECT TRANSFORM (*)
USING 'cat' AS (a,b,c)
FROM (
     SELECT  d2, max(d1) as max_d1, sum(d3)
     FROM script_trans
    WHERE d1 <= 100
    GROUP BY d2
    HAVING max_d1 > 0
) tmp
```
then in Analyzer's `ResolveReferences`, resolve `* (UnresolvedStar)`, then sql behavior like
```
SELECT TRANSFORM ( d2, max(d1) as max_d1, sum(d3))
USING 'cat' AS (a,b,c)
FROM script_trans
WHERE d1 <= 100
GROUP BY d2
HAVING max_d1 > 0
```

About UT, in this pr we add a lot of different SQL to check we can support all kind of such SQL and  each kind of expressions can work well, such as alias, case when, binary compute etc...

### Why are the changes needed?
Support transform with aggregateClause/windowClause/LateralView etc , make sql migration more smoothly

### Does this PR introduce _any_ user-facing change?
User can write transform with  aggregateClause/windowClause/LateralView.

### How was this patch tested?
Added UT

Closes #29087 from AngersZhuuuu/SPARK-28227-NEW.

Lead-authored-by: angerszhu <angers.zhu@gmail.com>
Co-authored-by: Angerszhuuuu <angers.zhu@gmail.com>
Co-authored-by: AngersZhuuuu <angers.zhu@gmail.com>
Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
  • Loading branch information
AngersZhuuuu authored and maropu committed Apr 13, 2021
1 parent 9c1f807 commit 278203d
Show file tree
Hide file tree
Showing 8 changed files with 662 additions and 76 deletions.
Expand Up @@ -509,7 +509,11 @@ fromStatementBody
querySpecification
: transformClause
fromClause?
whereClause? #transformQuerySpecification
lateralView*
whereClause?
aggregationClause?
havingClause?
windowClause? #transformQuerySpecification
| selectClause
fromClause?
lateralView*
Expand Down
Expand Up @@ -1383,14 +1383,9 @@ class Analyzer(override val catalogManager: CatalogManager)
} else {
a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child))
}
// If the script transformation input contains Stars, expand it.
// TODO: Remove this logic and see SPARK-34035
case t: ScriptTransformation if containsStar(t.input) =>
t.copy(
input = t.input.flatMap {
case s: Star => s.expand(t.child, resolver)
case o => o :: Nil
}
)
t.copy(input = t.child.output)
case g: Generate if containsStar(g.generator.children) =>
throw QueryCompilationErrors.invalidStarUsageError("explode/json_tuple/UDTF")

Expand Down
Expand Up @@ -150,7 +150,11 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
withTransformQuerySpecification(
ctx,
ctx.transformClause,
ctx.lateralView,
ctx.whereClause,
ctx.aggregationClause,
ctx.havingClause,
ctx.windowClause,
plan
)
} else {
Expand Down Expand Up @@ -587,7 +591,16 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
val from = OneRowRelation().optional(ctx.fromClause) {
visitFromClause(ctx.fromClause)
}
withTransformQuerySpecification(ctx, ctx.transformClause, ctx.whereClause, from)
withTransformQuerySpecification(
ctx,
ctx.transformClause,
ctx.lateralView,
ctx.whereClause,
ctx.aggregationClause,
ctx.havingClause,
ctx.windowClause,
from
)
}

override def visitRegularQuerySpecification(
Expand Down Expand Up @@ -641,14 +654,12 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
private def withTransformQuerySpecification(
ctx: ParserRuleContext,
transformClause: TransformClauseContext,
lateralView: java.util.List[LateralViewContext],
whereClause: WhereClauseContext,
relation: LogicalPlan): LogicalPlan = withOrigin(ctx) {
// Add where.
val withFilter = relation.optionalMap(whereClause)(withWhereClause)

// Create the transform.
val expressions = visitNamedExpressionSeq(transformClause.namedExpressionSeq)

aggregationClause: AggregationClauseContext,
havingClause: HavingClauseContext,
windowClause: WindowClauseContext,
relation: LogicalPlan): LogicalPlan = withOrigin(ctx) {
// Create the attributes.
val (attributes, schemaLess) = if (transformClause.colTypeList != null) {
// Typed return columns.
Expand All @@ -664,12 +675,22 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
AttributeReference("value", StringType)()), true)
}

// Create the transform.
val plan = visitCommonSelectQueryClausePlan(
relation,
lateralView,
transformClause.namedExpressionSeq,
whereClause,
aggregationClause,
havingClause,
windowClause,
isDistinct = false)

ScriptTransformation(
expressions,
// TODO: Remove this logic and see SPARK-34035
Seq(UnresolvedStar(None)),
string(transformClause.script),
attributes,
withFilter,
plan,
withScriptIOSchema(
ctx,
transformClause.inRowFormat,
Expand Down Expand Up @@ -697,13 +718,40 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
havingClause: HavingClauseContext,
windowClause: WindowClauseContext,
relation: LogicalPlan): LogicalPlan = withOrigin(ctx) {
val isDistinct = selectClause.setQuantifier() != null &&
selectClause.setQuantifier().DISTINCT() != null

val plan = visitCommonSelectQueryClausePlan(
relation,
lateralView,
selectClause.namedExpressionSeq,
whereClause,
aggregationClause,
havingClause,
windowClause,
isDistinct)

// Hint
selectClause.hints.asScala.foldRight(plan)(withHints)
}

def visitCommonSelectQueryClausePlan(
relation: LogicalPlan,
lateralView: java.util.List[LateralViewContext],
namedExpressionSeq: NamedExpressionSeqContext,
whereClause: WhereClauseContext,
aggregationClause: AggregationClauseContext,
havingClause: HavingClauseContext,
windowClause: WindowClauseContext,
isDistinct: Boolean): LogicalPlan = {
// Add lateral views.
val withLateralView = lateralView.asScala.foldLeft(relation)(withGenerate)

// Add where.
val withFilter = withLateralView.optionalMap(whereClause)(withWhereClause)

val expressions = visitNamedExpressionSeq(selectClause.namedExpressionSeq)
val expressions = visitNamedExpressionSeq(namedExpressionSeq)

// Add aggregation or a project.
val namedExpressions = expressions.map {
case e: NamedExpression => e
Expand Down Expand Up @@ -737,9 +785,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
}

// Distinct
val withDistinct = if (
selectClause.setQuantifier() != null &&
selectClause.setQuantifier().DISTINCT() != null) {
val withDistinct = if (isDistinct) {
Distinct(withProject)
} else {
withProject
Expand All @@ -748,8 +794,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
// Window
val withWindow = withDistinct.optionalMap(windowClause)(withWindowClause)

// Hint
selectClause.hints.asScala.foldRight(withWindow)(withHints)
withWindow
}

// Script Transform's input/output format.
Expand Down
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.parser

import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction}
import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedStar, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
Expand Down Expand Up @@ -1074,11 +1074,11 @@ class PlanParserSuite extends AnalysisTest {
|FROM testData
""".stripMargin,
ScriptTransformation(
Seq('a, 'b, 'c),
Seq(UnresolvedStar(None)),
"cat",
Seq(AttributeReference("key", StringType)(),
AttributeReference("value", StringType)()),
UnresolvedRelation(TableIdentifier("testData")),
Project(Seq('a, 'b, 'c), UnresolvedRelation(TableIdentifier("testData"))),
ScriptInputOutputSchema(List.empty, List.empty, None, None,
List.empty, List.empty, None, None, true))
)
Expand All @@ -1091,12 +1091,12 @@ class PlanParserSuite extends AnalysisTest {
|FROM testData
""".stripMargin,
ScriptTransformation(
Seq('a, 'b, 'c),
Seq(UnresolvedStar(None)),
"cat",
Seq(AttributeReference("a", StringType)(),
AttributeReference("b", StringType)(),
AttributeReference("c", StringType)()),
UnresolvedRelation(TableIdentifier("testData")),
Project(Seq('a, 'b, 'c), UnresolvedRelation(TableIdentifier("testData"))),
ScriptInputOutputSchema(List.empty, List.empty, None, None,
List.empty, List.empty, None, None, false)))

Expand All @@ -1108,12 +1108,12 @@ class PlanParserSuite extends AnalysisTest {
|FROM testData
""".stripMargin,
ScriptTransformation(
Seq('a, 'b, 'c),
Seq(UnresolvedStar(None)),
"cat",
Seq(AttributeReference("a", IntegerType)(),
AttributeReference("b", StringType)(),
AttributeReference("c", LongType)()),
UnresolvedRelation(TableIdentifier("testData")),
Project(Seq('a, 'b, 'c), UnresolvedRelation(TableIdentifier("testData"))),
ScriptInputOutputSchema(List.empty, List.empty, None, None,
List.empty, List.empty, None, None, false)))

Expand All @@ -1137,12 +1137,12 @@ class PlanParserSuite extends AnalysisTest {
|FROM testData
""".stripMargin,
ScriptTransformation(
Seq('a, 'b, 'c),
Seq(UnresolvedStar(None)),
"cat",
Seq(AttributeReference("a", StringType)(),
AttributeReference("b", StringType)(),
AttributeReference("c", StringType)()),
UnresolvedRelation(TableIdentifier("testData")),
Project(Seq('a, 'b, 'c), UnresolvedRelation(TableIdentifier("testData"))),
ScriptInputOutputSchema(
Seq(("TOK_TABLEROWFORMATFIELD", "\t"),
("TOK_TABLEROWFORMATCOLLITEMS", "\u0002"),
Expand Down
132 changes: 132 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/transform.sql
Expand Up @@ -5,6 +5,12 @@ CREATE OR REPLACE TEMPORARY VIEW t AS SELECT * FROM VALUES
('3', true, unhex('537061726B2053514C'), tinyint(3), 3, smallint(300), bigint(3), float(3.0), 3.0, Decimal(3.0), timestamp('1997-02-10 17:32:01-08'), date('2000-04-03'))
AS t(a, b, c, d, e, f, g, h, i, j, k, l);

CREATE OR REPLACE TEMPORARY VIEW script_trans AS SELECT * FROM VALUES
(1, 2, 3),
(4, 5, 6),
(7, 8, 9)
AS script_trans(a, b, c);

SELECT TRANSFORM(a)
USING 'cat' AS (a)
FROM t;
Expand Down Expand Up @@ -184,6 +190,132 @@ SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k, l FROM (
FROM t
) tmp;

SELECT TRANSFORM(b, a, CAST(c AS STRING))
USING 'cat' AS (a, b, c)
FROM script_trans
WHERE a <= 4;

SELECT TRANSFORM(1, 2, 3)
USING 'cat' AS (a, b, c)
FROM script_trans
WHERE a <= 4;

SELECT TRANSFORM(1, 2)
USING 'cat' AS (a INT, b INT)
FROM script_trans
LIMIT 1;

SELECT TRANSFORM(
b AS d5, a,
CASE
WHEN c > 100 THEN 1
WHEN c < 100 THEN 2
ELSE 3 END)
USING 'cat' AS (a, b, c)
FROM script_trans
WHERE a <= 4;

SELECT TRANSFORM(b, a, c + 1)
USING 'cat' AS (a, b, c)
FROM script_trans
WHERE a <= 4;

SELECT TRANSFORM(*)
USING 'cat' AS (a, b, c)
FROM script_trans
WHERE a <= 4;

SELECT TRANSFORM(b AS d, MAX(a) as max_a, CAST(SUM(c) AS STRING))
USING 'cat' AS (a, b, c)
FROM script_trans
WHERE a <= 4
GROUP BY b;

SELECT TRANSFORM(b AS d, MAX(a) FILTER (WHERE a > 3) AS max_a, CAST(SUM(c) AS STRING))
USING 'cat' AS (a,b,c)
FROM script_trans
WHERE a <= 4
GROUP BY b;

SELECT TRANSFORM(b, MAX(a) as max_a, CAST(sum(c) AS STRING))
USING 'cat' AS (a, b, c)
FROM script_trans
WHERE a <= 2
GROUP BY b;

SELECT TRANSFORM(b, MAX(a) as max_a, CAST(SUM(c) AS STRING))
USING 'cat' AS (a, b, c)
FROM script_trans
WHERE a <= 4
GROUP BY b
HAVING max_a > 0;

SELECT TRANSFORM(b, MAX(a) as max_a, CAST(SUM(c) AS STRING))
USING 'cat' AS (a, b, c)
FROM script_trans
WHERE a <= 4
GROUP BY b
HAVING max(a) > 1;

SELECT TRANSFORM(b, MAX(a) OVER w as max_a, CAST(SUM(c) OVER w AS STRING))
USING 'cat' AS (a, b, c)
FROM script_trans
WHERE a <= 4
WINDOW w AS (PARTITION BY b ORDER BY a);

SELECT TRANSFORM(b, MAX(a) as max_a, CAST(SUM(c) AS STRING), myCol, myCol2)
USING 'cat' AS (a, b, c, d, e)
FROM script_trans
LATERAL VIEW explode(array(array(1,2,3))) myTable AS myCol
LATERAL VIEW explode(myTable.myCol) myTable2 AS myCol2
WHERE a <= 4
GROUP BY b, myCol, myCol2
HAVING max(a) > 1;

FROM(
FROM script_trans
SELECT TRANSFORM(a, b)
USING 'cat' AS (`a` INT, b STRING)
) t
SELECT a + 1;

FROM(
SELECT TRANSFORM(a, SUM(b) b)
USING 'cat' AS (`a` INT, b STRING)
FROM script_trans
GROUP BY a
) t
SELECT (b + 1) AS result
ORDER BY result;

MAP k / 10 USING 'cat' AS (one) FROM (SELECT 10 AS k);

FROM (SELECT 1 AS key, 100 AS value) src
MAP src.*, src.key, CAST(src.key / 10 AS INT), CAST(src.key % 10 AS INT), src.value
USING 'cat' AS (k, v, tkey, ten, one, tvalue);

SELECT TRANSFORM(1)
USING 'cat' AS (a)
FROM script_trans
HAVING true;

SET spark.sql.legacy.parser.havingWithoutGroupByAsWhere=true;

SELECT TRANSFORM(1)
USING 'cat' AS (a)
FROM script_trans
HAVING true;

SET spark.sql.legacy.parser.havingWithoutGroupByAsWhere=false;

SET spark.sql.parser.quotedRegexColumnNames=true;

SELECT TRANSFORM(`(a|b)?+.+`)
USING 'cat' AS (c)
FROM script_trans;

SET spark.sql.parser.quotedRegexColumnNames=false;

-- SPARK-34634: self join using CTE contains transform
WITH temp AS (
SELECT TRANSFORM(a) USING 'cat' AS (b string) FROM t
Expand Down

0 comments on commit 278203d

Please sign in to comment.