diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index a797a9e4fb28a..db587dd98685e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -57,7 +57,6 @@ private[execution] object SparkPlanInfo { case ReusedSubqueryExec(child) => child :: Nil case a: AdaptiveSparkPlanExec => a.executedPlan :: Nil case stage: QueryStageExec => stage.plan :: Nil - case rr: RecursiveRelationExec => rr.anchorTerm +: rr.recursiveTermIterations case inMemTab: InMemoryTableScanExec => inMemTab.relation.cachedPlan :: Nil case _ => plan.children ++ plan.subqueries } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f2910e84b7c25..088a1c8d83dbe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -644,8 +644,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.ProjectExec(projectList, planLater(child)) :: Nil case logical.Filter(condition, child) => execution.FilterExec(condition, planLater(child)) :: Nil - case rr @ logical.RecursiveRelation(name, anchorTerm, _) => - execution.RecursiveRelationExec(name, planLater(anchorTerm), rr.output) :: + case rr @ logical.RecursiveRelation(name, anchorTerm, recursiveTerm) => + execution.RecursiveRelationExec(name, planLater(anchorTerm), recursiveTerm, rr.output) :: Nil case logical.RecursiveReference(name, output, _, level, _, rdd) => RDDScanExec(output, rdd, s"RecursiveReference $name, $level") :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index ec06c4a34ade0..fc2111b62efaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.execution -import java.util.concurrent.{Future => JFuture, LinkedBlockingQueue} +import java.util.concurrent.{Future => JFuture} import java.util.concurrent.TimeUnit._ import scala.collection.mutable -import scala.concurrent.{ExecutionContext} +import scala.concurrent.ExecutionContext import scala.concurrent.duration.Duration import org.apache.spark._ @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.{RecursiveReference, RecursiveRelation, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, RecursiveReference, RecursiveRelation, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics @@ -277,39 +277,33 @@ case class FilterExec(condition: Expression, child: SparkPlan) case class RecursiveRelationExec( cteName: String, anchorTerm: SparkPlan, - output: Seq[Attribute]) extends SparkPlan { - @transient - lazy val logicalRecursiveTerm = logicalLink.get.asInstanceOf[RecursiveRelation].recursiveTerm - - override def children: Seq[SparkPlan] = anchorTerm :: Nil + @transient logicalRecursiveTerm: LogicalPlan, + override val output: Seq[Attribute]) extends SparkPlan { + override def children: Seq[SparkPlan] = Seq(anchorTerm) override def innerChildren: Seq[QueryPlan[_]] = logicalRecursiveTerm +: super.innerChildren override def stringArgs: Iterator[Any] = Iterator(cteName, output) - private val physicalRecursiveTerms = new LinkedBlockingQueue[SparkPlan] - - def recursiveTermIterations: Seq[SparkPlan] = - physicalRecursiveTerms.toArray(Array.empty[SparkPlan]) + private var physicalRecursiveTerms = new mutable.ArrayBuffer[SparkPlan] override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) - /** - * Notify the listeners of the physical plan change. - */ - private def onUpdatePlan(executionId: Long): Unit = { - val queryExecution = SQLExecution.getQueryExecution(executionId) - sparkContext.listenerBus.post(SparkListenerSQLAdaptiveExecutionUpdate( - executionId, - queryExecution.toString, - SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan))) + private def calculateStatistics(count: Long) = { + Statistics(EstimationUtils.getSizePerRow(output) * count, Some(count)) + } + + private def unionRDDs(rdds: Seq[RDD[InternalRow]]): RDD[InternalRow] = { + if (rdds.size == 1) { + rdds.head + } else { + sparkContext.union(rdds) + } } override protected def doExecute(): RDD[InternalRow] = { val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - val executionIdLong = Option(executionId).map(_.toLong) - val levelLimit = conf.getConf(SQLConf.CTE_RECURSION_LEVEL_LIMIT) // TODO: cache before count, as the RDD can be reused in the next iteration @@ -337,25 +331,9 @@ case class RecursiveRelationExec( val newLogicalRecursiveTerm = logicalRecursiveTerm.transform { case rr @ RecursiveReference(name, _, accumulated, _, _, _) if name == cteName => val (newStatistics, newRDD) = if (accumulated) { - ( - Statistics( - EstimationUtils.getSizePerRow(output) * accumulatedCount, - Some(accumulatedCount) - ), - if (accumulatedRDDs.size > 1) { - sparkContext.union(accumulatedRDDs) - } else { - accumulatedRDDs.head - } - ) + (calculateStatistics(accumulatedCount), unionRDDs(accumulatedRDDs)) } else { - ( - Statistics( - EstimationUtils.getSizePerRow(output) * prevIterationCount, - Some(prevIterationCount) - ), - prevIterationRDD - ) + (calculateStatistics(prevIterationCount), prevIterationRDD) } rr.withNewIteration(level, newStatistics, newRDD) } @@ -363,9 +341,7 @@ case class RecursiveRelationExec( val physicalRecursiveTerm = QueryExecution.prepareExecutedPlan(sqlContext.sparkSession, newLogicalRecursiveTerm) - physicalRecursiveTerms.offer(physicalRecursiveTerm) - - executionIdLong.foreach(onUpdatePlan) + physicalRecursiveTerms += physicalRecursiveTerm // TODO: cache before count, as the RDD can be reused in the next iteration prevIterationRDD = physicalRecursiveTerm.execute().map(_.copy()) @@ -374,14 +350,10 @@ case class RecursiveRelationExec( level = level + 1 } - executionIdLong.foreach(onUpdatePlan) - if (accumulatedRDDs.isEmpty) { prevIterationRDD - } else if (accumulatedRDDs.size == 1) { - accumulatedRDDs.head } else { - sparkContext.union(accumulatedRDDs) + unionRDDs(accumulatedRDDs) } } }