diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7f63d79a21ed6..4388f6e12671a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1176,6 +1176,15 @@ object SQLConf { .longConf .createWithDefault(4 * 1024 * 1024) + val FILES_MIN_PARTITION_NUM = buildConf("spark.sql.files.minPartitionNum") + .doc("The suggested (not guaranteed) minimum number of split file partitions. " + + "If not set, the default value is `spark.default.parallelism`. This configuration is " + + "effective only when using file-based sources such as Parquet, JSON and ORC.") + .version("3.1.0") + .intConf + .checkValue(v => v > 0, "The min partition number must be a positive integer.") + .createOptional + val IGNORE_CORRUPT_FILES = buildConf("spark.sql.files.ignoreCorruptFiles") .doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " + "encountering corrupted files and the contents that have been read will still be returned. " + @@ -2782,6 +2791,8 @@ class SQLConf extends Serializable with Logging { def filesOpenCostInBytes: Long = getConf(FILES_OPEN_COST_IN_BYTES) + def filesMinPartitionNum: Option[Int] = getConf(FILES_MIN_PARTITION_NUM) + def ignoreCorruptFiles: Boolean = getConf(IGNORE_CORRUPT_FILES) def ignoreMissingFiles: Boolean = getConf(IGNORE_MISSING_FILES) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FilePartition.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FilePartition.scala index b4fc94e097aa8..095940772ae78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FilePartition.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FilePartition.scala @@ -88,9 +88,10 @@ object FilePartition extends Logging { selectedPartitions: Seq[PartitionDirectory]): Long = { val defaultMaxSplitBytes = sparkSession.sessionState.conf.filesMaxPartitionBytes val openCostInBytes = sparkSession.sessionState.conf.filesOpenCostInBytes - val defaultParallelism = sparkSession.sparkContext.defaultParallelism + val minPartitionNum = sparkSession.sessionState.conf.filesMinPartitionNum + .getOrElse(sparkSession.sparkContext.defaultParallelism) val totalBytes = selectedPartitions.flatMap(_.files.map(_.getLen + openCostInBytes)).sum - val bytesPerCore = totalBytes / defaultParallelism + val bytesPerCore = totalBytes / minPartitionNum Math.min(defaultMaxSplitBytes, Math.max(openCostInBytes, bytesPerCore)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 812305ba24403..8a6e6b5ee801d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -528,6 +528,41 @@ class FileSourceStrategySuite extends QueryTest with SharedSparkSession with Pre } } + test("SPARK-32019: Add spark.sql.files.minPartitionNum config") { + withSQLConf(SQLConf.FILES_MIN_PARTITION_NUM.key -> "1") { + val table = + createTable(files = Seq( + "file1" -> 1, + "file2" -> 1, + "file3" -> 1 + )) + assert(table.rdd.partitions.length == 1) + } + + withSQLConf(SQLConf.FILES_MIN_PARTITION_NUM.key -> "10") { + val table = + createTable(files = Seq( + "file1" -> 1, + "file2" -> 1, + "file3" -> 1 + )) + assert(table.rdd.partitions.length == 3) + } + + withSQLConf(SQLConf.FILES_MIN_PARTITION_NUM.key -> "16") { + val partitions = (1 to 100).map(i => s"file$i" -> 128 * 1024 * 1024) + val table = createTable(files = partitions) + // partition is limited by filesMaxPartitionBytes(128MB) + assert(table.rdd.partitions.length == 100) + } + + withSQLConf(SQLConf.FILES_MIN_PARTITION_NUM.key -> "32") { + val partitions = (1 to 800).map(i => s"file$i" -> 4 * 1024 * 1024) + val table = createTable(files = partitions) + assert(table.rdd.partitions.length == 50) + } + } + // Helpers for checking the arguments passed to the FileFormat. protected val checkPartitionSchema =