Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,9 @@ trait CheckAnalysis extends PredicateHelper {
case f @ Filter(condition, child) =>
splitConjunctivePredicates(condition).foreach {
case _: PredicateSubquery | Not(_: PredicateSubquery) =>
case e if PredicateSubquery.hasPredicateSubquery(e) =>
failAnalysis(s"Predicate sub-queries cannot be used in nested conditions: $e")
case e if PredicateSubquery.hasNullAwarePredicateWithinNot(e) =>
failAnalysis(s"Null-aware predicate sub-queries cannot be used in nested" +
s" conditions: $e")
case e =>
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ case class PredicateSubquery(
extends SubqueryExpression with Predicate with Unevaluable {
override lazy val resolved = childrenResolved && query.resolved
override lazy val references: AttributeSet = super.references -- query.outputSet
override def nullable: Boolean = false
override def nullable: Boolean = nullAware
override def plan: LogicalPlan = SubqueryAlias(toString, query)
override def withNewPlan(plan: LogicalPlan): PredicateSubquery = copy(query = plan)
override def toString: String = s"predicate-subquery#${exprId.id} $conditionString"
Expand All @@ -105,6 +105,19 @@ object PredicateSubquery {
case _ => false
}.isDefined
}

/**
* Returns whether there are any null-aware predicate subqueries inside Not. If not, we could
* turn the null-aware predicate into not-null-aware predicate.
*/
def hasNullAwarePredicateWithinNot(e: Expression): Boolean = {
e.find{ x =>
x.isInstanceOf[Not] && e.find {
case p: PredicateSubquery => p.nullAware
case _ => false
}.isDefined
}.isDefined
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,18 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
EliminateSorts,
SimplifyCasts,
SimplifyCaseConversionExpressions,
EliminateSerialization,
RewritePredicateSubquery) ::
EliminateSerialization) ::
Batch("Decimal Optimizations", fixedPoint,
DecimalAggregates) ::
Batch("Typed Filter Optimization", fixedPoint,
EmbedSerializerInFilter) ::
Batch("LocalRelation", fixedPoint,
ConvertToLocalRelation) ::
Batch("OptimizeCodegen", Once,
OptimizeCodegen(conf)) :: Nil
OptimizeCodegen(conf)) ::
Batch("RewriteSubquery", Once,
RewritePredicateSubquery,
CollapseProject) :: Nil
}

/**
Expand Down Expand Up @@ -1077,7 +1079,14 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
def createOrderedJoin(input: Seq[LogicalPlan], conditions: Seq[Expression]): LogicalPlan = {
assert(input.size >= 2)
if (input.size == 2) {
Join(input(0), input(1), Inner, conditions.reduceLeftOption(And))
val (joinConditions, others) = conditions.partition(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: It might be easier to flip the names and call PredicateSubquery.hasPredicateSubquery directly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's weird to see others come before joinConditions, so I make it this way.

e => !PredicateSubquery.hasPredicateSubquery(e))
val join = Join(input(0), input(1), Inner, joinConditions.reduceLeftOption(And))
if (others.nonEmpty) {
Filter(others.reduceLeft(And), join)
} else {
join
}
} else {
val left :: rest = input.toList
// find out the first join that have at least one join condition
Expand All @@ -1090,7 +1099,8 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
val right = conditionalJoin.getOrElse(rest.head)

val joinedRefs = left.outputSet ++ right.outputSet
val (joinConditions, others) = conditions.partition(_.references.subsetOf(joinedRefs))
val (joinConditions, others) = conditions.partition(
e => e.references.subsetOf(joinedRefs) && !PredicateSubquery.hasPredicateSubquery(e))
val joined = Join(left, right, Inner, joinConditions.reduceLeftOption(And))

// should not have reference to same logical plan
Expand Down Expand Up @@ -1200,9 +1210,16 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
val newRight = rightFilterConditions.
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
val newJoinCond = (commonFilterCondition ++ joinCondition).reduceLeftOption(And)
val (newJoinConditions, others) =
commonFilterCondition.partition(e => !PredicateSubquery.hasPredicateSubquery(e))
val newJoinCond = (newJoinConditions ++ joinCondition).reduceLeftOption(And)

Join(newLeft, newRight, Inner, newJoinCond)
val join = Join(newLeft, newRight, Inner, newJoinCond)
if (others.nonEmpty) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2nd time you need this. Almost warrants an inner method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Repeat it only twice, I think it's OK.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm - two different rules.

Filter(others.reduceLeft(And), join)
} else {
join
}
case RightOuter =>
// push down the right side only `where` condition
val newLeft = left
Expand Down Expand Up @@ -1530,6 +1547,16 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
// Note that will almost certainly be planned as a Broadcast Nested Loop join. Use EXISTS
// if performance matters to you.
Join(p, sub, LeftAnti, Option(Or(anyNull, condition)))
case (p, predicate) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't plan these joins outside of filters right? So this is not working yet:

select a.*, a.value in (select value from b) as in_b from a

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. If that's needed, we could support that as follow-up PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow-up PR works.

var joined = p
val replaced = predicate transformUp {
case PredicateSubquery(sub, conditions, nullAware, _) =>
// TODO: support null-aware join
val exists = AttributeReference("exists", BooleanType, false)()
joined = Join(joined, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))
exists
}
Project(p.output, Filter(replaced, joined))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.Attribute

object JoinType {
def apply(typ: String): JoinType = typ.toLowerCase.replace("_", "") match {
Expand Down Expand Up @@ -69,6 +70,14 @@ case object LeftAnti extends JoinType {
override def sql: String = "LEFT ANTI"
}

case class ExistenceJoin(exists: Attribute) extends JoinType {
override def sql: String = {
// This join type is only used in the end of optimizer and physical plans, we will not
// generate SQL for this join type
throw new UnsupportedOperationException
}
}

case class NaturalJoin(tpe: JoinType) extends JoinType {
require(Seq(Inner, LeftOuter, RightOuter, FullOuter).contains(tpe),
"Unsupported natural join type " + tpe)
Expand All @@ -84,6 +93,7 @@ case class UsingJoin(tpe: JoinType, usingColumns: Seq[UnresolvedAttribute]) exte
object LeftExistence {
def unapply(joinType: JoinType): Option[JoinType] = joinType match {
case LeftSemi | LeftAnti => Some(joinType)
case j: ExistenceJoin => Some(joinType)
case _ => None
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,8 @@ case class Join(

override def output: Seq[Attribute] = {
joinType match {
case j: ExistenceJoin =>
left.output :+ j.exists
case LeftExistence(_) =>
left.output
case LeftOuter =>
Expand All @@ -295,6 +297,8 @@ case class Join(
case LeftSemi if condition.isDefined =>
left.constraints
.union(splitConjunctivePredicates(condition.get).toSet)
case j: ExistenceJoin =>
left.constraints
case Inner =>
left.constraints.union(right.constraints)
case LeftExistence(_) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,11 +459,14 @@ class AnalysisErrorSuite extends AnalysisTest {
val a = AttributeReference("a", IntegerType)()
val b = AttributeReference("b", IntegerType)()
val c = AttributeReference("c", BooleanType)()
val plan1 = Filter(Cast(In(a, Seq(ListQuery(LocalRelation(b)))), BooleanType), LocalRelation(a))
assertAnalysisError(plan1, "Predicate sub-queries cannot be used in nested conditions" :: Nil)
val plan1 = Filter(Cast(Not(In(a, Seq(ListQuery(LocalRelation(b))))), BooleanType),
LocalRelation(a))
assertAnalysisError(plan1,
"Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil)

val plan2 = Filter(Or(In(a, Seq(ListQuery(LocalRelation(b)))), c), LocalRelation(a, c))
assertAnalysisError(plan2, "Predicate sub-queries cannot be used in nested conditions" :: Nil)
val plan2 = Filter(Or(Not(In(a, Seq(ListQuery(LocalRelation(b))))), c), LocalRelation(a, c))
assertAnalysisError(plan2,
"Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil)
}

test("PredicateSubQuery correlated predicate is nested in an illegal plan") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {

private def canBuildRight(joinType: JoinType): Boolean = joinType match {
case Inner | LeftOuter | LeftSemi | LeftAnti => true
case j: ExistenceJoin => true
case _ => false
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ case class BroadcastHashJoinExec(
override private[sql] lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))

override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning

override def requiredChildDistribution: Seq[Distribution] = {
val mode = HashedRelationBroadcastMode(buildKeys)
buildSide match {
Expand Down Expand Up @@ -85,6 +83,7 @@ case class BroadcastHashJoinExec(
case LeftOuter | RightOuter => codegenOuter(ctx, input)
case LeftSemi => codegenSemi(ctx, input)
case LeftAnti => codegenAnti(ctx, input)
case j: ExistenceJoin => codegenExistence(ctx, input)
case x =>
throw new IllegalArgumentException(
s"BroadcastHashJoin should not take $x as the JoinType")
Expand Down Expand Up @@ -407,4 +406,67 @@ case class BroadcastHashJoinExec(
""".stripMargin
}
}

/**
* Generates the code for existence join.
*/
private def codegenExistence(ctx: CodegenContext, input: Seq[ExprCode]): String = {
val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
val numOutput = metricTerm(ctx, "numOutputRows")
val existsVar = ctx.freshName("exists")

val matched = ctx.freshName("matched")
val buildVars = genBuildSideVars(ctx, matched)
val checkCondition = if (condition.isDefined) {
val expr = condition.get
// evaluate the variables from build side that used by condition
val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references)
// filter the output via condition
ctx.currentVars = input ++ buildVars
val ev =
BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx)
s"""
|$eval
|${ev.code}
|$existsVar = !${ev.isNull} && ${ev.value};
""".stripMargin
} else {
s"$existsVar = true;"
}

val resultVar = input ++ Seq(ExprCode("", "false", existsVar))
if (broadcastRelation.value.keyIsUnique) {
s"""
|// generate join key for stream side
|${keyEv.code}
|// find matches from HashedRelation
|UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
|boolean $existsVar = false;
|if ($matched != null) {
| $checkCondition
|}
|$numOutput.add(1);
|${consume(ctx, resultVar)}
""".stripMargin
} else {
val matches = ctx.freshName("matches")
val iteratorCls = classOf[Iterator[UnsafeRow]].getName
s"""
|// generate join key for stream side
|${keyEv.code}
|// find matches from HashRelation
|$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value});
|boolean $existsVar = false;
|if ($matches != null) {
| while (!$existsVar && $matches.hasNext()) {
| UnsafeRow $matched = (UnsafeRow) $matches.next();
| $checkCondition
| }
|}
|$numOutput.add(1);
|${consume(ctx, resultVar)}
""".stripMargin
}
}
}
Loading