Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1081,10 +1081,10 @@ class Analyzer(
// Step 2: Pull out the predicates if the plan is resolved.
if (current.resolved) {
// Make sure the resolved query has the required number of output columns. This is only
// needed for IN expressions.
// needed for Scalar and IN subqueries.
if (requiredColumns > 0 && requiredColumns != current.output.size) {
failAnalysis(s"The number of fields in the value ($requiredColumns) does not " +
s"match with the number of columns in the subquery (${current.output.size})")
failAnalysis(s"The number of columns in the subquery (${current.output.size}) " +
s"does not match the required number of columns ($requiredColumns)")
}
// Pullout predicates and construct a new plan.
f.tupled(rewriteSubQuery(current, plans))
Expand All @@ -1099,8 +1099,11 @@ class Analyzer(
*/
private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = {
plan transformExpressions {
case s @ ScalarSubquery(sub, conditions, exprId)
if sub.resolved && conditions.isEmpty && sub.output.size != 1 =>
failAnalysis(s"Scalar subquery must return only one column, but got ${sub.output.size}")
case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved =>
resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId))
resolveSubQuery(s, plans, 1)(ScalarSubquery(_, _, exprId))
case e @ Exists(sub, exprId) =>
resolveSubQuery(e, plans)(PredicateSubquery(_, _, nullAware = false, exprId))
case In(e, Seq(l @ ListQuery(_, exprId))) if e.resolved =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.{Inner, RightOuter, UsingJoin}
import org.apache.spark.sql.catalyst.plans.UsingJoin
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -60,9 +60,6 @@ trait CheckAnalysis extends PredicateHelper {
val from = operator.inputSet.map(_.name).mkString(", ")
a.failAnalysis(s"cannot resolve '${a.sql}' given input columns: [$from]")

case ScalarSubquery(_, conditions, _) if conditions.nonEmpty =>
failAnalysis("Correlated scalar subqueries are not supported.")

case e: Expression if e.checkInputDataTypes().isFailure =>
e.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckFailure(message) =>
Expand Down Expand Up @@ -104,6 +101,36 @@ trait CheckAnalysis extends PredicateHelper {
failAnalysis(s"Window specification $s is not valid because $m")
case None => w
}

case s @ ScalarSubquery(query, conditions, _) if conditions.nonEmpty =>
// Make sure we are using equi-joins.
conditions.foreach {
case _: EqualTo | _: EqualNullSafe => // ok
case e => failAnalysis(
s"The correlated scalar subquery can only contain equality predicates: $e")
}

// Make sure correlated scalar subqueries contain one row for every outer row by
// enforcing that they are aggregates which contain exactly one aggregate expressions.
// The analyzer has already checked that subquery contained only one output column, and
// added all the grouping expressions to the aggregate.
def checkAggregate(a: Aggregate): Unit = {
val aggregates = a.expressions.flatMap(_.collect {
case a: AggregateExpression => a
})
if (aggregates.isEmpty) {
failAnalysis("The output of a correlated scalar subquery must be aggregated")
}
}

query match {
case a: Aggregate => checkAggregate(a)
case Filter(_, a: Aggregate) => checkAggregate(a)
case Project(_, a: Aggregate) => checkAggregate(a)
case Project(_, Filter(_, a: Aggregate)) => checkAggregate(a)
case fail => failAnalysis(s"Correlated scalar subqueries must be Aggregated: $fail")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it have an Filter on top of Aggregate (HAVING clause)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure I'll add it.

}
s
}

operator match {
Expand Down Expand Up @@ -220,6 +247,13 @@ trait CheckAnalysis extends PredicateHelper {
| but one table has '${firstError.output.length}' columns and another table has
| '${s.children.head.output.length}' columns""".stripMargin)

case p if p.expressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) =>
p match {
case _: Filter | _: Aggregate | _: Project => // Ok
case other => failAnalysis(
s"Correlated scalar sub-queries can only be used in a Filter/Aggregate/Project: $p")
}

case p if p.expressions.exists(PredicateSubquery.hasPredicateSubquery) =>
failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ abstract class SubqueryExpression extends Expression {
protected def conditionString: String = children.mkString("[", " && ", "]")
}

object SubqueryExpression {
def hasCorrelatedSubquery(e: Expression): Boolean = {
e.find {
case e: SubqueryExpression if e.children.nonEmpty => true
case _ => false
}.isDefined
}
}

/**
* A subquery that will return only one row and one column. This will be converted into a physical
* scalar subquery during planning.
Expand All @@ -55,28 +64,26 @@ case class ScalarSubquery(
children: Seq[Expression] = Seq.empty,
exprId: ExprId = NamedExpression.newExprId)
extends SubqueryExpression with Unevaluable {

override def plan: LogicalPlan = SubqueryAlias(toString, query)

override lazy val resolved: Boolean = childrenResolved && query.resolved

override def dataType: DataType = query.schema.fields.head.dataType

override def checkInputDataTypes(): TypeCheckResult = {
if (query.schema.length != 1) {
TypeCheckResult.TypeCheckFailure("Scalar subquery must return only one column, but got " +
query.schema.length.toString)
} else {
TypeCheckResult.TypeCheckSuccess
}
override lazy val references: AttributeSet = {
if (query.resolved) super.references -- query.outputSet
else super.references
}

override def dataType: DataType = query.schema.fields.head.dataType
override def foldable: Boolean = false
override def nullable: Boolean = true

override def plan: LogicalPlan = SubqueryAlias(toString, query)
override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(query = plan)
override def toString: String = s"scalar-subquery#${exprId.id} $conditionString"
}

override def toString: String = s"subquery#${exprId.id} $conditionString"
object ScalarSubquery {
def hasCorrelatedScalarSubquery(e: Expression): Boolean = {
e.find {
case e: ScalarSubquery if e.children.nonEmpty => true
case _ => false
}.isDefined
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer

import scala.annotation.tailrec
import scala.collection.immutable.HashSet
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf}
import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, DistinctAggregationRewriter, EliminateSubqueryAliases, EmptyFunctionRegistry}
Expand Down Expand Up @@ -100,6 +101,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
EliminateSorts,
SimplifyCasts,
SimplifyCaseConversionExpressions,
RewriteCorrelatedScalarSubquery,
EliminateSerialization) ::
Batch("Decimal Optimizations", fixedPoint,
DecimalAggregates) ::
Expand Down Expand Up @@ -1081,7 +1083,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
assert(input.size >= 2)
if (input.size == 2) {
val (joinConditions, others) = conditions.partition(
e => !PredicateSubquery.hasPredicateSubquery(e))
e => !SubqueryExpression.hasCorrelatedSubquery(e))
val join = Join(input(0), input(1), Inner, joinConditions.reduceLeftOption(And))
if (others.nonEmpty) {
Filter(others.reduceLeft(And), join)
Expand All @@ -1101,7 +1103,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {

val joinedRefs = left.outputSet ++ right.outputSet
val (joinConditions, others) = conditions.partition(
e => e.references.subsetOf(joinedRefs) && !PredicateSubquery.hasPredicateSubquery(e))
e => e.references.subsetOf(joinedRefs) && !SubqueryExpression.hasCorrelatedSubquery(e))
val joined = Join(left, right, Inner, joinConditions.reduceLeftOption(And))

// should not have reference to same logical plan
Expand Down Expand Up @@ -1134,7 +1136,7 @@ object OuterJoinElimination extends Rule[LogicalPlan] with PredicateHelper {
* Returns whether the expression returns null or false when all inputs are nulls.
*/
private def canFilterOutNull(e: Expression): Boolean = {
if (!e.deterministic || PredicateSubquery.hasPredicateSubquery(e)) return false
if (!e.deterministic || SubqueryExpression.hasCorrelatedSubquery(e)) return false
val attributes = e.references.toSeq
val emptyRow = new GenericInternalRow(attributes.length)
val v = BindReferences.bindReference(e, attributes).eval(emptyRow)
Expand Down Expand Up @@ -1203,7 +1205,6 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition)) =>
val (leftFilterConditions, rightFilterConditions, commonFilterCondition) =
split(splitConjunctivePredicates(filterCondition), left, right)

joinType match {
case Inner =>
// push down the single side `where` condition into respective sides
Expand All @@ -1212,7 +1213,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
val newRight = rightFilterConditions.
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
val (newJoinConditions, others) =
commonFilterCondition.partition(e => !PredicateSubquery.hasPredicateSubquery(e))
commonFilterCondition.partition(e => !SubqueryExpression.hasCorrelatedSubquery(e))
val newJoinCond = (newJoinConditions ++ joinCondition).reduceLeftOption(And)

val join = Join(newLeft, newRight, Inner, newJoinCond)
Expand Down Expand Up @@ -1573,3 +1574,74 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
}
}
}

/**
* This rule rewrites correlated [[ScalarSubquery]] expressions into LEFT OUTER joins.
*/
object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
/**
* Extract all correlated scalar subqueries from an expression. The subqueries are collected using
* the given collector. The expression is rewritten and returned.
*/
private def extractCorrelatedScalarSubqueries[E <: Expression](
expression: E,
subqueries: ArrayBuffer[ScalarSubquery]): E = {
val newExpression = expression transform {
case s: ScalarSubquery if s.children.nonEmpty =>
subqueries += s
s.query.output.head
}
newExpression.asInstanceOf[E]
}

/**
* Construct a new child plan by left joining the given subqueries to a base plan.
*/
private def constructLeftJoins(
child: LogicalPlan,
subqueries: ArrayBuffer[ScalarSubquery]): LogicalPlan = {
subqueries.foldLeft(child) {
case (currentChild, ScalarSubquery(query, conditions, _)) =>
Project(
currentChild.output :+ query.output.head,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we know that the query can have only one output column, then this Project is not needed, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The query contains both the column we are interested in and the join columns. We don't want those, so we remove them

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks

Join(currentChild, query, LeftOuter, conditions.reduceOption(And)))
}
}

/**
* Rewrite [[Filter]], [[Project]] and [[Aggregate]] plans containing correlated scalar
* subqueries.
*/
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case a @ Aggregate(grouping, expressions, child) =>
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
if (subqueries.nonEmpty) {
// We currently only allow correlated subqueries in an aggregate if they are part of the
// grouping expressions. As a result we need to replace all the scalar subqueries in the
// grouping expressions by their result.
val newGrouping = grouping.map { e =>
subqueries.find(_.semanticEquals(e)).map(_.query.output.head).getOrElse(e)
}
Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries))
} else {
a
}
case p @ Project(expressions, child) =>
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
if (subqueries.nonEmpty) {
Project(newExpressions, constructLeftJoins(child, subqueries))
} else {
p
}
case f @ Filter(condition, child) =>
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries)
if (subqueries.nonEmpty) {
Project(f.output, Filter(newCondition, constructLeftJoins(child, subqueries)))
} else {
f
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ case class Filter(condition: Expression, child: LogicalPlan)

override protected def validConstraints: Set[Expression] = {
val predicates = splitConjunctivePredicates(condition)
.filterNot(PredicateSubquery.hasPredicateSubquery)
.filterNot(SubqueryExpression.hasCorrelatedSubquery)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@davies I changed the filter to prevent any correlated subquery from being propagated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, I missed this one from latest changes.

child.constraints.union(predicates.toSet)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ class AnalysisErrorSuite extends AnalysisTest {
"scalar subquery with 2 columns",
testRelation.select(
(ScalarSubquery(testRelation.select('a, dateLit.as('b))) + Literal(1)).as('a)),
"Scalar subquery must return only one column, but got 2" :: Nil)
"The number of columns in the subquery (2)" ::
"does not match the required number of columns (1)":: Nil)

errorTest(
"scalar subquery with no column",
Expand Down Expand Up @@ -499,12 +500,4 @@ class AnalysisErrorSuite extends AnalysisTest {
LocalRelation(a))
assertAnalysisError(plan3, "Accessing outer query column is not allowed in" :: Nil)
}

test("Correlated Scalar Subquery") {
val a = AttributeReference("a", IntegerType)()
val b = AttributeReference("b", IntegerType)()
val sub = Project(Seq(b), Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b)))
val plan = Project(Seq(a, Alias(ScalarSubquery(sub), "b")()), LocalRelation(a))
assertAnalysisError(plan, "Correlated scalar subqueries are not supported." :: Nil)
}
}
47 changes: 47 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -234,4 +234,51 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
sql("select a from l group by 1 having exists (select 1 from r where d < min(b))"),
Row(null) :: Row(1) :: Row(3) :: Nil)
}

test("correlated scalar subquery in where") {
checkAnswer(
sql("select * from l where b < (select max(d) from r where a = c)"),
Row(2, 1.0) :: Row(2, 1.0) :: Nil)
}

test("correlated scalar subquery in select") {
checkAnswer(
sql("select a, (select sum(b) from l l2 where l2.a = l1.a) sum_b from l l1"),
Row(1, 4.0) :: Row(1, 4.0) :: Row(2, 2.0) :: Row(2, 2.0) :: Row(3, 3.0) ::
Row(null, null) :: Row(null, null) :: Row(6, null) :: Nil)
}

test("correlated scalar subquery in select (null safe)") {
checkAnswer(
sql("select a, (select sum(b) from l l2 where l2.a <=> l1.a) sum_b from l l1"),
Row(1, 4.0) :: Row(1, 4.0) :: Row(2, 2.0) :: Row(2, 2.0) :: Row(3, 3.0) ::
Row(null, 5.0) :: Row(null, 5.0) :: Row(6, null) :: Nil)
}

test("correlated scalar subquery in aggregate") {
checkAnswer(
sql("select a, (select sum(d) from r where a = c) sum_d from l l1 group by 1, 2"),
Row(1, null) :: Row(2, 6.0) :: Row(3, 2.0) :: Row(null, null) :: Row(6, null) :: Nil)
}

test("non-aggregated correlated scalar subquery") {
val msg1 = intercept[AnalysisException] {
sql("select a, (select b from l l2 where l2.a = l1.a) sum_b from l l1")
}
assert(msg1.getMessage.contains("Correlated scalar subqueries must be Aggregated"))

val msg2 = intercept[AnalysisException] {
sql("select a, (select b from l l2 where l2.a = l1.a group by 1) sum_b from l l1")
}
assert(msg2.getMessage.contains(
"The output of a correlated scalar subquery must be aggregated"))
}

test("non-equal correlated scalar subquery") {
val msg1 = intercept[AnalysisException] {
sql("select a, (select b from l l2 where l2.a < l1.a) sum_b from l l1")
}
assert(msg1.getMessage.contains(
"The correlated scalar subquery can only contain equality predicates"))
}
}