Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ulysses-you committed May 21, 2021
1 parent 05e074c commit 48dd92a
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 37 deletions.
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral
import org.apache.spark.sql.catalyst.planning.ExtractSingleColumnNullAwareAntiJoin
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _}
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, TRUE_OR_FALSE_LITERAL}

Expand Down Expand Up @@ -91,28 +91,25 @@ object PropagateEmptyRelationBasic extends Rule[LogicalPlan] {
* - Aggregate with all empty children and at least one grouping expression.
* - Generate(Explode) with all empty children. Others like Hive UDTF may return results.
*/
trait PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSupport {
abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSupport {
/**
* At AQE side, we use this function to check if a plan has output rows or not
*/
protected def checkRowCount: Option[(LogicalPlan, Boolean) => Boolean] = None
protected def checkRowCount(plan: LogicalPlan, hasRow: Boolean): Option[Boolean] = None

/**
* At AQE side, we use the broadcast query stage to do the check
*/
protected def isRelationWithAllNullKeys: Option[LogicalPlan => Boolean] = None
protected def isRelationWithAllNullKeys(plan: LogicalPlan): Option[Boolean] = None

private def isEmptyLocalRelation(plan: LogicalPlan): Boolean = {
val defaultEmptyRelation: Boolean = plan match {
case p: LocalRelation => p.data.isEmpty
case _ => false
}

if (checkRowCount.isDefined) {
checkRowCount.get.apply(plan, false) || defaultEmptyRelation
} else {
defaultEmptyRelation
}
checkRowCount(plan, false).map(_ || defaultEmptyRelation)
.getOrElse(defaultEmptyRelation)
}

private def empty(plan: LogicalPlan) =
Expand All @@ -122,12 +119,9 @@ trait PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSupport {
private def nullValueProjectList(plan: LogicalPlan): Seq[NamedExpression] =
plan.output.map{ a => Alias(cast(Literal(null), a.dataType), a.name)(a.exprId) }

// We can not use transformUpWithPruning here since this rule is used by both normal Optimizer
// and AQE Optimizer. And this may only effective at AQE side.
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsAnyPattern(LOCAL_RELATION, TRUE_OR_FALSE_LITERAL), ruleId) {
protected def applyInternal: PartialFunction[LogicalPlan, LogicalPlan] = {
case j @ ExtractSingleColumnNullAwareAntiJoin(_, _)
if isRelationWithAllNullKeys.isDefined && isRelationWithAllNullKeys.get(j.right) =>
if isRelationWithAllNullKeys(j.right).contains(true) =>
empty(j)

// Joins on empty LocalRelations generated from streaming sources are not eliminated
Expand Down Expand Up @@ -162,10 +156,10 @@ trait PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSupport {
case _ => p
}
} else if (joinType == LeftSemi && conditionOpt.isEmpty &&
checkRowCount.isDefined && checkRowCount.get.apply(p.right, true)) {
checkRowCount(p.right, true).contains(true)) {
p.left
} else if (joinType == LeftAnti && conditionOpt.isEmpty &&
checkRowCount.isDefined && checkRowCount.get.apply(p.right, true)) {
checkRowCount(p.right, true).contains(true)) {
empty(p)
} else {
p
Expand Down Expand Up @@ -199,4 +193,9 @@ trait PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSupport {
}
}

object PropagateEmptyRelation extends PropagateEmptyRelationBase
object PropagateEmptyRelation extends PropagateEmptyRelationBase {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsAnyPattern(LOCAL_RELATION, TRUE_OR_FALSE_LITERAL), ruleId) {
applyInternal
}
}
Expand Up @@ -119,7 +119,6 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.optimizer.OptimizeUpdateFields"::
"org.apache.spark.sql.catalyst.optimizer.PropagateEmptyRelationBasic" ::
"org.apache.spark.sql.catalyst.optimizer.PropagateEmptyRelation" ::
"org.apache.spark.sql.catalyst.optimizer.AQEPropagateEmptyRelation" ::
"org.apache.spark.sql.catalyst.optimizer.PruneFilters" ::
"org.apache.spark.sql.catalyst.optimizer.PushDownLeftSemiAntiJoin" ::
"org.apache.spark.sql.catalyst.optimizer.PushExtraPredicateThroughJoin" ::
Expand Down
Expand Up @@ -22,29 +22,32 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.joins.HashedRelationWithAllNullKeys

/**
* Rule [[PropagateEmptyRelation]] at AQE side.
* Rule [[PropagateEmptyRelationBase]] at AQE side.
*/
object AQEPropagateEmptyRelation extends PropagateEmptyRelationBase {

override protected def checkRowCount: Option[(LogicalPlan, Boolean) => Boolean] = {
case (plan, hasRow) =>
Some(plan match {
case LogicalQueryStage(_, stage: QueryStageExec) if stage.resultOption.get().isDefined =>
stage.getRuntimeStatistics.rowCount match {
case Some(count) => hasRow == (count > 0)
case _ => false
}
case _ => false
})
override protected def checkRowCount(plan: LogicalPlan, hasRow: Boolean): Option[Boolean] = {
Some(plan match {
case LogicalQueryStage(_, stage: QueryStageExec) if stage.resultOption.get().isDefined =>
stage.getRuntimeStatistics.rowCount match {
case Some(count) => hasRow == (count > 0)
case _ => false
}
case _ => false
})
}

override protected def isRelationWithAllNullKeys: Option[LogicalPlan => Boolean] = {
case plan =>
Some(plan match {
case LogicalQueryStage(_, stage: BroadcastQueryStageExec)
if stage.resultOption.get().isDefined =>
stage.broadcast.relationFuture.get().value == HashedRelationWithAllNullKeys
case _ => false
})
override protected def isRelationWithAllNullKeys(plan: LogicalPlan): Option[Boolean] = {
Some(plan match {
case LogicalQueryStage(_, stage: BroadcastQueryStageExec)
if stage.resultOption.get().isDefined =>
stage.broadcast.relationFuture.get().value == HashedRelationWithAllNullKeys
case _ => false
})
}

// TODO we need use transformUpWithPruning instead of transformUp
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
applyInternal
}
}

0 comments on commit 48dd92a

Please sign in to comment.