Skip to content

Commit

Permalink
[SPARK-40211][CORE][SQL] Allow customize initial partitions number in…
Browse files Browse the repository at this point in the history
… take() behavior

[SPARK-40211](https://issues.apache.org/jira/browse/SPARK-40211) add a `initialNumPartitions` config parameter to allow customizing initial partitions to try in `take()`

Currently, the initial partitions to try to hardcode to `1`, which might cause unnecessary overhead. By setting this new configuration to a high value we could effectively mitigate the “run multiple jobs” overhead in take behavior. We could also set it to higher-than-1-but-still-small values (like, say, 10) to achieve a middle-ground trade-off.

NO

Unit test

Closes apache#37661 from liuzqt/SPARK-40211.

Authored-by: Ziqi Liu <ziqi.liu@databricks.com>
Signed-off-by: Josh Rosen <joshrosen@databricks.com>
  • Loading branch information
liuzqt authored and chenzhx committed Nov 4, 2022
1 parent 30dd33b commit 234b171
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1870,6 +1870,13 @@ package object config {
.intConf
.createWithDefault(10)

private[spark] val RDD_LIMIT_INITIAL_NUM_PARTITIONS =
ConfigBuilder("spark.rdd.limit.initialNumPartitions")
.version("3.4.0")
.intConf
.checkValue(_ > 0, "value should be positive")
.createWithDefault(1)

private[spark] val RDD_LIMIT_SCALE_UP_FACTOR =
ConfigBuilder("spark.rdd.limit.scaleUpFactor")
.version("2.1.0")
Expand Down
17 changes: 10 additions & 7 deletions core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.reflect.ClassTag

import org.apache.spark.{ComplexFutureAction, FutureAction, JobSubmitter}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.{RDD_LIMIT_INITIAL_NUM_PARTITIONS, RDD_LIMIT_SCALE_UP_FACTOR}
import org.apache.spark.util.ThreadUtils

/**
Expand Down Expand Up @@ -72,6 +73,8 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
val results = new ArrayBuffer[T]
val totalParts = self.partitions.length

val scaleUpFactor = Math.max(self.conf.get(RDD_LIMIT_SCALE_UP_FACTOR), 2)

/*
Recursively triggers jobs to scan partitions until either the requested
number of elements are retrieved, or the partitions to scan are exhausted.
Expand All @@ -84,18 +87,18 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
} else {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1L
var numPartsToTry = self.conf.get(RDD_LIMIT_INITIAL_NUM_PARTITIONS)
if (partsScanned > 0) {
// If we didn't find any rows after the previous iteration, quadruple and retry.
// Otherwise, interpolate the number of partitions we need to try, but overestimate it
// by 50%. We also cap the estimation in the end.
if (results.size == 0) {
numPartsToTry = partsScanned * 4L
// If we didn't find any rows after the previous iteration, multiply by
// limitScaleUpFactor and retry. Otherwise, interpolate the number of partitions we need
// to try, but overestimate it by 50%. We also cap the estimation in the end.
if (results.isEmpty) {
numPartsToTry = partsScanned * scaleUpFactor
} else {
// the left side of max is >=1 whenever partsScanned >= 2
numPartsToTry = Math.max(1,
(1.5 * num * partsScanned / results.size).toInt - partsScanned)
numPartsToTry = Math.min(numPartsToTry, partsScanned * 4L)
numPartsToTry = Math.min(numPartsToTry, partsScanned * scaleUpFactor)
}
}

Expand Down
8 changes: 4 additions & 4 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1430,12 +1430,12 @@ abstract class RDD[T: ClassTag](
while (buf.size < num && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1L
var numPartsToTry = conf.get(RDD_LIMIT_INITIAL_NUM_PARTITIONS)
val left = num - buf.size
if (partsScanned > 0) {
// If we didn't find any rows after the previous iteration, quadruple and retry.
// Otherwise, interpolate the number of partitions we need to try, but overestimate
// it by 50%. We also cap the estimation in the end.
// If we didn't find any rows after the previous iteration, multiply by
// limitScaleUpFactor and retry. Otherwise, interpolate the number of partitions we need
// to try, but overestimate it by 50%. We also cap the estimation in the end.
if (buf.isEmpty) {
numPartsToTry = partsScanned * scaleUpFactor
} else {
Expand Down
39 changes: 38 additions & 1 deletion core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.rdd

import java.io.{File, IOException, ObjectInputStream, ObjectOutputStream}
import java.lang.management.ManagementFactory
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, HashMap}
Expand All @@ -32,8 +33,9 @@ import org.scalatest.concurrent.Eventually

import org.apache.spark._
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.internal.config.RDD_PARALLEL_LISTING_THRESHOLD
import org.apache.spark.internal.config.{RDD_LIMIT_INITIAL_NUM_PARTITIONS, RDD_PARALLEL_LISTING_THRESHOLD}
import org.apache.spark.rdd.RDDSuiteUtils._
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.util.{ThreadUtils, Utils}

class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually {
Expand Down Expand Up @@ -1245,6 +1247,41 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually {
assert(numPartsPerLocation(locations(1)) > 0.4 * numCoalescedPartitions)
}

test("SPARK-40211: customize initialNumPartitions for take") {
val totalElements = 100
val numToTake = 50
val rdd = sc.parallelize(0 to totalElements, totalElements)
import scala.language.reflectiveCalls
val jobCountListener = new SparkListener {
private var count: AtomicInteger = new AtomicInteger(0)
def getCount: Int = count.get
def reset(): Unit = count.set(0)
override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
count.incrementAndGet()
}
}
sc.addSparkListener(jobCountListener)
// with default RDD_LIMIT_INITIAL_NUM_PARTITIONS = 1, expecting multiple jobs
rdd.take(numToTake)
sc.listenerBus.waitUntilEmpty()
assert(jobCountListener.getCount > 1)
jobCountListener.reset()
rdd.takeAsync(numToTake).get()
sc.listenerBus.waitUntilEmpty()
assert(jobCountListener.getCount > 1)

// setting RDD_LIMIT_INITIAL_NUM_PARTITIONS to large number(1000), expecting only 1 job
sc.conf.set(RDD_LIMIT_INITIAL_NUM_PARTITIONS, 1000)
jobCountListener.reset()
rdd.take(numToTake)
sc.listenerBus.waitUntilEmpty()
assert(jobCountListener.getCount == 1)
jobCountListener.reset()
rdd.takeAsync(numToTake).get()
sc.listenerBus.waitUntilEmpty()
assert(jobCountListener.getCount == 1)
}

// NOTE
// Below tests calling sc.stop() have to be the last tests in this suite. If there are tests
// running after them and if they access sc those tests will fail as sc is already closed, because
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,16 @@ object SQLConf {
.bytesConf(ByteUnit.BYTE)
.createWithDefaultString("10MB")

val LIMIT_INITIAL_NUM_PARTITIONS = buildConf("spark.sql.limit.initialNumPartitions")
.internal()
.doc("Initial number of partitions to try when executing a take on a query. Higher values " +
"lead to more partitions read. Lower values might lead to longer execution times as more" +
"jobs will be run")
.version("3.4.0")
.intConf
.checkValue(_ > 0, "value should be positive")
.createWithDefault(1)

val LIMIT_SCALE_UP_FACTOR = buildConf("spark.sql.limit.scaleUpFactor")
.internal()
.doc("Minimal increase rate in number of partitions between attempts when executing a take " +
Expand Down Expand Up @@ -3918,6 +3928,8 @@ class SQLConf extends Serializable with Logging {

def autoBroadcastJoinThreshold: Long = getConf(AUTO_BROADCASTJOIN_THRESHOLD)

def limitInitialNumPartitions: Int = getConf(LIMIT_INITIAL_NUM_PARTITIONS)

def limitScaleUpFactor: Int = getConf(LIMIT_SCALE_UP_FACTOR)

def advancedPartitionPredicatePushdownEnabled: Boolean =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
if (n == 0) {
return new Array[InternalRow](0)
}

val limitScaleUpFactor = Math.max(conf.limitScaleUpFactor, 2)
// TODO: refactor and reuse the code from RDD's take()
val childRDD = getByteArrayRdd(n, takeFromEnd)

val buf = if (takeFromEnd) new ListBuffer[InternalRow] else new ArrayBuffer[InternalRow]
Expand All @@ -403,12 +404,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
while (buf.length < n && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1L
var numPartsToTry = conf.limitInitialNumPartitions
if (partsScanned > 0) {
// If we didn't find any rows after the previous iteration, quadruple and retry.
// Otherwise, interpolate the number of partitions we need to try, but overestimate
// it by 50%. We also cap the estimation in the end.
val limitScaleUpFactor = Math.max(conf.limitScaleUpFactor, 2)
// If we didn't find any rows after the previous iteration, multiply by
// limitScaleUpFactor and retry. Otherwise, interpolate the number of partitions we need
// to try, but overestimate it by 50%. We also cap the estimation in the end.
if (buf.isEmpty) {
numPartsToTry = partsScanned * limitScaleUpFactor
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@

package org.apache.spark.sql

import java.util.concurrent.atomic.AtomicInteger

import org.apache.commons.math3.stat.inference.ChiSquareTest

import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecution
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
Expand Down Expand Up @@ -68,4 +71,42 @@ class ConfigBehaviorSuite extends QueryTest with SharedSparkSession {
}
}

test("SPARK-40211: customize initialNumPartitions for take") {
val totalElements = 100
val numToTake = 50
import scala.language.reflectiveCalls
val jobCountListener = new SparkListener {
private var count: AtomicInteger = new AtomicInteger(0)
def getCount: Int = count.get
def reset(): Unit = count.set(0)
override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
count.incrementAndGet()
}
}
spark.sparkContext.addSparkListener(jobCountListener)
val df = spark.range(0, totalElements, 1, totalElements)

// with default LIMIT_INITIAL_NUM_PARTITIONS = 1, expecting multiple jobs
df.take(numToTake)
spark.sparkContext.listenerBus.waitUntilEmpty()
assert(jobCountListener.getCount > 1)
jobCountListener.reset()
df.tail(numToTake)
spark.sparkContext.listenerBus.waitUntilEmpty()
assert(jobCountListener.getCount > 1)

// setting LIMIT_INITIAL_NUM_PARTITIONS to large number(1000), expecting only 1 job

withSQLConf(SQLConf.LIMIT_INITIAL_NUM_PARTITIONS.key -> "1000") {
jobCountListener.reset()
df.take(numToTake)
spark.sparkContext.listenerBus.waitUntilEmpty()
assert(jobCountListener.getCount == 1)
jobCountListener.reset()
df.tail(numToTake)
spark.sparkContext.listenerBus.waitUntilEmpty()
assert(jobCountListener.getCount == 1)
}
}

}

0 comments on commit 234b171

Please sign in to comment.