Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-40211][CORE][SQL] Allow customize initial partitions number in take() behavior #37661

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1956,6 +1956,12 @@ package object config {
.intConf
.createWithDefault(10)

private[spark] val RDD_LIMIT_INITIAL_NUM_PARTITIONS =
ConfigBuilder("spark.rdd.limit.initialNumPartitions")
.version("3.4.0")
.intConf
.createWithDefault(1)

private[spark] val RDD_LIMIT_SCALE_UP_FACTOR =
ConfigBuilder("spark.rdd.limit.scaleUpFactor")
.version("2.1.0")
Expand Down
17 changes: 10 additions & 7 deletions core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.reflect.ClassTag

import org.apache.spark.{ComplexFutureAction, FutureAction, JobSubmitter}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.{RDD_LIMIT_INITIAL_NUM_PARTITIONS, RDD_LIMIT_SCALE_UP_FACTOR}
import org.apache.spark.util.ThreadUtils

/**
Expand Down Expand Up @@ -72,6 +73,8 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
val results = new ArrayBuffer[T]
val totalParts = self.partitions.length

val scaleUpFactor = Math.max(self.conf.get(RDD_LIMIT_SCALE_UP_FACTOR), 2)

/*
Recursively triggers jobs to scan partitions until either the requested
number of elements are retrieved, or the partitions to scan are exhausted.
Expand All @@ -84,18 +87,18 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
} else {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1L
var numPartsToTry = Math.max(self.conf.get(RDD_LIMIT_INITIAL_NUM_PARTITIONS), 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enforce it in Config itself and always use self.conf.get(RDD_LIMIT_INITIAL_NUM_PARTITIONS) ?

For RDD_LIMIT_INITIAL_NUM_PARTITIONS:

...
.intConf
.checkValue(_ > 0, "")
...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice idea, modified accordingly

if (partsScanned > 0) {
// If we didn't find any rows after the previous iteration, quadruple and retry.
// Otherwise, interpolate the number of partitions we need to try, but overestimate it
// by 50%. We also cap the estimation in the end.
if (results.size == 0) {
numPartsToTry = partsScanned * 4L
// If we didn't find any rows after the previous iteration, multiply by
// limitScaleUpFactor and retry. Otherwise, interpolate the number of partitions we need
// to try, but overestimate it by 50%. We also cap the estimation in the end.
if (results.isEmpty) {
numPartsToTry = partsScanned * scaleUpFactor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this is fixing a pre-existing bug where the RDD_LIMIT_SCALE_UP_FACTOR wasn't being applied to the AsyncRDDActions version of take(). Nice!

} else {
// the left side of max is >=1 whenever partsScanned >= 2
numPartsToTry = Math.max(1,
(1.5 * num * partsScanned / results.size).toInt - partsScanned)
numPartsToTry = Math.min(numPartsToTry, partsScanned * 4L)
numPartsToTry = Math.min(numPartsToTry, partsScanned * scaleUpFactor)
}
}

Expand Down
8 changes: 4 additions & 4 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1445,12 +1445,12 @@ abstract class RDD[T: ClassTag](
while (buf.size < num && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1L
var numPartsToTry = Math.max(conf.get(RDD_LIMIT_INITIAL_NUM_PARTITIONS), 1)
val left = num - buf.size
if (partsScanned > 0) {
// If we didn't find any rows after the previous iteration, quadruple and retry.
// Otherwise, interpolate the number of partitions we need to try, but overestimate
// it by 50%. We also cap the estimation in the end.
// If we didn't find any rows after the previous iteration, multiply by
// limitScaleUpFactor and retry. Otherwise, interpolate the number of partitions we need
// to try, but overestimate it by 50%. We also cap the estimation in the end.
if (buf.isEmpty) {
numPartsToTry = partsScanned * scaleUpFactor
} else {
Expand Down
39 changes: 38 additions & 1 deletion core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.rdd

import java.io.{File, IOException, ObjectInputStream, ObjectOutputStream}
import java.lang.management.ManagementFactory
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, HashMap}
Expand All @@ -32,8 +33,9 @@ import org.scalatest.concurrent.Eventually

liuzqt marked this conversation as resolved.
Show resolved Hide resolved
import org.apache.spark._
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.internal.config.RDD_PARALLEL_LISTING_THRESHOLD
import org.apache.spark.internal.config.{RDD_LIMIT_INITIAL_NUM_PARTITIONS, RDD_PARALLEL_LISTING_THRESHOLD}
import org.apache.spark.rdd.RDDSuiteUtils._
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.util.{ThreadUtils, Utils}

class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually {
Expand Down Expand Up @@ -1255,6 +1257,41 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually {
assert(numPartsPerLocation(locations(1)) > 0.4 * numCoalescedPartitions)
}

test("SPARK-40211: customize initialNumPartitions for take") {
val totalElements = 100
val numToTake = 50
val rdd = sc.parallelize(0 to totalElements, totalElements)
import scala.language.reflectiveCalls
val jobCountListener = new SparkListener {
private var count: AtomicInteger = new AtomicInteger(0)
def getCount: Int = count.get
def reset(): Unit = count.set(0)
override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
count.incrementAndGet()
}
}
sc.addSparkListener(jobCountListener)
// with default RDD_LIMIT_INITIAL_NUM_PARTITIONS = 1, expecting multiple jobs
rdd.take(numToTake)
sc.listenerBus.waitUntilEmpty()
assert(jobCountListener.getCount > 1)
jobCountListener.reset()
rdd.takeAsync(numToTake).get()
sc.listenerBus.waitUntilEmpty()
assert(jobCountListener.getCount > 1)

// setting RDD_LIMIT_INITIAL_NUM_PARTITIONS to large number(1000), expecting only 1 job
sc.conf.set(RDD_LIMIT_INITIAL_NUM_PARTITIONS, 1000)
jobCountListener.reset()
rdd.take(numToTake)
sc.listenerBus.waitUntilEmpty()
assert(jobCountListener.getCount == 1)
jobCountListener.reset()
rdd.takeAsync(numToTake).get()
sc.listenerBus.waitUntilEmpty()
assert(jobCountListener.getCount == 1)
}

// NOTE
// Below tests calling sc.stop() have to be the last tests in this suite. If there are tests
// running after them and if they access sc those tests will fail as sc is already closed, because
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,15 @@ object SQLConf {
.checkValue(_ >= 1, "The shuffle hash join factor cannot be negative.")
.createWithDefault(3)

val LIMIT_INITIAL_NUM_PARTITIONS = buildConf("spark.sql.limit.initialNumPartitions")
.internal()
.doc("Initial number of partitions to try when executing a take on a query. Higher values " +
liuzqt marked this conversation as resolved.
Show resolved Hide resolved
"lead to more partitions read. Lower values might lead to longer execution times as more" +
"jobs will be run")
.version("3.4.0")
.intConf
.createWithDefault(1)

val LIMIT_SCALE_UP_FACTOR = buildConf("spark.sql.limit.scaleUpFactor")
.internal()
.doc("Minimal increase rate in number of partitions between attempts when executing a take " +
Expand Down Expand Up @@ -4300,6 +4309,8 @@ class SQLConf extends Serializable with Logging {

def autoBroadcastJoinThreshold: Long = getConf(AUTO_BROADCASTJOIN_THRESHOLD)

def limitInitialNumPartitions: Int = getConf(LIMIT_INITIAL_NUM_PARTITIONS)

def limitScaleUpFactor: Int = getConf(LIMIT_SCALE_UP_FACTOR)

def advancedPartitionPredicatePushdownEnabled: Boolean =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
if (n == 0) {
return new Array[InternalRow](0)
}

val limitScaleUpFactor = Math.max(conf.limitScaleUpFactor, 2)
// TODO: refactor and reuse the code from RDD's take()
val childRDD = getByteArrayRdd(n, takeFromEnd)

val buf = if (takeFromEnd) new ListBuffer[InternalRow] else new ArrayBuffer[InternalRow]
Expand All @@ -478,12 +479,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
while (buf.length < n && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1L
var numPartsToTry = Math.max(conf.limitInitialNumPartitions, 1)
if (partsScanned > 0) {
// If we didn't find any rows after the previous iteration, quadruple and retry.
// Otherwise, interpolate the number of partitions we need to try, but overestimate
// it by 50%. We also cap the estimation in the end.
val limitScaleUpFactor = Math.max(conf.limitScaleUpFactor, 2)
// If we didn't find any rows after the previous iteration, multiply by
// limitScaleUpFactor and retry. Otherwise, interpolate the number of partitions we need
// to try, but overestimate it by 50%. We also cap the estimation in the end.
if (buf.isEmpty) {
numPartsToTry = partsScanned * limitScaleUpFactor
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@

package org.apache.spark.sql

import java.util.concurrent.atomic.AtomicInteger

import org.apache.commons.math3.stat.inference.ChiSquareTest

import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecution
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
Expand Down Expand Up @@ -68,4 +71,42 @@ class ConfigBehaviorSuite extends QueryTest with SharedSparkSession {
}
}

test("SPARK-40211: customize initialNumPartitions for take") {
val totalElements = 100
val numToTake = 50
import scala.language.reflectiveCalls
val jobCountListener = new SparkListener {
private var count: AtomicInteger = new AtomicInteger(0)
def getCount: Int = count.get
def reset(): Unit = count.set(0)
override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
count.incrementAndGet()
}
}
spark.sparkContext.addSparkListener(jobCountListener)
val df = spark.range(0, totalElements, 1, totalElements)

// with default LIMIT_INITIAL_NUM_PARTITIONS = 1, expecting multiple jobs
df.take(numToTake)
spark.sparkContext.listenerBus.waitUntilEmpty()
assert(jobCountListener.getCount > 1)
jobCountListener.reset()
df.tail(numToTake)
spark.sparkContext.listenerBus.waitUntilEmpty()
assert(jobCountListener.getCount > 1)

// setting LIMIT_INITIAL_NUM_PARTITIONS to large number(1000), expecting only 1 job

withSQLConf(SQLConf.LIMIT_INITIAL_NUM_PARTITIONS.key -> "1000") {
jobCountListener.reset()
df.take(numToTake)
spark.sparkContext.listenerBus.waitUntilEmpty()
assert(jobCountListener.getCount == 1)
jobCountListener.reset()
df.tail(numToTake)
spark.sparkContext.listenerBus.waitUntilEmpty()
assert(jobCountListener.getCount == 1)
}
}

}