Skip to content

Commit

Permalink
small fix
Browse files Browse the repository at this point in the history
  • Loading branch information
JkSelf committed Jan 3, 2005
1 parent 81ad999 commit c37f397
Showing 1 changed file with 7 additions and 7 deletions.
Expand Up @@ -121,12 +121,12 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
stage.shuffle.shuffleDependency.rdd.partitions.length
}

private def getShuffleQueryStage(plan : SparkPlan): (Boolean, Option[ShuffleQueryStageExec]) =
private def getShuffleQueryStage(plan : SparkPlan): Option[ShuffleQueryStageExec] =
plan match {
case stage: ShuffleQueryStageExec => (true, Some(stage))
case stage: ShuffleQueryStageExec => Some(stage)
case SortExec(_, _, s: ShuffleQueryStageExec, _) =>
(true, Some(s))
case _ => (false, None)
Some(s)
case _ => None
}

private def reOptimizeChild(
Expand Down Expand Up @@ -154,10 +154,10 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
*/
def optimizeSkewJoin(plan: SparkPlan): SparkPlan = plan.transformUp {
case smj @ SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, leftPlan, rightPlan)
if (getShuffleQueryStage(leftPlan)._1 && getShuffleQueryStage(rightPlan)._1) &&
if (getShuffleQueryStage(leftPlan).nonEmpty && getShuffleQueryStage(rightPlan).nonEmpty) &&
supportedJoinTypes.contains(joinType) =>
val left = getShuffleQueryStage(leftPlan)._2.get
val right = getShuffleQueryStage(rightPlan)._2.get
val left = getShuffleQueryStage(leftPlan).get
val right = getShuffleQueryStage(rightPlan).get
val leftStats = getStatistics(left)
val rightStats = getStatistics(right)
val numPartitions = leftStats.bytesByPartitionId.length
Expand Down

0 comments on commit c37f397

Please sign in to comment.