Skip to content

Commit

Permalink
Extend rule-exclusion to Optimizer sub-classes, esp. SparkOptimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
maryannxue committed Jul 25, 2018
1 parent a2161ef commit 3730053
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 17 deletions.
Expand Up @@ -46,6 +46,13 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)

protected def fixedPoint = FixedPoint(SQLConf.get.optimizerMaxIterations)

/**
* Defines the default rule batches in the Optimizer.
*
* Implementations of this class should override this method, and [[nonExcludableRules]] if
* necessary, instead of [[batches]]. The rule batches that eventually run in the Optimizer,
* i.e., returned by [[batches]], will be (defaultBatches - (excludedRules - nonExcludableRules)).
*/
def defaultBatches: Seq[Batch] = {
val operatorOptimizationRuleSet =
Seq(
Expand Down Expand Up @@ -160,6 +167,14 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
UpdateNullabilityInAttributeReferences)
}

/**
* Defines rules that cannot be excluded from the Optimizer even if they are specified in
* SQL config "excludedRules".
*
* Implementations of this class can override this method if necessary. The rule batches
* that eventually run in the Optimizer, i.e., returned by [[batches]], will be
* (defaultBatches - (excludedRules - nonExcludableRules)).
*/
def nonExcludableRules: Seq[String] =
EliminateDistinct.ruleName ::
EliminateSubqueryAliases.ruleName ::
Expand Down Expand Up @@ -202,7 +217,14 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
*/
def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil

override def batches: Seq[Batch] = {
/**
* Returns (defaultBatches - (excludedRules - nonExcludableRules)), the rule batches that
* eventually run in the Optimizer.
*
* Implementations of this class should override [[defaultBatches]], and [[nonExcludableRules]]
* if necessary, instead of this method.
*/
final override def batches: Seq[Batch] = {
val excludedRulesConf =
SQLConf.get.optimizerExcludedRules.toSeq.flatMap(Utils.stringToSeq)
val excludedRules = excludedRulesConf.filter { ruleName =>
Expand Down
Expand Up @@ -47,7 +47,7 @@ class OptimizerExtendableSuite extends SparkFunSuite {
DummyRule) :: Nil
}

override def batches: Seq[Batch] = super.batches ++ myBatches
override def defaultBatches: Seq[Batch] = super.defaultBatches ++ myBatches
}

test("Extending batches possible") {
Expand Down
Expand Up @@ -28,8 +28,10 @@ class OptimizerRuleExclusionSuite extends PlanTest {

val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

private def verifyExcludedRules(excludedRuleNames: Seq[String]) {
val optimizer = new SimpleTestOptimizer()
private def verifyExcludedRules(optimizer: Optimizer, rulesToExclude: Seq[String]) {
val nonExcludableRules = optimizer.nonExcludableRules

val excludedRuleNames = rulesToExclude.filter(!nonExcludableRules.contains(_))
// Batches whose rules are all to be excluded should be removed as a whole.
val excludedBatchNames = optimizer.batches
.filter(batch => batch.rules.forall(rule => excludedRuleNames.contains(rule.ruleName)))
Expand All @@ -38,21 +40,31 @@ class OptimizerRuleExclusionSuite extends PlanTest {
withSQLConf(
OPTIMIZER_EXCLUDED_RULES.key -> excludedRuleNames.foldLeft("")((l, r) => l + "," + r)) {
val batches = optimizer.batches
// Verify removed batches.
assert(batches.forall(batch => !excludedBatchNames.contains(batch.name)))
// Verify removed rules.
assert(
batches
.forall(batch => batch.rules.forall(rule => !excludedRuleNames.contains(rule.ruleName))))
// Verify non-excludable rules retained.
nonExcludableRules.foreach { nonExcludableRule =>
assert(
optimizer.batches
.exists(batch => batch.rules.exists(rule => rule.ruleName == nonExcludableRule)))
}
}
}

test("Exclude a single rule from multiple batches") {
verifyExcludedRules(
new SimpleTestOptimizer(),
Seq(
PushPredicateThroughJoin.ruleName))
}

test("Exclude multiple rules from single or multiple batches") {
verifyExcludedRules(
new SimpleTestOptimizer(),
Seq(
CombineUnions.ruleName,
RemoveLiteralFromGroupExpressions.ruleName,
Expand All @@ -61,27 +73,42 @@ class OptimizerRuleExclusionSuite extends PlanTest {

test("Exclude non-existent rule with other valid rules") {
verifyExcludedRules(
new SimpleTestOptimizer(),
Seq(
LimitPushDown.ruleName,
InferFiltersFromConstraints.ruleName,
"DummyRuleName"))
}

test("Try to exclude a non-excludable rule") {
val excludedRules = Seq(
ReplaceIntersectWithSemiJoin.ruleName,
PullupCorrelatedPredicates.ruleName)
verifyExcludedRules(
new SimpleTestOptimizer(),
Seq(
ReplaceIntersectWithSemiJoin.ruleName,
PullupCorrelatedPredicates.ruleName))
}

val optimizer = new SimpleTestOptimizer()
test("Custom optimizer") {
val optimizer = new SimpleTestOptimizer() {
override def defaultBatches: Seq[Batch] =
Batch("push", Once,
PushDownPredicate,
PushPredicateThroughJoin,
PushProjectionThroughUnion) ::
Batch("pull", Once,
PullupCorrelatedPredicates) :: Nil

withSQLConf(
OPTIMIZER_EXCLUDED_RULES.key -> excludedRules.foldLeft("")((l, r) => l + "," + r)) {
excludedRules.foreach { excludedRule =>
assert(
optimizer.batches
.exists(batch => batch.rules.exists(rule => rule.ruleName == excludedRule)))
}
override def nonExcludableRules: Seq[String] =
PushDownPredicate.ruleName ::
PullupCorrelatedPredicates.ruleName :: Nil
}

verifyExcludedRules(
optimizer,
Seq(
PushDownPredicate.ruleName,
PushProjectionThroughUnion.ruleName,
PullupCorrelatedPredicates.ruleName))
}

test("Verify optimized plan after excluding CombineUnions rule") {
Expand Down
Expand Up @@ -44,7 +44,7 @@ class OptimizerStructuralIntegrityCheckerSuite extends PlanTest {
EmptyFunctionRegistry,
new SQLConf())) {
val newBatch = Batch("OptimizeRuleBreakSI", Once, OptimizeRuleBreakSI)
override def batches: Seq[Batch] = Seq(newBatch) ++ super.batches
override def defaultBatches: Seq[Batch] = Seq(newBatch) ++ super.defaultBatches
}

test("check for invalid plan after execution of rule") {
Expand Down
Expand Up @@ -28,13 +28,16 @@ class SparkOptimizer(
experimentalMethods: ExperimentalMethods)
extends Optimizer(catalog) {

override def batches: Seq[Batch] = (preOptimizationBatches ++ super.batches :+
override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+
Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+
Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+
Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++
postHocOptimizationBatches :+
Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*)

override def nonExcludableRules: Seq[String] =
super.nonExcludableRules :+ ExtractPythonUDFFromAggregate.ruleName

/**
* Optimization batches that are executed before the regular optimization batches (also before
* the finish analysis batch).
Expand Down

0 comments on commit 3730053

Please sign in to comment.