Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-32056][SQL][Follow-up] Coalesce partitions for repartiotion hint and sql when AQE is enabled #28952

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -183,7 +183,7 @@ object ResolveHints {
val hintName = hint.name.toUpperCase(Locale.ROOT)

def createRepartitionByExpression(
numPartitions: Int, partitionExprs: Seq[Any]): RepartitionByExpression = {
numPartitions: Option[Int], partitionExprs: Seq[Any]): RepartitionByExpression = {
val sortOrders = partitionExprs.filter(_.isInstanceOf[SortOrder])
if (sortOrders.nonEmpty) throw new IllegalArgumentException(
s"""Invalid partitionExprs specified: $sortOrders
Expand All @@ -208,11 +208,11 @@ object ResolveHints {
throw new AnalysisException(s"$hintName Hint expects a partition number as a parameter")

case param @ Seq(IntegerLiteral(numPartitions), _*) if shuffle =>
createRepartitionByExpression(numPartitions, param.tail)
createRepartitionByExpression(Some(numPartitions), param.tail)
case param @ Seq(numPartitions: Int, _*) if shuffle =>
createRepartitionByExpression(numPartitions, param.tail)
createRepartitionByExpression(Some(numPartitions), param.tail)
case param @ Seq(_*) if shuffle =>
createRepartitionByExpression(conf.numShufflePartitions, param)
createRepartitionByExpression(None, param)
}
}

Expand All @@ -224,7 +224,7 @@ object ResolveHints {
val hintName = hint.name.toUpperCase(Locale.ROOT)

def createRepartitionByExpression(
numPartitions: Int, partitionExprs: Seq[Any]): RepartitionByExpression = {
numPartitions: Option[Int], partitionExprs: Seq[Any]): RepartitionByExpression = {
val invalidParams = partitionExprs.filter(!_.isInstanceOf[UnresolvedAttribute])
if (invalidParams.nonEmpty) {
throw new AnalysisException(s"$hintName Hint parameter should include columns, but " +
Expand All @@ -239,11 +239,11 @@ object ResolveHints {

hint.parameters match {
case param @ Seq(IntegerLiteral(numPartitions), _*) =>
createRepartitionByExpression(numPartitions, param.tail)
createRepartitionByExpression(Some(numPartitions), param.tail)
case param @ Seq(numPartitions: Int, _*) =>
createRepartitionByExpression(numPartitions, param.tail)
createRepartitionByExpression(Some(numPartitions), param.tail)
case param @ Seq(_*) =>
createRepartitionByExpression(conf.numShufflePartitions, param)
createRepartitionByExpression(None, param)
}
}

Expand Down
Expand Up @@ -163,7 +163,7 @@ class ResolveHintsSuite extends AnalysisTest {
checkAnalysis(
UnresolvedHint("REPARTITION", Seq(UnresolvedAttribute("a")), table("TaBlE")),
RepartitionByExpression(
Seq(AttributeReference("a", IntegerType)()), testRelation, conf.numShufflePartitions))
Seq(AttributeReference("a", IntegerType)()), testRelation, None))

val e = intercept[IllegalArgumentException] {
checkAnalysis(
Expand All @@ -187,7 +187,7 @@ class ResolveHintsSuite extends AnalysisTest {
"REPARTITION_BY_RANGE", Seq(UnresolvedAttribute("a")), table("TaBlE")),
RepartitionByExpression(
Seq(SortOrder(AttributeReference("a", IntegerType)(), Ascending)),
testRelation, conf.numShufflePartitions))
testRelation, None))

val errMsg2 = "REPARTITION Hint parameter should include columns, but"

Expand Down
Expand Up @@ -746,7 +746,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
ctx: QueryOrganizationContext,
expressions: Seq[Expression],
query: LogicalPlan): LogicalPlan = {
RepartitionByExpression(expressions, query, conf.numShufflePartitions)
RepartitionByExpression(expressions, query, None)
}

/**
Expand Down
Expand Up @@ -199,20 +199,20 @@ class SparkSqlParserSuite extends AnalysisTest {
assertEqual(s"$baseSql distribute by a, b",
RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil,
basePlan,
numPartitions = newConf.numShufflePartitions))
None))
assertEqual(s"$baseSql distribute by a sort by b",
Sort(SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil,
global = false,
RepartitionByExpression(UnresolvedAttribute("a") :: Nil,
basePlan,
numPartitions = newConf.numShufflePartitions)))
None)))
assertEqual(s"$baseSql cluster by a, b",
Sort(SortOrder(UnresolvedAttribute("a"), Ascending) ::
SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil,
global = false,
RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil,
basePlan,
numPartitions = newConf.numShufflePartitions)))
None)))
}

test("pipeline concatenation") {
Expand Down
Expand Up @@ -23,7 +23,7 @@ import java.net.URI
import org.apache.log4j.Level

import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart}
import org.apache.spark.sql.{QueryTest, Row, SparkSession, Strategy}
import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, ReusedSubqueryExec, ShuffledRowRDD, SparkPlan}
Expand Down Expand Up @@ -130,6 +130,17 @@ class AdaptiveQueryExecSuite
assert(numShuffles === (numLocalReaders.length + numShufflesWithoutLocalReader))
}

private def checkInitialPartitionNum(df: Dataset[_], numPartition: Int): Unit = {
// repartition obeys initialPartitionNum when adaptiveExecutionEnabled
val plan = df.queryExecution.executedPlan
assert(plan.isInstanceOf[AdaptiveSparkPlanExec])
val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect {
case s: ShuffleExchangeExec => s
}
assert(shuffle.size == 1)
assert(shuffle(0).outputPartitioning.numPartitions == numPartition)
}

test("Change merge join to broadcast join") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
Expand Down Expand Up @@ -1040,14 +1051,8 @@ class AdaptiveQueryExecSuite
assert(partitionsNum1 < 10)
assert(partitionsNum2 < 10)

// repartition obeys initialPartitionNum when adaptiveExecutionEnabled
val plan = df1.queryExecution.executedPlan
assert(plan.isInstanceOf[AdaptiveSparkPlanExec])
val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect {
case s: ShuffleExchangeExec => s
}
assert(shuffle.size == 1)
assert(shuffle(0).outputPartitioning.numPartitions == 10)
checkInitialPartitionNum(df1, 10)
checkInitialPartitionNum(df2, 10)
} else {
assert(partitionsNum1 === 10)
assert(partitionsNum2 === 10)
Expand Down Expand Up @@ -1081,14 +1086,8 @@ class AdaptiveQueryExecSuite
assert(partitionsNum1 < 10)
assert(partitionsNum2 < 10)

// repartition obeys initialPartitionNum when adaptiveExecutionEnabled
val plan = df1.queryExecution.executedPlan
assert(plan.isInstanceOf[AdaptiveSparkPlanExec])
val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect {
case s: ShuffleExchangeExec => s
}
assert(shuffle.size == 1)
assert(shuffle(0).outputPartitioning.numPartitions == 10)
checkInitialPartitionNum(df1, 10)
checkInitialPartitionNum(df2, 10)
} else {
assert(partitionsNum1 === 10)
assert(partitionsNum2 === 10)
Expand All @@ -1100,4 +1099,52 @@ class AdaptiveQueryExecSuite
}
}
}

test("SPARK-31220, SPARK-32056: repartition using sql and hint with AQE") {
Seq(true, false).foreach { enableAQE =>
withTempView("test") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString,
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10",
SQLConf.SHUFFLE_PARTITIONS.key -> "10") {

spark.range(10).toDF.createTempView("test")

val df1 = spark.sql("SELECT /*+ REPARTITION(id) */ * from test")
val df2 = spark.sql("SELECT /*+ REPARTITION_BY_RANGE(id) */ * from test")
val df3 = spark.sql("SELECT * from test DISTRIBUTE BY id")
val df4 = spark.sql("SELECT * from test CLUSTER BY id")

val partitionsNum1 = df1.rdd.collectPartitions().length
val partitionsNum2 = df2.rdd.collectPartitions().length
val partitionsNum3 = df3.rdd.collectPartitions().length
val partitionsNum4 = df4.rdd.collectPartitions().length

if (enableAQE) {
assert(partitionsNum1 < 10)
assert(partitionsNum2 < 10)
assert(partitionsNum3 < 10)
assert(partitionsNum4 < 10)

checkInitialPartitionNum(df1, 10)
checkInitialPartitionNum(df2, 10)
checkInitialPartitionNum(df3, 10)
checkInitialPartitionNum(df4, 10)
} else {
assert(partitionsNum1 === 10)
assert(partitionsNum2 === 10)
assert(partitionsNum3 === 10)
assert(partitionsNum4 === 10)
}

// Don't coalesce partitions if the number of partitions is specified.
val df5 = spark.sql("SELECT /*+ REPARTITION(10, id) */ * from test")
val df6 = spark.sql("SELECT /*+ REPARTITION_BY_RANGE(10, id) */ * from test")
assert(df5.rdd.collectPartitions().length == 10)
assert(df6.rdd.collectPartitions().length == 10)
}
}
}
}
}