Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-41220][SQL] Range partitioner sample supports column pruning #38756

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.collection.mutable

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, IntegerType}

Expand Down Expand Up @@ -394,8 +395,14 @@ object KeyGroupedPartitioning {
*
* This class extends expression primarily so that transformations over expression will descend
* into its child.
*
* If the `planForSample` is present, the shuffle exchange will use the given plan to do sample
* for ranger partitioner.
*/
case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
case class RangePartitioning(
ordering: Seq[SortOrder],
numPartitions: Int,
@transient planForSample: Option[LogicalPlan] = None)
extends Expression with Partitioning with Unevaluable {

override def children: Seq[SortOrder] = ordering
Expand Down Expand Up @@ -441,6 +448,11 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec =
RangeShuffleSpec(this.numPartitions, distribution)

override lazy val canonicalized: Expression =
RangePartitioning(ordering.map(_.canonicalized.asInstanceOf[SortOrder]), numPartitions)

override protected def stringArgs: Iterator[Any] = Iterator(ordering, numPartitions)

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): RangePartitioning =
copy(ordering = newChildren.asInstanceOf[Seq[SortOrder]])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,13 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val OPTIMIZE_SAMPLE_FOR_RANGE_PARTITION_ENABLED =
buildConf("spark.sql.optimizer.optimizeSampleForRangePartition.enabled")
.doc("When set to true, Spark would try to optimize sample plan for range partitioning.")
.version("3.4.0")
.booleanConf
.createWithDefault(true)

val COMPRESS_CACHED = buildConf("spark.sql.inMemoryColumnarStorage.compressed")
.doc("When set to true Spark SQL will automatically select a compression codec for each " +
"column based on statistics of the data.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.adaptive.{AdaptiveExecutionContext, InsertAdaptiveSparkPlan}
import org.apache.spark.sql.execution.bucketing.{CoalesceBucketsInJoin, DisableUnnecessaryBucketedScan}
import org.apache.spark.sql.execution.dynamicpruning.PlanDynamicPruningFilters
import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, OptimizeSampleForRangePartitioning}
import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery
import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -428,6 +428,7 @@ object QueryExecution {
PlanSubqueries(sparkSession),
RemoveRedundantProjects,
EnsureRequirements(),
OptimizeSampleForRangePartitioning,
// `ReplaceHashWithSortAgg` needs to be added after `EnsureRequirements` to guarantee the
// sort order of each node is checked to be valid.
ReplaceHashWithSortAgg,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,9 @@ case class AdaptiveSparkPlanExec(

@transient val initialPlan = context.session.withActive {
applyPhysicalRules(
inputPlan, queryStagePreparationRules, Some((planChangeLogger, "AQE Preparations")))
inputPlan,
queryStagePreparationRules ++ Seq(OptimizeSampleForRangePartitioning),
Some((planChangeLogger, "AQE Preparations")))
}

@volatile private var currentPhysicalPlan = initialPlan
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.exchange

import org.apache.spark.sql.catalyst.expressions.Alias
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.catalyst.plans.physical.RangePartitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.internal.SQLConf

/**
* [[RangePartitioning]] would do sample for range bounds. By default, the plan for sample is same
* with the plan for shuffle.
* The rule decouples the plan for shuffle and for sample. Ideally the plan for sample depends on
* the sort orders expression, so it can be optimized to prune unnecessary columns.
*
* Note, this rule would not optimize the plan which contains multi shuffle exchanges.
*/
object OptimizeSampleForRangePartitioning extends Rule[SparkPlan] {

private def hasBenefit(plan: SparkPlan): Boolean = {
if (conf.getConf(SQLConf.OPTIMIZE_SAMPLE_FOR_RANGE_PARTITION_ENABLED)) {
var numRangePartitioning = 0
var numShuffleWithoutGlobalSort = 0
plan.foreach {
case ShuffleExchangeExec(r: RangePartitioning, child, _) if child.logicalLink.isDefined &&
!child.logicalLink.get.isStreaming &&
r.ordering.flatMap(_.references).size < child.outputSet.size =>
numRangePartitioning += 1
case _: ShuffleExchangeExec => numShuffleWithoutGlobalSort += 1
case _ =>
}
numRangePartitioning == 1 && numShuffleWithoutGlobalSort == 0
} else {
false
}
}

override def apply(plan: SparkPlan): SparkPlan = {
if (!hasBenefit(plan)) {
return plan
}

plan.transform {
case shuffle @ ShuffleExchangeExec(r: RangePartitioning, child, _) =>
val planForSample = child.logicalLink.map { p =>
val named = r.ordering.zipWithIndex.map { case (order, i) =>
Alias(order.child, s"_sort_$i")()
}
Project(named, p.clone())
}

val rangePartitioningWithSample = r.copy(planForSample = planForSample)
shuffle.copy(outputPartitioning = rangePartitioningWithSample)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleWriteMetricsReporter, ShuffleWriteProcessor}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
Expand Down Expand Up @@ -271,16 +272,31 @@ object ShuffleExchangeExec {
// For HashPartitioning, the partitioning key is already a valid partition ID, as we use
// `HashPartitioning.partitionIdExpression` to produce partitioning key.
new PartitionIdPassthrough(n)
case RangePartitioning(sortingExpressions, numPartitions) =>
case RangePartitioning(sortingExpressions, numPartitions, planForSample) =>
// Extract only fields used for sorting to avoid collecting large fields that does not
// affect sorting result when deciding partition bounds in RangePartitioner
val rddForSampling = rdd.mapPartitionsInternal { iter =>
val projection =
UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
val mutablePair = new MutablePair[InternalRow, Null]()
// Internally, RangePartitioner runs a job on the RDD that samples keys to compute
// partition bounds. To get accurate samples, we need to copy the mutable keys.
iter.map(row => mutablePair.update(projection(row).copy(), null))
val rddForSampling = if (planForSample.isEmpty) {
rdd.mapPartitionsInternal { iter =>
val projection =
UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
val mutablePair = new MutablePair[InternalRow, Null]()
// Internally, RangePartitioner runs a job on the RDD that samples keys to compute
// partition bounds. To get accurate samples, we need to copy the mutable keys.
iter.map(row => mutablePair.update(projection(row).copy(), null))
}
} else {
val sample = planForSample.get
assert(sample.output.size == sortingExpressions.size)
// re-optimize sample plan
// this new query execution for sample is still a part of the original plan
// so we do not need to assign a new execution id.
QueryExecution.prepareExecutedPlan(SparkSession.getActiveSession.orNull, sample)
.execute().mapPartitionsInternal { iter =>
val mutablePair = new MutablePair[InternalRow, Null]()
// Internally, RangePartitioner runs a job on the RDD that samples keys to compute
// partition bounds. To get accurate samples, we need to copy the mutable keys.
iter.map(row => mutablePair.update(row.copy(), null))
}
}
// Construct ordering on extracted sort key.
val orderingAttributes = sortingExpressions.zipWithIndex.map { case (ord, i) =>
Expand Down Expand Up @@ -315,7 +331,7 @@ object ShuffleExchangeExec {
case h: HashPartitioning =>
val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes)
row => projection(row).getInt(0)
case RangePartitioning(sortingExpressions, _) =>
case RangePartitioning(sortingExpressions, _, _) =>
val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
row => projection(row)
case SinglePartition => identity
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ abstract class DistributionAndOrderingSuiteBase
partitionValues)
case PartitioningCollection(partitionings) =>
PartitioningCollection(partitionings.map(resolvePartitioning(_, plan)))
case RangePartitioning(ordering, numPartitions) =>
case RangePartitioning(ordering, numPartitions, _) =>
RangePartitioning(ordering.map(resolveAttrs(_, plan).asInstanceOf[SortOrder]), numPartitions)
case p @ SinglePartition =>
p
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {

val projects = collect(planned) { case p: ProjectExec => p }
assert(projects.exists(_.outputPartitioning match {
case RangePartitioning(Seq(SortOrder(ar: AttributeReference, _, _, _)), _) =>
case RangePartitioning(Seq(SortOrder(ar: AttributeReference, _, _, _)), _, _) =>
ar.name == "id1"
case _ => false
}))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.exchange

import org.apache.spark.sql.{DataFrame, QueryTest, Row}
import org.apache.spark.sql.catalyst.plans.physical.RangePartitioning
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}

class OptimizeSampleForRangePartitioningSuite
extends QueryTest with SharedSparkSession with AdaptiveSparkPlanHelper {

protected override def beforeAll(): Unit = {
super.beforeAll()
val data = Seq(
Seq(9, null, "x"),
Seq(null, 3, "y"),
Seq(1, 7, null),
Seq(null, 0, null),
Seq(1, null, null),
Seq(null, null, "b"),
Seq(5, 3, "z"),
Seq(7, 1, "a"),
Seq(5, 1, "b"),
Seq(8, 2, "a"))
val schema = new StructType().add("c1", IntegerType, nullable = true)
.add("c2", IntegerType, nullable = true)
.add("c3", StringType, nullable = true)
val rdd = spark.sparkContext.parallelize(data)
spark.createDataFrame(rdd.map(s => Row.fromSeq(s)), schema)
.write.format("parquet").saveAsTable("t")
}

protected override def afterAll(): Unit = {
spark.sql("DROP TABLE IF EXISTS t")
super.afterAll()
}

private def checkQuery(query: => DataFrame, optimized: Boolean): Unit = {
withSQLConf(SQLConf.OPTIMIZE_SAMPLE_FOR_RANGE_PARTITION_ENABLED.key -> "true") {
val df = query
assert(collect(df.queryExecution.executedPlan) {
case shuffle @ ShuffleExchangeExec(r: RangePartitioning, _, _)
if r.planForSample.isDefined => shuffle
}.nonEmpty == optimized)

var expected: Array[Row] = null
withSQLConf(SQLConf.OPTIMIZE_SAMPLE_FOR_RANGE_PARTITION_ENABLED.key -> "false") {
expected = query.collect()
}
checkAnswer(df, expected)
}
}

test("Optimize range partitioning") {
Seq(
("", "ORDER BY c1"),
("", " ORDER BY c1, c2"),
("/*+ repartition_by_range(c1) */", ""),
("/*+ repartition_by_range(c1, c2) */", "")).foreach { case (head, tail) =>
checkQuery(
sql(s"SELECT $head * FROM t $tail"),
true)

checkQuery(
sql(s"SELECT $head * FROM t WHERE c1 > 4 $tail"),
true)

checkQuery(
sql(s"SELECT $head * FROM t WHERE c2 > 1 $tail"),
true)

checkQuery(
sql(s"SELECT $head * FROM (SELECT * FROM t WHERE c2 > 1 $tail) WHERE c1 > rand()"),
true)
}
}

test("Do not optimize range partitioning") {
Seq(
("", "ORDER BY c1"),
("/*+ repartition_by_range(c1) */", "")).foreach { case (head, tail) =>
// more than one shuffle
checkQuery(
sql(s"SELECT $head c1 FROM t GROUP BY c1 $tail"),
false)
}

Seq(
("", "ORDER BY c1, c2"),
("/*+ repartition_by_range(c1, c2) */", "")).foreach { case (head, tail) =>
// references of sort order is same with query output
checkQuery(
sql(s"SELECT $head c1, c2 FROM t $tail"),
false)
}
}
}