Skip to content

Commit

Permalink
[SPARK-18874][SQL] First phase: Deferring the correlated predicate pu…
Browse files Browse the repository at this point in the history
…ll up to Optimizer phase

## What changes were proposed in this pull request?
Currently Analyzer as part of ResolveSubquery, pulls up the correlated predicates to its
originating SubqueryExpression. The subquery plan is then transformed to remove the correlated
predicates after they are moved up to the outer plan. In this PR, the task of pulling up
correlated predicates is deferred to Optimizer. This is the initial work that will allow us to
support the form of correlated subqueries that we don't support today. The design document
from nsyca can be found in the following link :
[DesignDoc](https://docs.google.com/document/d/1QDZ8JwU63RwGFS6KVF54Rjj9ZJyK33d49ZWbjFBaIgU/edit#)

The brief description of code changes (hopefully to aid with code review) can be be found in the
following link:
[CodeChanges](https://docs.google.com/document/d/18mqjhL9V1An-tNta7aVE13HkALRZ5GZ24AATA-Vqqf0/edit#)

## How was this patch tested?
The test case PRs were submitted earlier using.
[16337](#16337) [16759](#16759) [16841](#16841) [16915](#16915) [16798](#16798) [16712](#16712) [16710](#16710) [16760](#16760) [16802](#16802)

Author: Dilip Biswal <dbiswal@us.ibm.com>

Closes #16954 from dilipbiswal/SPARK-18874.
  • Loading branch information
nsyca authored and hvanhovell committed Mar 14, 2017
1 parent f6314ea commit 4ce970d
Show file tree
Hide file tree
Showing 13 changed files with 675 additions and 300 deletions.

Large diffs are not rendered by default.

Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -133,10 +134,8 @@ trait CheckAnalysis extends PredicateHelper {
if (conditions.isEmpty && query.output.size != 1) {
failAnalysis(
s"Scalar subquery must return only one column, but got ${query.output.size}")
} else if (conditions.nonEmpty) {
// Collect the columns from the subquery for further checking.
var subqueryColumns = conditions.flatMap(_.references).filter(query.output.contains)

}
else if (conditions.nonEmpty) {
def checkAggregate(agg: Aggregate): Unit = {
// Make sure correlated scalar subqueries contain one row for every outer row by
// enforcing that they are aggregates containing exactly one aggregate expression.
Expand All @@ -152,6 +151,9 @@ trait CheckAnalysis extends PredicateHelper {
// SPARK-18504/SPARK-18814: Block cases where GROUP BY columns
// are not part of the correlated columns.
val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references))
// Collect the local references from the correlated predicate in the subquery.
val subqueryColumns = getCorrelatedPredicates(query).flatMap(_.references)
.filterNot(conditions.flatMap(_.references).contains)
val correlatedCols = AttributeSet(subqueryColumns)
val invalidCols = groupByCols -- correlatedCols
// GROUP BY columns must be a subset of columns in the predicates
Expand All @@ -167,17 +169,7 @@ trait CheckAnalysis extends PredicateHelper {
// For projects, do the necessary mapping and skip to its child.
def cleanQuery(p: LogicalPlan): LogicalPlan = p match {
case s: SubqueryAlias => cleanQuery(s.child)
case p: Project =>
// SPARK-18814: Map any aliases to their AttributeReference children
// for the checking in the Aggregate operators below this Project.
subqueryColumns = subqueryColumns.map {
xs => p.projectList.collectFirst {
case e @ Alias(child : AttributeReference, _) if e.exprId == xs.exprId =>
child
}.getOrElse(xs)
}

cleanQuery(p.child)
case p: Project => cleanQuery(p.child)
case child => child
}

Expand Down Expand Up @@ -211,14 +203,9 @@ trait CheckAnalysis extends PredicateHelper {
s"filter expression '${f.condition.sql}' " +
s"of type ${f.condition.dataType.simpleString} is not a boolean.")

case Filter(condition, _) =>
splitConjunctivePredicates(condition).foreach {
case _: PredicateSubquery | Not(_: PredicateSubquery) =>
case e if PredicateSubquery.hasNullAwarePredicateWithinNot(e) =>
failAnalysis(s"Null-aware predicate sub-queries cannot be used in nested" +
s" conditions: $e")
case e =>
}
case Filter(condition, _) if hasNullAwarePredicateWithinNot(condition) =>
failAnalysis("Null-aware predicate sub-queries cannot be used in nested " +
s"conditions: $condition")

case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType =>
failAnalysis(
Expand Down Expand Up @@ -306,8 +293,11 @@ trait CheckAnalysis extends PredicateHelper {
s"Correlated scalar sub-queries can only be used in a Filter/Aggregate/Project: $p")
}

case p if p.expressions.exists(PredicateSubquery.hasPredicateSubquery) =>
failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p")
case p if p.expressions.exists(SubqueryExpression.hasInOrExistsSubquery) =>
p match {
case _: Filter => // Ok
case _ => failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p")
}

case _: Union | _: SetOperation if operator.children.length > 1 =>
def dataTypes(plan: LogicalPlan): Seq[DataType] = plan.output.map(_.dataType)
Expand Down
Expand Up @@ -108,6 +108,28 @@ object TypeCoercion {
case _ => None
}

/**
* This function determines the target type of a comparison operator when one operand
* is a String and the other is not. It also handles when one op is a Date and the
* other is a Timestamp by making the target type to be String.
*/
val findCommonTypeForBinaryComparison: (DataType, DataType) => Option[DataType] = {
// We should cast all relative timestamp/date/string comparison into string comparisons
// This behaves as a user would expect because timestamp strings sort lexicographically.
// i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true
case (StringType, DateType) => Some(StringType)
case (DateType, StringType) => Some(StringType)
case (StringType, TimestampType) => Some(StringType)
case (TimestampType, StringType) => Some(StringType)
case (TimestampType, DateType) => Some(StringType)
case (DateType, TimestampType) => Some(StringType)
case (StringType, NullType) => Some(StringType)
case (NullType, StringType) => Some(StringType)
case (l: StringType, r: AtomicType) if r != StringType => Some(r)
case (l: AtomicType, r: StringType) if (l != StringType) => Some(l)
case (l, r) => None
}

/**
* Case 2 type widening (see the classdoc comment above for TypeCoercion).
*
Expand Down Expand Up @@ -305,6 +327,14 @@ object TypeCoercion {
* Promotes strings that appear in arithmetic expressions.
*/
object PromoteStrings extends Rule[LogicalPlan] {
private def castExpr(expr: Expression, targetType: DataType): Expression = {
(expr.dataType, targetType) match {
case (NullType, dt) => Literal.create(null, targetType)
case (l, dt) if (l != dt) => Cast(expr, targetType)
case _ => expr
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
Expand All @@ -321,37 +351,10 @@ object TypeCoercion {
case p @ Equality(left @ TimestampType(), right @ StringType()) =>
p.makeCopy(Array(left, Cast(right, TimestampType)))

// We should cast all relative timestamp/date/string comparison into string comparisons
// This behaves as a user would expect because timestamp strings sort lexicographically.
// i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true
case p @ BinaryComparison(left @ StringType(), right @ DateType()) =>
p.makeCopy(Array(left, Cast(right, StringType)))
case p @ BinaryComparison(left @ DateType(), right @ StringType()) =>
p.makeCopy(Array(Cast(left, StringType), right))
case p @ BinaryComparison(left @ StringType(), right @ TimestampType()) =>
p.makeCopy(Array(left, Cast(right, StringType)))
case p @ BinaryComparison(left @ TimestampType(), right @ StringType()) =>
p.makeCopy(Array(Cast(left, StringType), right))

// Comparisons between dates and timestamps.
case p @ BinaryComparison(left @ TimestampType(), right @ DateType()) =>
p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType)))
case p @ BinaryComparison(left @ DateType(), right @ TimestampType()) =>
p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType)))

// Checking NullType
case p @ BinaryComparison(left @ StringType(), right @ NullType()) =>
p.makeCopy(Array(left, Literal.create(null, StringType)))
case p @ BinaryComparison(left @ NullType(), right @ StringType()) =>
p.makeCopy(Array(Literal.create(null, StringType), right))

// When compare string with atomic type, case string to that type.
case p @ BinaryComparison(left @ StringType(), right @ AtomicType())
if right.dataType != StringType =>
p.makeCopy(Array(Cast(left, right.dataType), right))
case p @ BinaryComparison(left @ AtomicType(), right @ StringType())
if left.dataType != StringType =>
p.makeCopy(Array(left, Cast(right, left.dataType)))
case p @ BinaryComparison(left, right)
if findCommonTypeForBinaryComparison(left.dataType, right.dataType).isDefined =>
val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType).get
p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType)))

case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
Expand All @@ -365,17 +368,72 @@ object TypeCoercion {
}

/**
* Convert the value and in list expressions to the common operator type
* by looking at all the argument types and finding the closest one that
* all the arguments can be cast to. When no common operator type is found
* the original expression will be returned and an Analysis Exception will
* be raised at type checking phase.
* Handles type coercion for both IN expression with subquery and IN
* expressions without subquery.
* 1. In the first case, find the common type by comparing the left hand side (LHS)
* expression types against corresponding right hand side (RHS) expression derived
* from the subquery expression's plan output. Inject appropriate casts in the
* LHS and RHS side of IN expression.
*
* 2. In the second case, convert the value and in list expressions to the
* common operator type by looking at all the argument types and finding
* the closest one that all the arguments can be cast to. When no common
* operator type is found the original expression will be returned and an
* Analysis Exception will be raised at the type checking phase.
*/
object InConversion extends Rule[LogicalPlan] {
private def flattenExpr(expr: Expression): Seq[Expression] = {
expr match {
// Multi columns in IN clause is represented as a CreateNamedStruct.
// flatten the named struct to get the list of expressions.
case cns: CreateNamedStruct => cns.valExprs
case expr => Seq(expr)
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e

// Handle type casting required between value expression and subquery output
// in IN subquery.
case i @ In(a, Seq(ListQuery(sub, children, exprId)))
if !i.resolved && flattenExpr(a).length == sub.output.length =>
// LHS is the value expression of IN subquery.
val lhs = flattenExpr(a)

// RHS is the subquery output.
val rhs = sub.output

val commonTypes = lhs.zip(rhs).flatMap { case (l, r) =>
findCommonTypeForBinaryComparison(l.dataType, r.dataType)
.orElse(findTightestCommonType(l.dataType, r.dataType))
}

// The number of columns/expressions must match between LHS and RHS of an
// IN subquery expression.
if (commonTypes.length == lhs.length) {
val castedRhs = rhs.zip(commonTypes).map {
case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)()
case (e, _) => e
}
val castedLhs = lhs.zip(commonTypes).map {
case (e, dt) if e.dataType != dt => Cast(e, dt)
case (e, _) => e
}

// Before constructing the In expression, wrap the multi values in LHS
// in a CreatedNamedStruct.
val newLhs = castedLhs match {
case Seq(lhs) => lhs
case _ => CreateStruct(castedLhs)
}

In(newLhs, Seq(ListQuery(Project(castedRhs, sub), children, exprId)))
} else {
i
}

case i @ In(a, b) if b.exists(_.dataType != a.dataType) =>
findWiderCommonType(i.children.map(_.dataType)) match {
case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType)))
Expand Down
Expand Up @@ -123,19 +123,44 @@ case class Not(child: Expression)
*/
@ExpressionDescription(
usage = "expr1 _FUNC_(expr2, expr3, ...) - Returns true if `expr` equals to any valN.")
case class In(value: Expression, list: Seq[Expression]) extends Predicate
with ImplicitCastInputTypes {
case class In(value: Expression, list: Seq[Expression]) extends Predicate {

require(list != null, "list should not be null")
override def checkInputDataTypes(): TypeCheckResult = {
list match {
case ListQuery(sub, _, _) :: Nil =>
val valExprs = value match {
case cns: CreateNamedStruct => cns.valExprs
case expr => Seq(expr)
}

override def inputTypes: Seq[AbstractDataType] = value.dataType +: list.map(_.dataType)
val mismatchedColumns = valExprs.zip(sub.output).flatMap {
case (l, r) if l.dataType != r.dataType =>
s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})"
case _ => None
}

override def checkInputDataTypes(): TypeCheckResult = {
if (list.exists(l => l.dataType != value.dataType)) {
TypeCheckResult.TypeCheckFailure(
"Arguments must be same type")
} else {
TypeCheckResult.TypeCheckSuccess
if (mismatchedColumns.nonEmpty) {
TypeCheckResult.TypeCheckFailure(
s"""
|The data type of one or more elements in the left hand side of an IN subquery
|is not compatible with the data type of the output of the subquery
|Mismatched columns:
|[${mismatchedColumns.mkString(", ")}]
|Left side:
|[${valExprs.map(_.dataType.catalogString).mkString(", ")}].
|Right side:
|[${sub.output.map(_.dataType.catalogString).mkString(", ")}].
""".stripMargin)
} else {
TypeCheckResult.TypeCheckSuccess
}
case _ =>
if (list.exists(l => l.dataType != value.dataType)) {
TypeCheckResult.TypeCheckFailure("Arguments must be same type")
} else {
TypeCheckResult.TypeCheckSuccess
}
}
}

Expand Down

0 comments on commit 4ce970d

Please sign in to comment.