From fcef61ec0a2246ecaf37cf9e17536c88915b5ff6 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 10 Jul 2019 07:46:21 +0200 Subject: [PATCH] [SPARK-27878][SQL] Support ARRAY(sub-SELECT) expressions --- .../spark/sql/catalyst/parser/SqlBase.g4 | 1 + .../sql/catalyst/analysis/Analyzer.scala | 12 +++++ .../expressions/complexTypeCreator.scala | 21 +++++++++ .../sql/catalyst/parser/AstBuilder.scala | 7 +++ .../parser/ExpressionParserSuite.scala | 4 ++ .../spark/sql/catalyst/plans/PlanTest.scala | 2 + .../test/resources/sql-tests/inputs/array.sql | 7 +++ .../inputs/pgSQL/aggregates_part2.sql | 2 +- .../resources/sql-tests/results/array.sql.out | 44 ++++++++++++++++++- 9 files changed, 98 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index a1c11504a9036..ddad4a0c15f98 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -703,6 +703,7 @@ primaryExpression FROM srcStr=valueExpression ')' #trim | OVERLAY '(' input=valueExpression PLACING replace=valueExpression FROM position=valueExpression (FOR length=valueExpression)? ')' #overlay + | ARRAY '(' query ')' #array ; constant diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5d37e909f80aa..5e6d186917ee6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1533,6 +1533,18 @@ class Analyzer( ListQuery(plan, exprs, exprId, plan.output) }) InSubquery(values, expr.asInstanceOf[ListQuery]) + case a @ ArraySubquery(sub, _, exprId) if !sub.resolved => + resolveSubQuery(a, plans) { (plan, children) => + // Array subquery must return one column as output. + if (plan.output.size != 1) { + failAnalysis( + s"Array subquery must return only one column, but got ${plan.output.size}") + } + ScalarSubquery(Aggregate(Seq.empty, Seq( + Alias(AggregateExpression(CollectList(plan.output.head), Complete, false), "array()") + () + ), plan)) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 319a7fc87e59a..015003d8ab903 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -70,6 +71,26 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def prettyName: String = "array" } +case class ArraySubquery( + plan: LogicalPlan, + children: Seq[Expression] = Seq.empty, + exprId: ExprId = NamedExpression.newExprId) + extends SubqueryExpression(plan, children, exprId) with Unevaluable { + override def dataType: DataType = { + assert(plan.schema.fields.nonEmpty, "Array subquery should have only one column") + plan.schema.fields.head.dataType + } + override def nullable: Boolean = true + override def withNewPlan(plan: LogicalPlan): ArraySubquery = copy(plan = plan) + override def toString: String = s"array-subquery#${exprId.id} $conditionString" + override lazy val canonicalized: Expression = { + ArraySubquery( + plan.canonicalized, + children.map(_.canonicalized), + ExprId(0)) + } +} + private [sql] object GenArrayData { /** * Return Java code pieces based on DataType and array size to allocate ArrayData class diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 6c5ad55e88bea..e22fa02a84f42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1436,6 +1436,13 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging } } + /** + * Create an Array from a query. + */ + override def visitArray(ctx: ArrayContext): Expression = withOrigin(ctx) { + ArraySubquery(plan(ctx.query)) + } + /** * Create a (windowed) Function expression. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index e16262ddb9cd3..bdf27a9e79370 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -780,4 +780,8 @@ class ExpressionParserSuite extends AnalysisTest { assertEqual("current_timestamp", UnresolvedAttribute.quoted("current_timestamp")) } } + + test("Array from subquery") { + assertEqual("array(SELECT c FROM t)", ArraySubquery(table("t").select('c))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 5394732f41f2d..7b2e1ac862782 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -73,6 +73,8 @@ trait PlanTestBase extends PredicateHelper with SQLHelper { self: Suite => e.copy(exprId = ExprId(0)) case l: ListQuery => l.copy(exprId = ExprId(0)) + case a: ArraySubquery => + a.copy(exprId = ExprId(0)) case a: AttributeReference => AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) case a: Alias => diff --git a/sql/core/src/test/resources/sql-tests/inputs/array.sql b/sql/core/src/test/resources/sql-tests/inputs/array.sql index 984321ab795fc..d7cd6d43c1e41 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/array.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/array.sql @@ -90,3 +90,10 @@ select size(date_array), size(timestamp_array) from primitive_arrays; + +-- array from subquery +select array(select 1); +select array(select a from data); +select array(select a from data where false); +select array(select 1, 2); +select array(select a, a from data); diff --git a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/aggregates_part2.sql b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/aggregates_part2.sql index c4613701ec747..3c5a26de70e2d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/aggregates_part2.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/aggregates_part2.sql @@ -28,7 +28,7 @@ create temporary view int4_tbl as select * from values -- from generate_series(1, 3) s2 group by s2) ss -- order by 1, 2; --- [SPARK-27878] Support ARRAY(sub-SELECT) expressions +-- [SPARK-27769] Handling of sublinks within outer-level aggregates. -- explain (verbose, costs off) -- select array(select sum(x+y) s -- from generate_series(1,3) y group by y order by s) diff --git a/sql/core/src/test/resources/sql-tests/results/array.sql.out b/sql/core/src/test/resources/sql-tests/results/array.sql.out index 5f5d988771847..d4960673f46cc 100644 --- a/sql/core/src/test/resources/sql-tests/results/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/array.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 12 +-- Number of queries: 17 -- !query 0 @@ -160,3 +160,45 @@ from primitive_arrays struct -- !query 11 output 1 2 2 2 2 2 2 2 2 2 + + +-- !query 12 +select array(select 1) +-- !query 12 schema +struct> +-- !query 12 output +[1] + + +-- !query 13 +select array(select a from data) +-- !query 13 schema +struct> +-- !query 13 output +["one","two"] + + +-- !query 14 +select array(select a from data where false) +-- !query 14 schema +struct> +-- !query 14 output +[] + + +-- !query 15 +select array(select 1, 2) +-- !query 15 schema +struct<> +-- !query 15 output +org.apache.spark.sql.AnalysisException +Array subquery must return only one column, but got 2; + + +-- !query 16 +select array(select a, a from data) +-- !query 16 schema +struct<> +-- !query 16 output +org.apache.spark.sql.AnalysisException +Array subquery must return only one column, but got 2;