Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule}
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._
import org.apache.spark.sql.execution.bucketing.DisableUnnecessaryBucketedScan
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec
import org.apache.spark.sql.execution.exchange._
Expand Down Expand Up @@ -82,17 +83,14 @@ case class AdaptiveSparkPlanExec(
// The logical plan optimizer for re-optimizing the current logical plan.
@transient private val optimizer = new AQEOptimizer(conf)

@transient private val removeRedundantProjects = RemoveRedundantProjects
@transient private val removeRedundantSorts = RemoveRedundantSorts
@transient private val ensureRequirements = EnsureRequirements

// A list of physical plan rules to be applied before creation of query stages. The physical
// plan should reach a final status of query stages (i.e., no more addition or removal of
// Exchange nodes) after running these rules.
private def queryStagePreparationRules: Seq[Rule[SparkPlan]] = Seq(
removeRedundantProjects,
removeRedundantSorts,
ensureRequirements
RemoveRedundantProjects,
RemoveRedundantSorts,
EnsureRequirements,
DisableUnnecessaryBucketedScan
) ++ context.session.sessionState.queryStagePrepRules

// A list of physical optimizer rules to be applied to a new stage before its execution. These
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ object DisableUnnecessaryBucketedScan extends Rule[SparkPlan] {
case scan: FileSourceScanExec =>
if (isBucketedScanWithoutFilter(scan)) {
if (!withInterestingPartition || (withExchange && withAllowedNode)) {
scan.copy(disableBucketedScan = true)
val nonBucketedScan = scan.copy(disableBucketedScan = true)
scan.logicalLink.foreach(nonBucketedScan.setLogicalLink)
nonBucketedScan
} else {
scan
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,40 @@ import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}

class DisableUnnecessaryBucketedScanWithoutHiveSupportSuite
extends DisableUnnecessaryBucketedScanSuite
with SharedSparkSession {
with SharedSparkSession
with DisableAdaptiveExecutionSuite {

protected override def beforeAll(): Unit = {
super.beforeAll()
assert(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory")
}
}

abstract class DisableUnnecessaryBucketedScanSuite extends QueryTest with SQLTestUtils {
class DisableUnnecessaryBucketedScanWithoutHiveSupportSuiteAE
extends DisableUnnecessaryBucketedScanSuite
with SharedSparkSession
with EnableAdaptiveExecutionSuite {

protected override def beforeAll(): Unit = {
super.beforeAll()
assert(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory")
}
}

abstract class DisableUnnecessaryBucketedScanSuite
extends QueryTest
with SQLTestUtils
with AdaptiveSparkPlanHelper {

import testImplicits._

private lazy val df1 =
Expand All @@ -51,7 +69,7 @@ abstract class DisableUnnecessaryBucketedScanSuite extends QueryTest with SQLTes

def checkNumBucketedScan(query: String, expectedNumBucketedScan: Int): Unit = {
val plan = sql(query).queryExecution.executedPlan
val bucketedScan = plan.collect { case s: FileSourceScanExec if s.bucketedScan => s }
val bucketedScan = collect(plan) { case s: FileSourceScanExec if s.bucketedScan => s }
assert(bucketedScan.length == expectedNumBucketedScan)
}

Expand Down Expand Up @@ -230,14 +248,14 @@ abstract class DisableUnnecessaryBucketedScanSuite extends QueryTest with SQLTes
assertCached(spark.table("t1"))

// Verify cached bucketed table scan not disabled
val partitioning = spark.table("t1").queryExecution.executedPlan
val partitioning = stripAQEPlan(spark.table("t1").queryExecution.executedPlan)
.outputPartitioning
assert(partitioning match {
case HashPartitioning(Seq(column: AttributeReference), 8) if column.name == "i" => true
case _ => false
})
val aggregateQueryPlan = sql("SELECT SUM(i) FROM t1 GROUP BY i").queryExecution.executedPlan
assert(aggregateQueryPlan.find(_.isInstanceOf[ShuffleExchangeExec]).isEmpty)
assert(find(aggregateQueryPlan)(_.isInstanceOf[ShuffleExchangeExec]).isEmpty)
}
}
}
Expand Down