diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index a4a58dfe1de53..4ae33311d5a24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule} import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._ +import org.apache.spark.sql.execution.bucketing.DisableUnnecessaryBucketedScan import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec import org.apache.spark.sql.execution.exchange._ @@ -82,17 +83,14 @@ case class AdaptiveSparkPlanExec( // The logical plan optimizer for re-optimizing the current logical plan. @transient private val optimizer = new AQEOptimizer(conf) - @transient private val removeRedundantProjects = RemoveRedundantProjects - @transient private val removeRedundantSorts = RemoveRedundantSorts - @transient private val ensureRequirements = EnsureRequirements - // A list of physical plan rules to be applied before creation of query stages. The physical // plan should reach a final status of query stages (i.e., no more addition or removal of // Exchange nodes) after running these rules. private def queryStagePreparationRules: Seq[Rule[SparkPlan]] = Seq( - removeRedundantProjects, - removeRedundantSorts, - ensureRequirements + RemoveRedundantProjects, + RemoveRedundantSorts, + EnsureRequirements, + DisableUnnecessaryBucketedScan ) ++ context.session.sessionState.queryStagePrepRules // A list of physical optimizer rules to be applied to a new stage before its execution. These diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/DisableUnnecessaryBucketedScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/DisableUnnecessaryBucketedScan.scala index 2bbd5f5d969dc..bb59f44abc761 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/DisableUnnecessaryBucketedScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/DisableUnnecessaryBucketedScan.scala @@ -101,7 +101,9 @@ object DisableUnnecessaryBucketedScan extends Rule[SparkPlan] { case scan: FileSourceScanExec => if (isBucketedScanWithoutFilter(scan)) { if (!withInterestingPartition || (withExchange && withAllowedNode)) { - scan.copy(disableBucketedScan = true) + val nonBucketedScan = scan.copy(disableBucketedScan = true) + scan.logicalLink.foreach(nonBucketedScan.setLogicalLink) + nonBucketedScan } else { scan } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala index 70b74aed40eca..1fdd3be88f782 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala @@ -21,6 +21,8 @@ import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.FileSourceScanExec +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION @@ -28,7 +30,8 @@ import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} class DisableUnnecessaryBucketedScanWithoutHiveSupportSuite extends DisableUnnecessaryBucketedScanSuite - with SharedSparkSession { + with SharedSparkSession + with DisableAdaptiveExecutionSuite { protected override def beforeAll(): Unit = { super.beforeAll() @@ -36,7 +39,22 @@ class DisableUnnecessaryBucketedScanWithoutHiveSupportSuite } } -abstract class DisableUnnecessaryBucketedScanSuite extends QueryTest with SQLTestUtils { +class DisableUnnecessaryBucketedScanWithoutHiveSupportSuiteAE + extends DisableUnnecessaryBucketedScanSuite + with SharedSparkSession + with EnableAdaptiveExecutionSuite { + + protected override def beforeAll(): Unit = { + super.beforeAll() + assert(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory") + } +} + +abstract class DisableUnnecessaryBucketedScanSuite + extends QueryTest + with SQLTestUtils + with AdaptiveSparkPlanHelper { + import testImplicits._ private lazy val df1 = @@ -51,7 +69,7 @@ abstract class DisableUnnecessaryBucketedScanSuite extends QueryTest with SQLTes def checkNumBucketedScan(query: String, expectedNumBucketedScan: Int): Unit = { val plan = sql(query).queryExecution.executedPlan - val bucketedScan = plan.collect { case s: FileSourceScanExec if s.bucketedScan => s } + val bucketedScan = collect(plan) { case s: FileSourceScanExec if s.bucketedScan => s } assert(bucketedScan.length == expectedNumBucketedScan) } @@ -230,14 +248,14 @@ abstract class DisableUnnecessaryBucketedScanSuite extends QueryTest with SQLTes assertCached(spark.table("t1")) // Verify cached bucketed table scan not disabled - val partitioning = spark.table("t1").queryExecution.executedPlan + val partitioning = stripAQEPlan(spark.table("t1").queryExecution.executedPlan) .outputPartitioning assert(partitioning match { case HashPartitioning(Seq(column: AttributeReference), 8) if column.name == "i" => true case _ => false }) val aggregateQueryPlan = sql("SELECT SUM(i) FROM t1 GROUP BY i").queryExecution.executedPlan - assert(aggregateQueryPlan.find(_.isInstanceOf[ShuffleExchangeExec]).isEmpty) + assert(find(aggregateQueryPlan)(_.isInstanceOf[ShuffleExchangeExec]).isEmpty) } } }