From b98865127a39bde885f9b1680cfe608629d59d51 Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Fri, 29 Jul 2016 17:43:56 -0400 Subject: [PATCH 01/16] [SPARK-16804][SQL] Correlated subqueries containing LIMIT return incorrect results ## What changes were proposed in this pull request? This patch fixes the incorrect results in the rule ResolveSubquery in Catalyst's Analysis phase. ## How was this patch tested? ./dev/run-tests a new unit test on the problematic pattern. --- .../apache/spark/sql/catalyst/analysis/Analyzer.scala | 10 ++++++++++ .../sql/catalyst/analysis/AnalysisErrorSuite.scala | 8 ++++++++ 2 files changed, 18 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2efa997ff22d2..c3ee6517875c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1021,6 +1021,16 @@ class Analyzer( case e: Expand => failOnOuterReferenceInSubTree(e, "an EXPAND") e + case l @ LocalLimit(_, child) => + failOnOuterReferenceInSubTree(l, "LIMIT") + l + // Since LIMIT is represented as GlobalLimit(, (LocalLimit (, child)) + // and we are walking bottom up, we will fail on LocalLimit before + // reaching GlobalLimit. + // The code below is just a safety net. + case g @ GlobalLimit(_, child) => + failOnOuterReferenceInSubTree(g, "LIMIT") + g case p => failOnOuterReference(p) p diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index ff112c51697ad..b78a988eddbb0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -533,5 +533,13 @@ class AnalysisErrorSuite extends AnalysisTest { Exists(Union(LocalRelation(b), Filter(EqualTo(OuterReference(a), c), LocalRelation(c)))), LocalRelation(a)) assertAnalysisError(plan3, "Accessing outer query column is not allowed in" :: Nil) + + val plan4 = Filter( + Exists( + Limit(1, + Filter(EqualTo(OuterReference(a), b), LocalRelation(b))) + ), + LocalRelation(a)) + assertAnalysisError(plan4, "Accessing outer query column is not allowed in LIMIT" :: Nil) } } From 069ed8f8e5f14dca7a15701945d42fc27fe82f3c Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Fri, 29 Jul 2016 17:50:02 -0400 Subject: [PATCH 02/16] [SPARK-16804][SQL] Correlated subqueries containing LIMIT return incorrect results ## What changes were proposed in this pull request? This patch fixes the incorrect results in the rule ResolveSubquery in Catalyst's Analysis phase. ## How was this patch tested? ./dev/run-tests a new unit test on the problematic pattern. --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c3ee6517875c7..357c763f59467 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1022,14 +1022,14 @@ class Analyzer( failOnOuterReferenceInSubTree(e, "an EXPAND") e case l @ LocalLimit(_, child) => - failOnOuterReferenceInSubTree(l, "LIMIT") + failOnOuterReferenceInSubTree(l, "a LIMIT") l // Since LIMIT is represented as GlobalLimit(, (LocalLimit (, child)) // and we are walking bottom up, we will fail on LocalLimit before // reaching GlobalLimit. // The code below is just a safety net. case g @ GlobalLimit(_, child) => - failOnOuterReferenceInSubTree(g, "LIMIT") + failOnOuterReferenceInSubTree(g, "a LIMIT") g case p => failOnOuterReference(p) From edca333c081e6d4e53a91b496fba4a3ef4ee89ac Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Fri, 29 Jul 2016 20:28:15 -0400 Subject: [PATCH 03/16] New positive test cases --- .../org/apache/spark/sql/SubquerySuite.scala | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index afed342ff8e2a..52387b4b72a16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -571,4 +571,33 @@ class SubquerySuite extends QueryTest with SharedSQLContext { Row(1.0, false) :: Row(1.0, false) :: Row(2.0, true) :: Row(2.0, true) :: Row(3.0, false) :: Row(5.0, true) :: Row(null, false) :: Row(null, true) :: Nil) } + + test("SPARK-16804: Correlated subqueries containing LIMIT - 1") { + withTempView("onerow") { + Seq(1).toDF("c1").createOrReplaceTempView("onerow") + + checkAnswer( + sql( + """ + | select c1 from onerow t1 + | where exists (select 1 from onerow t2 where t1.c1=t2.c1) + | and exists (select 1 from onerow LIMIT 1)""".stripMargin), + Row(1) :: Nil) + } + } + + test("SPARK-16804: Correlated subqueries containing LIMIT - 2") { + withTempView("onerow") { + Seq(1).toDF("c1").createOrReplaceTempView("onerow") + + checkAnswer( + sql( + """ + | select c1 from onerow t1 + | where exists (select 1 + | from (select 1 from onerow t2 LIMIT 1) + | where t1.c1=t2.c1)""".stripMargin), + Row(1) :: Nil) + } + } } From 64184fdb77c1a305bb2932e82582da28bb4c0e53 Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Mon, 1 Aug 2016 09:20:09 -0400 Subject: [PATCH 04/16] Fix unit test case failure --- .../apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index b78a988eddbb0..c08de826bd945 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -540,6 +540,6 @@ class AnalysisErrorSuite extends AnalysisTest { Filter(EqualTo(OuterReference(a), b), LocalRelation(b))) ), LocalRelation(a)) - assertAnalysisError(plan4, "Accessing outer query column is not allowed in LIMIT" :: Nil) + assertAnalysisError(plan4, "Accessing outer query column is not allowed in a LIMIT" :: Nil) } } From 29f82b05c9e40e7934397257c674b260a8e8a996 Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Fri, 5 Aug 2016 13:42:01 -0400 Subject: [PATCH 05/16] blocking TABLESAMPLE --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 7 +++++-- .../spark/sql/catalyst/analysis/AnalysisErrorSuite.scala | 8 ++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 357c763f59467..9d99c4173d4af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1021,16 +1021,19 @@ class Analyzer( case e: Expand => failOnOuterReferenceInSubTree(e, "an EXPAND") e - case l @ LocalLimit(_, child) => + case l @ LocalLimit(_, _) => failOnOuterReferenceInSubTree(l, "a LIMIT") l // Since LIMIT is represented as GlobalLimit(, (LocalLimit (, child)) // and we are walking bottom up, we will fail on LocalLimit before // reaching GlobalLimit. // The code below is just a safety net. - case g @ GlobalLimit(_, child) => + case g @ GlobalLimit(_, _) => failOnOuterReferenceInSubTree(g, "a LIMIT") g + case s @ Sample(_, _, _, _, _) => + failOnOuterReferenceInSubTree(s, "a TABLESAMPLE") + s case p => failOnOuterReference(p) p diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index c08de826bd945..0b7d681be5114 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -541,5 +541,13 @@ class AnalysisErrorSuite extends AnalysisTest { ), LocalRelation(a)) assertAnalysisError(plan4, "Accessing outer query column is not allowed in a LIMIT" :: Nil) + + val plan5 = Filter( + Exists( + Sample(0.0, 0.5, false, 1L, + Filter(EqualTo(OuterReference(a), b), LocalRelation(b)))().select('b) + ), + LocalRelation(a)) + assertAnalysisError(plan5, "Accessing outer query column is not allowed in a TABLESAMPLE" :: Nil) } } From ac43ab47907a1ccd6d22f920415fbb4de93d4720 Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Fri, 5 Aug 2016 17:10:19 -0400 Subject: [PATCH 06/16] Fixing code styling --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9d99c4173d4af..29ede7048a2db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1021,17 +1021,17 @@ class Analyzer( case e: Expand => failOnOuterReferenceInSubTree(e, "an EXPAND") e - case l @ LocalLimit(_, _) => + case l : LocalLimit => failOnOuterReferenceInSubTree(l, "a LIMIT") l // Since LIMIT is represented as GlobalLimit(, (LocalLimit (, child)) // and we are walking bottom up, we will fail on LocalLimit before // reaching GlobalLimit. // The code below is just a safety net. - case g @ GlobalLimit(_, _) => + case g : GlobalLimit => failOnOuterReferenceInSubTree(g, "a LIMIT") g - case s @ Sample(_, _, _, _, _) => + case s : Sample => failOnOuterReferenceInSubTree(s, "a TABLESAMPLE") s case p => From 631d396031e8bf627eb1f4872a4d3a17c144536c Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Sun, 7 Aug 2016 14:39:44 -0400 Subject: [PATCH 07/16] Correcting Scala test style --- .../spark/sql/catalyst/analysis/AnalysisErrorSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 0b7d681be5114..8935d979414ae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -548,6 +548,7 @@ class AnalysisErrorSuite extends AnalysisTest { Filter(EqualTo(OuterReference(a), b), LocalRelation(b)))().select('b) ), LocalRelation(a)) - assertAnalysisError(plan5, "Accessing outer query column is not allowed in a TABLESAMPLE" :: Nil) + assertAnalysisError(plan5, + "Accessing outer query column is not allowed in a TABLESAMPLE" :: Nil) } } From 7eb9b2dbba3633a1958e38e0019e3ce816300514 Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Sun, 7 Aug 2016 22:31:09 -0400 Subject: [PATCH 08/16] One (last) attempt to correct the Scala style tests --- .../apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 8935d979414ae..6438065fb292e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -548,7 +548,7 @@ class AnalysisErrorSuite extends AnalysisTest { Filter(EqualTo(OuterReference(a), b), LocalRelation(b)))().select('b) ), LocalRelation(a)) - assertAnalysisError(plan5, + assertAnalysisError(plan5, "Accessing outer query column is not allowed in a TABLESAMPLE" :: Nil) } } From e8717831683f5f9a78b2660abe420042c8d1df6c Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Sat, 10 Dec 2016 10:45:18 -0500 Subject: [PATCH 09/16] first pass --- .../spark/sql/catalyst/analysis/Analyzer.scala | 18 ++++++++++++++++-- .../sql/catalyst/analysis/CheckAnalysis.scala | 13 ------------- .../org/apache/spark/sql/SubquerySuite.scala | 12 +++++++++++- 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ed6e17a8eb465..c15fbfd2e784a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1230,10 +1230,24 @@ class Analyzer( */ private def rewriteSubQuery( sub: LogicalPlan, - outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = { + outer: Seq[LogicalPlan], + scalarSubq: Boolean = false): (LogicalPlan, Seq[Expression]) = { // Pull out the tagged predicates and rewrite the subquery in the process. val (basePlan, baseConditions) = pullOutCorrelatedPredicates(sub) + // SPARK-18504: block cases where GROUP BY columns + // are not part of the correlated columns + if (scalarSubq && sub.isInstanceOf[Aggregate]) { + val groupByCols = ExpressionSet.apply(sub.asInstanceOf[Aggregate]. + groupingExpressions.flatMap(_.references)) + val conditionsCols = ExpressionSet.apply(baseConditions.flatMap(_.references)) + val invalidCols = groupByCols.diff(conditionsCols) + if (invalidCols.nonEmpty) { + failAnalysis("a GROUP BY clause in a scalar correlated subquery " + + "cannot contain non-correlated columns: " + + invalidCols.mkString(",")) + } + } // Make sure the inner and the outer query attributes do not collide. val outputSet = outer.map(_.outputSet).reduce(_ ++ _) val duplicates = basePlan.outputSet.intersect(outputSet) @@ -1298,7 +1312,7 @@ class Analyzer( s"does not match the required number of columns ($requiredColumns)") } // Pullout predicates and construct a new plan. - f.tupled(rewriteSubQuery(current, plans)) + f.tupled(rewriteSubQuery(current, plans, e.isInstanceOf[ScalarSubquery])) } else { e.withNewPlan(current) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 235a79973d6ee..caf4a5ad35831 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -135,19 +135,6 @@ trait CheckAnalysis extends PredicateHelper { if (aggregates.isEmpty) { failAnalysis("The output of a correlated scalar subquery must be aggregated") } - - // SPARK-18504: block cases where GROUP BY columns - // are not part of the correlated columns - val groupByCols = ExpressionSet.apply(agg.groupingExpressions.flatMap(_.references)) - val predicateCols = ExpressionSet.apply(conditions.flatMap(_.references)) - val invalidCols = groupByCols.diff(predicateCols) - // GROUP BY columns must be a subset of columns in the predicates - if (invalidCols.nonEmpty) { - failAnalysis( - "a GROUP BY clause in a scalar correlated subquery " + - "cannot contain non-correlated columns: " + - invalidCols.mkString(",")) - } } // Skip projects and subquery aliases added by the Analyzer and the SQLBuilder. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 0f2f520006e35..cc565d9c61cd7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -495,6 +495,16 @@ class SubquerySuite extends QueryTest with SharedSQLContext { } } + test("SPARK-18814 extra GROUP BY column in correlated scalar subquery is not permitted") { + withTempView("p", "c") { + Seq((1,1)).toDF("pk","pv").createOrReplaceTempView("p") + Seq((1,1)).toDF("ck","cv").createOrReplaceTempView("c") + checkAnswer( + sql("select pk, cv from p,c where p.pk=c.ck and c.cv = (select avg(c1.cv) from c c1 where c1.ck = p.pk)"), + Row(1, 1) :: Nil) + } + } + test("non-aggregated correlated scalar subquery") { val msg1 = intercept[AnalysisException] { sql("select a, (select b from l l2 where l2.a = l1.a) sum_b from l l1") @@ -505,7 +515,7 @@ class SubquerySuite extends QueryTest with SharedSQLContext { sql("select a, (select b from l l2 where l2.a = l1.a group by 1) sum_b from l l1") } assert(msg2.getMessage.contains( - "The output of a correlated scalar subquery must be aggregated")) + "a GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns:")) } test("non-equal correlated scalar subquery") { From b93b3ce38205a76e9faa40c5520ca7affed441b4 Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Sat, 10 Dec 2016 10:53:00 -0500 Subject: [PATCH 10/16] second pass --- .../test/scala/org/apache/spark/sql/SubquerySuite.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index cc565d9c61cd7..2a519284c10b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -497,10 +497,11 @@ class SubquerySuite extends QueryTest with SharedSQLContext { test("SPARK-18814 extra GROUP BY column in correlated scalar subquery is not permitted") { withTempView("p", "c") { - Seq((1,1)).toDF("pk","pv").createOrReplaceTempView("p") - Seq((1,1)).toDF("ck","cv").createOrReplaceTempView("c") + Seq((1, 1)).toDF("pk", "pv").createOrReplaceTempView("p") + Seq((1, 1)).toDF("ck", "cv").createOrReplaceTempView("c") checkAnswer( - sql("select pk, cv from p,c where p.pk=c.ck and c.cv = (select avg(c1.cv) from c c1 where c1.ck = p.pk)"), + sql("select pk, cv from p,c where p.pk=c.ck and " + + "c.cv = (select avg(c1.cv) from c c1 where c1.ck = p.pk)"), Row(1, 1) :: Nil) } } From 09b543ba540c048751974ec37b12aa436721c0d8 Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Sat, 10 Dec 2016 20:48:28 -0500 Subject: [PATCH 11/16] address @gatorsmile's comments --- .../sql/catalyst/analysis/Analyzer.scala | 34 ++++++++++++------- .../sql-tests/inputs/scalar-subquery.sql | 10 ++++++ .../sql-tests/results/scalar-subquery.sql.out | 31 +++++++++++++++++ .../org/apache/spark/sql/SubquerySuite.scala | 15 ++------ 4 files changed, 64 insertions(+), 26 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/scalar-subquery.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/scalar-subquery.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c15fbfd2e784a..9d1896d043a51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1225,8 +1225,12 @@ class Analyzer( } /** - * Rewrite the subquery in a safe way by preventing that the subquery and the outer use the same - * attributes. + * Rewrite the subquery in a safe way by preventing that the subquery and + * the outer use the same attributes. + * + * If this is a scalar subquery, check that GROUP BY columns are a subset + * of the columns used in the correlated predicate(s). Otherwise, pulling up + * correlated predicates could cause incorrect results. */ private def rewriteSubQuery( sub: LogicalPlan, @@ -1235,17 +1239,21 @@ class Analyzer( // Pull out the tagged predicates and rewrite the subquery in the process. val (basePlan, baseConditions) = pullOutCorrelatedPredicates(sub) - // SPARK-18504: block cases where GROUP BY columns - // are not part of the correlated columns - if (scalarSubq && sub.isInstanceOf[Aggregate]) { - val groupByCols = ExpressionSet.apply(sub.asInstanceOf[Aggregate]. - groupingExpressions.flatMap(_.references)) - val conditionsCols = ExpressionSet.apply(baseConditions.flatMap(_.references)) - val invalidCols = groupByCols.diff(conditionsCols) - if (invalidCols.nonEmpty) { - failAnalysis("a GROUP BY clause in a scalar correlated subquery " + - "cannot contain non-correlated columns: " + - invalidCols.mkString(",")) + // SPARK-18504/SPARK-18814: + // Block cases where GROUP BY columns are not part of the correlated columns + // of a scalar subquery. + if (scalarSubq) { + sub match { + case a @ Aggregate(grouping, _, _) => + val groupByCols = ExpressionSet(grouping.flatMap(_.references)) + val conditionsCols = ExpressionSet(baseConditions.flatMap(_.references)) + val invalidCols = groupByCols.diff(conditionsCols) + if (invalidCols.nonEmpty) { + failAnalysis("A GROUP BY clause in a scalar correlated subquery " + + "cannot contain non-correlated columns: " + + invalidCols.mkString(",")) + } + case _ => None } } // Make sure the inner and the outer query attributes do not collide. diff --git a/sql/core/src/test/resources/sql-tests/inputs/scalar-subquery.sql b/sql/core/src/test/resources/sql-tests/inputs/scalar-subquery.sql new file mode 100644 index 0000000000000..7de384179a854 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/scalar-subquery.sql @@ -0,0 +1,10 @@ +CREATE OR REPLACE TEMPORARY VIEW p AS VALUES (1, 1) AS T(pk, pv); +CREATE OR REPLACE TEMPORARY VIEW c AS VALUES (1, 1) AS T(ck, cv); + +-- SPARK-18814: Simplified version of TPCDS-Q32 +SELECT pk, cv +FROM p, c +WHERE p.pk = c.ck +AND c.cv = (SELECT avg(c1.cv) + FROM c c1 + WHERE c1.ck = p.pk); diff --git a/sql/core/src/test/resources/sql-tests/results/scalar-subquery.sql.out b/sql/core/src/test/resources/sql-tests/results/scalar-subquery.sql.out new file mode 100644 index 0000000000000..502b12939db7c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/scalar-subquery.sql.out @@ -0,0 +1,31 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 3 + + +-- !query 0 +CREATE OR REPLACE TEMPORARY VIEW p AS VALUES (1, 1) AS T(pk, pv) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE OR REPLACE TEMPORARY VIEW c AS VALUES (1, 1) AS T(ck, cv) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +SELECT pk, cv +FROM p, c +WHERE p.pk = c.ck +AND c.cv = (SELECT avg(c1.cv) + FROM c c1 + WHERE c1.ck = p.pk) +-- !query 2 schema +struct +-- !query 2 output +1 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 2a519284c10b9..9833d55c0ca99 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -491,18 +491,7 @@ class SubquerySuite extends QueryTest with SharedSQLContext { sql("select (select sum(-1) from t t2 where t1.c2 = t2.c1 group by t2.c2) sum from t t1") } assert(errMsg.getMessage.contains( - "a GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns:")) - } - } - - test("SPARK-18814 extra GROUP BY column in correlated scalar subquery is not permitted") { - withTempView("p", "c") { - Seq((1, 1)).toDF("pk", "pv").createOrReplaceTempView("p") - Seq((1, 1)).toDF("ck", "cv").createOrReplaceTempView("c") - checkAnswer( - sql("select pk, cv from p,c where p.pk=c.ck and " + - "c.cv = (select avg(c1.cv) from c c1 where c1.ck = p.pk)"), - Row(1, 1) :: Nil) + "A GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns:")) } } @@ -516,7 +505,7 @@ class SubquerySuite extends QueryTest with SharedSQLContext { sql("select a, (select b from l l2 where l2.a = l1.a group by 1) sum_b from l l1") } assert(msg2.getMessage.contains( - "a GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns:")) + "A GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns:")) } test("non-equal correlated scalar subquery") { From f88a2058f3d16e28ac68f771de80126d245e8adf Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Mon, 12 Dec 2016 09:25:52 -0500 Subject: [PATCH 12/16] Address @gatorsmile's 2nd round comments --- .../sql/catalyst/analysis/Analyzer.scala | 25 ++++++++----------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9d1896d043a51..6ebd47ab8443d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1235,26 +1235,23 @@ class Analyzer( private def rewriteSubQuery( sub: LogicalPlan, outer: Seq[LogicalPlan], - scalarSubq: Boolean = false): (LogicalPlan, Seq[Expression]) = { + isScalarSubq: Boolean): (LogicalPlan, Seq[Expression]) = { // Pull out the tagged predicates and rewrite the subquery in the process. val (basePlan, baseConditions) = pullOutCorrelatedPredicates(sub) // SPARK-18504/SPARK-18814: // Block cases where GROUP BY columns are not part of the correlated columns // of a scalar subquery. - if (scalarSubq) { - sub match { - case a @ Aggregate(grouping, _, _) => - val groupByCols = ExpressionSet(grouping.flatMap(_.references)) - val conditionsCols = ExpressionSet(baseConditions.flatMap(_.references)) - val invalidCols = groupByCols.diff(conditionsCols) - if (invalidCols.nonEmpty) { - failAnalysis("A GROUP BY clause in a scalar correlated subquery " + - "cannot contain non-correlated columns: " + - invalidCols.mkString(",")) - } - case _ => None - } + sub collect { + case a @ Aggregate(grouping, _, _) if (isScalarSubq) => + val groupByCols = ExpressionSet(grouping.flatMap(_.references)) + val conditionsCols = ExpressionSet(baseConditions.flatMap(_.references)) + val invalidCols = groupByCols.diff(conditionsCols) + if (invalidCols.nonEmpty) { + failAnalysis("A GROUP BY clause in a scalar correlated subquery " + + "cannot contain non-correlated columns: " + + invalidCols.mkString(",")) + } } // Make sure the inner and the outer query attributes do not collide. val outputSet = outer.map(_.outputSet).reduce(_ ++ _) From ca1dc96e8a4e7a4ae3b7036d5bc598a178228694 Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Mon, 12 Dec 2016 10:49:40 -0500 Subject: [PATCH 13/16] Address @gatorsmile's 3rd round comments --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6ebd47ab8443d..bea41ee4d54bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1243,7 +1243,7 @@ class Analyzer( // Block cases where GROUP BY columns are not part of the correlated columns // of a scalar subquery. sub collect { - case a @ Aggregate(grouping, _, _) if (isScalarSubq) => + case a @ Aggregate(grouping, _, _) if isScalarSubq => val groupByCols = ExpressionSet(grouping.flatMap(_.references)) val conditionsCols = ExpressionSet(baseConditions.flatMap(_.references)) val invalidCols = groupByCols.diff(conditionsCols) From 724335a9e08bc0e8979fbd5de8bcc046f70d87f4 Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Mon, 12 Dec 2016 10:51:15 -0500 Subject: [PATCH 14/16] Address @gatorsmile's 3rd round comments(2) --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index bea41ee4d54bf..12a95627632d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1243,7 +1243,7 @@ class Analyzer( // Block cases where GROUP BY columns are not part of the correlated columns // of a scalar subquery. sub collect { - case a @ Aggregate(grouping, _, _) if isScalarSubq => + case Aggregate(grouping, _, _) if isScalarSubq => val groupByCols = ExpressionSet(grouping.flatMap(_.references)) val conditionsCols = ExpressionSet(baseConditions.flatMap(_.references)) val invalidCols = groupByCols.diff(conditionsCols) From 6040dcff767d8e8d71ef1c42151315abcf7b6f1d Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Tue, 13 Dec 2016 14:47:42 -0500 Subject: [PATCH 15/16] Code the fix based on @hvanhovell's solution --- .../sql/catalyst/analysis/Analyzer.scala | 27 +++----------- .../sql/catalyst/analysis/CheckAnalysis.scala | 35 +++++++++++++++++-- .../sql-tests/inputs/scalar-subquery.sql | 12 ++++++- .../sql-tests/results/scalar-subquery.sql.out | 17 ++++++++- .../org/apache/spark/sql/SubquerySuite.scala | 2 +- 5 files changed, 65 insertions(+), 28 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 12a95627632d7..ed6e17a8eb465 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1225,34 +1225,15 @@ class Analyzer( } /** - * Rewrite the subquery in a safe way by preventing that the subquery and - * the outer use the same attributes. - * - * If this is a scalar subquery, check that GROUP BY columns are a subset - * of the columns used in the correlated predicate(s). Otherwise, pulling up - * correlated predicates could cause incorrect results. + * Rewrite the subquery in a safe way by preventing that the subquery and the outer use the same + * attributes. */ private def rewriteSubQuery( sub: LogicalPlan, - outer: Seq[LogicalPlan], - isScalarSubq: Boolean): (LogicalPlan, Seq[Expression]) = { + outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = { // Pull out the tagged predicates and rewrite the subquery in the process. val (basePlan, baseConditions) = pullOutCorrelatedPredicates(sub) - // SPARK-18504/SPARK-18814: - // Block cases where GROUP BY columns are not part of the correlated columns - // of a scalar subquery. - sub collect { - case Aggregate(grouping, _, _) if isScalarSubq => - val groupByCols = ExpressionSet(grouping.flatMap(_.references)) - val conditionsCols = ExpressionSet(baseConditions.flatMap(_.references)) - val invalidCols = groupByCols.diff(conditionsCols) - if (invalidCols.nonEmpty) { - failAnalysis("A GROUP BY clause in a scalar correlated subquery " + - "cannot contain non-correlated columns: " + - invalidCols.mkString(",")) - } - } // Make sure the inner and the outer query attributes do not collide. val outputSet = outer.map(_.outputSet).reduce(_ ++ _) val duplicates = basePlan.outputSet.intersect(outputSet) @@ -1317,7 +1298,7 @@ class Analyzer( s"does not match the required number of columns ($requiredColumns)") } // Pullout predicates and construct a new plan. - f.tupled(rewriteSubQuery(current, plans, e.isInstanceOf[ScalarSubquery])) + f.tupled(rewriteSubQuery(current, plans)) } else { e.withNewPlan(current) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index caf4a5ad35831..37a2311d9991c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -124,6 +124,12 @@ trait CheckAnalysis extends PredicateHelper { s"Scalar subquery must return only one column, but got ${query.output.size}") case s @ ScalarSubquery(query, conditions, _) if conditions.nonEmpty => + + // Collect the columns from the subquery for further checking. + var subqueryColumns = conditions.flatMap(_.references).collect { + case xs if query.output.contains(xs) => + xs + } def checkAggregate(agg: Aggregate): Unit = { // Make sure correlated scalar subqueries contain one row for every outer row by // enforcing that they are aggregates which contain exactly one aggregate expressions. @@ -135,12 +141,37 @@ trait CheckAnalysis extends PredicateHelper { if (aggregates.isEmpty) { failAnalysis("The output of a correlated scalar subquery must be aggregated") } + + // SPARK-18504/SPARK-18814: Block cases where GROUP BY columns + // are not part of the correlated columns. + val groupByCols = ExpressionSet(agg.groupingExpressions.flatMap(_.references)) + val correlatedCols = ExpressionSet(subqueryColumns) + val invalidCols = groupByCols.diff(correlatedCols) + // GROUP BY columns must be a subset of columns in the predicates + if (invalidCols.nonEmpty) { + failAnalysis( + "A GROUP BY clause in a scalar correlated subquery " + + "cannot contain non-correlated columns: " + + invalidCols.mkString(",")) + } } - // Skip projects and subquery aliases added by the Analyzer and the SQLBuilder. + // Skip subquery aliases added by the Analyzer and the SQLBuilder. + // For projects, do the necessary mapping and skip to its child. def cleanQuery(p: LogicalPlan): LogicalPlan = p match { case s: SubqueryAlias => cleanQuery(s.child) - case p: Project => cleanQuery(p.child) + case p: Project => + // SPARK-18814: Map any aliases to their AttributeReference children + // for the checking in the Aggregate operators below this Project. + subqueryColumns = subqueryColumns.map { + case xs => + p.projectList.collectFirst { + case e @ Alias(child : AttributeReference, _) if e.toAttribute equals xs => + child + }.getOrElse(xs) + } + + cleanQuery(p.child) case child => child } diff --git a/sql/core/src/test/resources/sql-tests/inputs/scalar-subquery.sql b/sql/core/src/test/resources/sql-tests/inputs/scalar-subquery.sql index 7de384179a854..3acc9db09cb80 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/scalar-subquery.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/scalar-subquery.sql @@ -1,10 +1,20 @@ CREATE OR REPLACE TEMPORARY VIEW p AS VALUES (1, 1) AS T(pk, pv); CREATE OR REPLACE TEMPORARY VIEW c AS VALUES (1, 1) AS T(ck, cv); --- SPARK-18814: Simplified version of TPCDS-Q32 +-- SPARK-18814.1: Simplified version of TPCDS-Q32 SELECT pk, cv FROM p, c WHERE p.pk = c.ck AND c.cv = (SELECT avg(c1.cv) FROM c c1 WHERE c1.ck = p.pk); + +-- SPARK-18814.2: Adding stack of aggregates +SELECT pk, cv +FROM p, c +WHERE p.pk = c.ck +AND c.cv = (SELECT max(avg) + FROM (SELECT c1.cv, avg(c1.cv) avg + FROM c c1 + WHERE c1.ck = p.pk + GROUP BY c1.cv)); diff --git a/sql/core/src/test/resources/sql-tests/results/scalar-subquery.sql.out b/sql/core/src/test/resources/sql-tests/results/scalar-subquery.sql.out index 502b12939db7c..c249329d6a61c 100644 --- a/sql/core/src/test/resources/sql-tests/results/scalar-subquery.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/scalar-subquery.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 3 +-- Number of queries: 4 -- !query 0 @@ -29,3 +29,18 @@ AND c.cv = (SELECT avg(c1.cv) struct -- !query 2 output 1 1 + + +-- !query 3 +SELECT pk, cv +FROM p, c +WHERE p.pk = c.ck +AND c.cv = (SELECT max(avg) + FROM (SELECT c1.cv, avg(c1.cv) avg + FROM c c1 + WHERE c1.ck = p.pk + GROUP BY c1.cv)) +-- !query 3 schema +struct +-- !query 3 output +1 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 9833d55c0ca99..5a4b1cfe95e27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -505,7 +505,7 @@ class SubquerySuite extends QueryTest with SharedSQLContext { sql("select a, (select b from l l2 where l2.a = l1.a group by 1) sum_b from l l1") } assert(msg2.getMessage.contains( - "A GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns:")) + "The output of a correlated scalar subquery must be aggregated")) } test("non-equal correlated scalar subquery") { From 0b6bfd4ba4b4a4b445a0f44441518cb487c03f78 Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Tue, 13 Dec 2016 18:47:20 -0500 Subject: [PATCH 16/16] Address @hvanhovell's 2nd comments --- .../sql/catalyst/analysis/CheckAnalysis.scala | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 37a2311d9991c..aa77a6efef347 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -126,10 +126,8 @@ trait CheckAnalysis extends PredicateHelper { case s @ ScalarSubquery(query, conditions, _) if conditions.nonEmpty => // Collect the columns from the subquery for further checking. - var subqueryColumns = conditions.flatMap(_.references).collect { - case xs if query.output.contains(xs) => - xs - } + var subqueryColumns = conditions.flatMap(_.references).filter(query.output.contains) + def checkAggregate(agg: Aggregate): Unit = { // Make sure correlated scalar subqueries contain one row for every outer row by // enforcing that they are aggregates which contain exactly one aggregate expressions. @@ -144,9 +142,9 @@ trait CheckAnalysis extends PredicateHelper { // SPARK-18504/SPARK-18814: Block cases where GROUP BY columns // are not part of the correlated columns. - val groupByCols = ExpressionSet(agg.groupingExpressions.flatMap(_.references)) - val correlatedCols = ExpressionSet(subqueryColumns) - val invalidCols = groupByCols.diff(correlatedCols) + val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references)) + val correlatedCols = AttributeSet(subqueryColumns) + val invalidCols = groupByCols -- correlatedCols // GROUP BY columns must be a subset of columns in the predicates if (invalidCols.nonEmpty) { failAnalysis( @@ -164,11 +162,10 @@ trait CheckAnalysis extends PredicateHelper { // SPARK-18814: Map any aliases to their AttributeReference children // for the checking in the Aggregate operators below this Project. subqueryColumns = subqueryColumns.map { - case xs => - p.projectList.collectFirst { - case e @ Alias(child : AttributeReference, _) if e.toAttribute equals xs => - child - }.getOrElse(xs) + xs => p.projectList.collectFirst { + case e @ Alias(child : AttributeReference, _) if e.exprId == xs.exprId => + child + }.getOrElse(xs) } cleanQuery(p.child)