Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
peter-toth committed Jun 19, 2020
1 parent adf120c commit 69202d5
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ private[execution] object SparkPlanInfo {
case ReusedSubqueryExec(child) => child :: Nil
case a: AdaptiveSparkPlanExec => a.executedPlan :: Nil
case stage: QueryStageExec => stage.plan :: Nil
case rr: RecursiveRelationExec => rr.anchorTerm +: rr.recursiveTermIterations
case inMemTab: InMemoryTableScanExec => inMemTab.relation.cachedPlan :: Nil
case _ => plan.children ++ plan.subqueries
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -644,8 +644,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.ProjectExec(projectList, planLater(child)) :: Nil
case logical.Filter(condition, child) =>
execution.FilterExec(condition, planLater(child)) :: Nil
case rr @ logical.RecursiveRelation(name, anchorTerm, _) =>
execution.RecursiveRelationExec(name, planLater(anchorTerm), rr.output) ::
case rr @ logical.RecursiveRelation(name, anchorTerm, recursiveTerm) =>
execution.RecursiveRelationExec(name, planLater(anchorTerm), recursiveTerm, rr.output) ::
Nil
case logical.RecursiveReference(name, output, _, level, _, rdd) =>
RDDScanExec(output, rdd, s"RecursiveReference $name, $level") :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

package org.apache.spark.sql.execution

import java.util.concurrent.{Future => JFuture, LinkedBlockingQueue}
import java.util.concurrent.{Future => JFuture}
import java.util.concurrent.TimeUnit._

import scala.collection.mutable
import scala.concurrent.{ExecutionContext}
import scala.concurrent.ExecutionContext
import scala.concurrent.duration.Duration

import org.apache.spark._
Expand All @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{RecursiveReference, RecursiveRelation, Statistics}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, RecursiveReference, RecursiveRelation, Statistics}
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.metric.SQLMetrics
Expand Down Expand Up @@ -277,39 +277,33 @@ case class FilterExec(condition: Expression, child: SparkPlan)
case class RecursiveRelationExec(
cteName: String,
anchorTerm: SparkPlan,
output: Seq[Attribute]) extends SparkPlan {
@transient
lazy val logicalRecursiveTerm = logicalLink.get.asInstanceOf[RecursiveRelation].recursiveTerm

override def children: Seq[SparkPlan] = anchorTerm :: Nil
@transient logicalRecursiveTerm: LogicalPlan,
override val output: Seq[Attribute]) extends SparkPlan {
override def children: Seq[SparkPlan] = Seq(anchorTerm)

override def innerChildren: Seq[QueryPlan[_]] = logicalRecursiveTerm +: super.innerChildren

override def stringArgs: Iterator[Any] = Iterator(cteName, output)

private val physicalRecursiveTerms = new LinkedBlockingQueue[SparkPlan]

def recursiveTermIterations: Seq[SparkPlan] =
physicalRecursiveTerms.toArray(Array.empty[SparkPlan])
private var physicalRecursiveTerms = new mutable.ArrayBuffer[SparkPlan]

override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))

/**
* Notify the listeners of the physical plan change.
*/
private def onUpdatePlan(executionId: Long): Unit = {
val queryExecution = SQLExecution.getQueryExecution(executionId)
sparkContext.listenerBus.post(SparkListenerSQLAdaptiveExecutionUpdate(
executionId,
queryExecution.toString,
SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan)))
private def calculateStatistics(count: Long) = {
Statistics(EstimationUtils.getSizePerRow(output) * count, Some(count))
}

private def unionRDDs(rdds: Seq[RDD[InternalRow]]): RDD[InternalRow] = {
if (rdds.size == 1) {
rdds.head
} else {
sparkContext.union(rdds)
}
}

override protected def doExecute(): RDD[InternalRow] = {
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
val executionIdLong = Option(executionId).map(_.toLong)

val levelLimit = conf.getConf(SQLConf.CTE_RECURSION_LEVEL_LIMIT)

// TODO: cache before count, as the RDD can be reused in the next iteration
Expand Down Expand Up @@ -337,35 +331,17 @@ case class RecursiveRelationExec(
val newLogicalRecursiveTerm = logicalRecursiveTerm.transform {
case rr @ RecursiveReference(name, _, accumulated, _, _, _) if name == cteName =>
val (newStatistics, newRDD) = if (accumulated) {
(
Statistics(
EstimationUtils.getSizePerRow(output) * accumulatedCount,
Some(accumulatedCount)
),
if (accumulatedRDDs.size > 1) {
sparkContext.union(accumulatedRDDs)
} else {
accumulatedRDDs.head
}
)
(calculateStatistics(accumulatedCount), unionRDDs(accumulatedRDDs))
} else {
(
Statistics(
EstimationUtils.getSizePerRow(output) * prevIterationCount,
Some(prevIterationCount)
),
prevIterationRDD
)
(calculateStatistics(prevIterationCount), prevIterationRDD)
}
rr.withNewIteration(level, newStatistics, newRDD)
}

val physicalRecursiveTerm =
QueryExecution.prepareExecutedPlan(sqlContext.sparkSession, newLogicalRecursiveTerm)

physicalRecursiveTerms.offer(physicalRecursiveTerm)

executionIdLong.foreach(onUpdatePlan)
physicalRecursiveTerms += physicalRecursiveTerm

// TODO: cache before count, as the RDD can be reused in the next iteration
prevIterationRDD = physicalRecursiveTerm.execute().map(_.copy())
Expand All @@ -374,14 +350,10 @@ case class RecursiveRelationExec(
level = level + 1
}

executionIdLong.foreach(onUpdatePlan)

if (accumulatedRDDs.isEmpty) {
prevIterationRDD
} else if (accumulatedRDDs.size == 1) {
accumulatedRDDs.head
} else {
sparkContext.union(accumulatedRDDs)
unionRDDs(accumulatedRDDs)
}
}
}
Expand Down

0 comments on commit 69202d5

Please sign in to comment.