diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 9f86fb298877a..13e5b129765e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -232,10 +232,10 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT } /** - * Returns a sequence containing the result of applying a partial function to all elements in this + * A variant of `collect`. This method not only apply the given function to all elements in this * plan, also considering all the plans in its (nested) subqueries */ - def collectInPlanAndSubqueries[B](f: PartialFunction[PlanType, B]): Seq[B] = + def collectWithSubqueries[B](f: PartialFunction[PlanType, B]): Seq[B] = (this +: subqueriesAll).flatMap(_.collect(f)) override def innerChildren: Seq[QueryPlan[_]] = subqueries diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala index d96f8086a3e93..91ce187f4d270 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala @@ -78,7 +78,7 @@ class QueryPlanSuite extends SparkFunSuite { val countRelationsInPlan = plan.collect({ case _: UnresolvedRelation => 1 }).sum val countRelationsInPlanAndSubqueries = - plan.collectInPlanAndSubqueries({ case _: UnresolvedRelation => 1 }).sum + plan.collectWithSubqueries({ case _: UnresolvedRelation => 1 }).sum assert(countRelationsInPlan == 2) assert(countRelationsInPlanAndSubqueries == 5) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala index e482bc9941ea9..e1b9c8f430c56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala @@ -87,7 +87,7 @@ object CollectMetricsExec { * Recursively collect all collected metrics from a query tree. */ def collect(plan: SparkPlan): Map[String, Row] = { - val metrics = plan.collectInPlanAndSubqueries { + val metrics = plan.collectWithSubqueries { case collector: CollectMetricsExec => collector.name -> collector.collectedMetrics } metrics.toMap diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala index baa9f5ecafc68..cdf9ea4b31ee7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala @@ -1234,7 +1234,7 @@ abstract class DynamicPartitionPruningSuiteBase val plan = df.queryExecution.executedPlan val countSubqueryBroadcasts = - plan.collectInPlanAndSubqueries({ case _: SubqueryBroadcastExec => 1 }).sum + plan.collectWithSubqueries({ case _: SubqueryBroadcastExec => 1 }).sum assert(countSubqueryBroadcasts == 2) }