Skip to content

Commit

Permalink
[SPARK-32430][SQL] Extend SparkSessionExtensions to inject rules into…
Browse files Browse the repository at this point in the history
… AQE query stage preparation

### What changes were proposed in this pull request?

Provide a generic mechanism for plugins to inject rules into the AQE "query prep" stage that happens before query stage creation.

This goes along with https://issues.apache.org/jira/browse/SPARK-32332 where the current AQE implementation doesn't allow for users to properly extend it for columnar processing.

### Why are the changes needed?

The issue here is that we create new query stages but we do not have access to the parent plan of the new query stage so certain things can not be determined because you have to know what the parent did.  With this change it would allow you to add TAGs to be able to figure out what is going on.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

A new unit test is included in the PR.

Closes #29224 from andygrove/insert-aqe-rule.

Authored-by: Andy Grove <andygrove@nvidia.com>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
  • Loading branch information
andygrove authored and dongjoon-hyun committed Jul 24, 2020
1 parent d3596c0 commit 64a01c0
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 5 deletions.
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.ColumnarRule
import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan}

/**
* :: Experimental ::
Expand All @@ -44,6 +44,7 @@ import org.apache.spark.sql.execution.ColumnarRule
* <li>Customized Parser.</li>
* <li>(External) Catalog listeners.</li>
* <li>Columnar Rules.</li>
* <li>Adaptive Query Stage Preparation Rules.</li>
* </ul>
*
* The extensions can be used by calling `withExtensions` on the [[SparkSession.Builder]], for
Expand Down Expand Up @@ -96,8 +97,10 @@ class SparkSessionExtensions {
type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface
type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder)
type ColumnarRuleBuilder = SparkSession => ColumnarRule
type QueryStagePrepRuleBuilder = SparkSession => Rule[SparkPlan]

private[this] val columnarRuleBuilders = mutable.Buffer.empty[ColumnarRuleBuilder]
private[this] val queryStagePrepRuleBuilders = mutable.Buffer.empty[QueryStagePrepRuleBuilder]

/**
* Build the override rules for columnar execution.
Expand All @@ -106,13 +109,28 @@ class SparkSessionExtensions {
columnarRuleBuilders.map(_.apply(session)).toSeq
}

/**
* Build the override rules for the query stage preparation phase of adaptive query execution.
*/
private[sql] def buildQueryStagePrepRules(session: SparkSession): Seq[Rule[SparkPlan]] = {
queryStagePrepRuleBuilders.map(_.apply(session)).toSeq
}

/**
* Inject a rule that can override the columnar execution of an executor.
*/
def injectColumnar(builder: ColumnarRuleBuilder): Unit = {
columnarRuleBuilders += builder
}

/**
* Inject a rule that can override the the query stage preparation phase of adaptive query
* execution.
*/
def injectQueryStagePrepRule(builder: QueryStagePrepRuleBuilder): Unit = {
queryStagePrepRuleBuilders += builder
}

private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]

/**
Expand Down
Expand Up @@ -90,7 +90,7 @@ case class AdaptiveSparkPlanExec(
// Exchange nodes) after running these rules.
private def queryStagePreparationRules: Seq[Rule[SparkPlan]] = Seq(
ensureRequirements
)
) ++ context.session.sessionState.queryStagePrepRules

// A list of physical optimizer rules to be applied to a new stage before its execution. These
// optimizations should be stage-independent.
Expand Down
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.execution.{ColumnarRule, QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser}
import org.apache.spark.sql.execution.{ColumnarRule, QueryExecution, SparkOptimizer, SparkPlan, SparkPlanner, SparkSqlParser}
import org.apache.spark.sql.execution.aggregate.ResolveEncodersInScalaAgg
import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin
import org.apache.spark.sql.execution.command.CommandCheck
Expand Down Expand Up @@ -286,6 +286,10 @@ abstract class BaseSessionStateBuilder(
extensions.buildColumnarRules(session)
}

protected def queryStagePrepRules: Seq[Rule[SparkPlan]] = {
extensions.buildQueryStagePrepRules(session)
}

/**
* Create a query execution object.
*/
Expand Down Expand Up @@ -337,7 +341,8 @@ abstract class BaseSessionStateBuilder(
() => resourceLoader,
createQueryExecution,
createClone,
columnarRules)
columnarRules,
queryStagePrepRules)
}
}

Expand Down
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.execution._
import org.apache.spark.sql.streaming.StreamingQueryManager
Expand Down Expand Up @@ -73,7 +74,8 @@ private[sql] class SessionState(
resourceLoaderBuilder: () => SessionResourceLoader,
createQueryExecution: LogicalPlan => QueryExecution,
createClone: (SparkSession, SessionState) => SessionState,
val columnarRules: Seq[ColumnarRule]) {
val columnarRules: Seq[ColumnarRule],
val queryStagePrepRules: Seq[Rule[SparkPlan]]) {

// The following fields are lazy to avoid creating the Hive client when creating SessionState.
lazy val catalog: SessionCatalog = catalogBuilder()
Expand Down
Expand Up @@ -26,7 +26,9 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, UnresolvedHint}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.COLUMN_BATCH_SIZE
Expand Down Expand Up @@ -145,6 +147,28 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
}
}

test("inject adaptive query prep rule") {
val extensions = create { extensions =>
// inject rule that will run during AQE query stage preparation and will add custom tags
// to the plan
extensions.injectQueryStagePrepRule(session => MyQueryStagePrepRule())
// inject rule that will run during AQE query stage optimization and will verify that the
// custom tags were written in the preparation phase
extensions.injectColumnar(session =>
MyColumarRule(MyNewQueryStageRule(), MyNewQueryStageRule()))
}
withSession(extensions) { session =>
session.sessionState.conf.setConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED, true)
assert(session.sessionState.queryStagePrepRules.contains(MyQueryStagePrepRule()))
assert(session.sessionState.columnarRules.contains(
MyColumarRule(MyNewQueryStageRule(), MyNewQueryStageRule())))
import session.sqlContext.implicits._
val data = Seq((100L), (200L), (300L)).toDF("vals").repartition(1)
val df = data.selectExpr("vals + 1")
df.collect()
}
}

test("inject columnar") {
val extensions = create { extensions =>
extensions.injectColumnar(session =>
Expand Down Expand Up @@ -731,6 +755,31 @@ class MyExtensions extends (SparkSessionExtensions => Unit) {
}
}

object QueryPrepRuleHelper {
val myPrepTag: TreeNodeTag[String] = TreeNodeTag[String]("myPrepTag")
val myPrepTagValue: String = "myPrepTagValue"
}

// this rule will run during AQE query preparation and will write custom tags to each node
case class MyQueryStagePrepRule() extends Rule[SparkPlan] {
override def apply(plan: SparkPlan): SparkPlan = plan.transformDown {
case plan =>
plan.setTagValue(QueryPrepRuleHelper.myPrepTag, QueryPrepRuleHelper.myPrepTagValue)
plan
}
}

// this rule will run during AQE query stage optimization and will verify custom tags were
// already written during query preparation phase
case class MyNewQueryStageRule() extends Rule[SparkPlan] {
override def apply(plan: SparkPlan): SparkPlan = plan.transformDown {
case plan if !plan.isInstanceOf[AdaptiveSparkPlanExec] =>
assert(plan.getTagValue(QueryPrepRuleHelper.myPrepTag).get ==
QueryPrepRuleHelper.myPrepTagValue)
plan
}
}

case class MyRule2(spark: SparkSession) extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan
}
Expand Down

0 comments on commit 64a01c0

Please sign in to comment.