diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index bfe90181ff..2ad9c1f609 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -534,6 +534,18 @@ object CometConf extends ShimCometConf { .checkValue(v => v > 0, "Write buffer size must be positive") .createWithDefault(1) + val COMET_SHUFFLE_MAX_BUFFERED_BATCHES: ConfigEntry[Int] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.maxBufferedBatches") + .category(CATEGORY_SHUFFLE) + .doc("Maximum number of batches to buffer in memory before spilling to disk during " + + "native shuffle. Setting this to a small value causes earlier spilling, which reduces " + + "peak memory usage on executors at the cost of more disk I/O. " + + "The default value of 0 disables this limit and spills only when the memory pool is " + + "exhausted.") + .intConf + .checkValue(v => v >= 0, "Max buffered batches must be non-negative") + .createWithDefault(0) + val COMET_SHUFFLE_PREFER_DICTIONARY_RATIO: ConfigEntry[Double] = conf( "spark.comet.shuffle.preferDictionary.ratio") .category(CATEGORY_SHUFFLE) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 5af31fcc22..b6e05fccd1 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1352,6 +1352,7 @@ impl PhysicalPlanner { }?; let write_buffer_size = writer.write_buffer_size as usize; + let max_buffered_batches = writer.max_buffered_batches as usize; let shuffle_writer = Arc::new(ShuffleWriterExec::try_new( Arc::clone(&child.native_plan), partitioning, @@ -1360,6 +1361,7 @@ impl PhysicalPlanner { writer.output_index_file.clone(), writer.tracing_enabled, write_buffer_size, + max_buffered_batches, )?); Ok(( diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 344b9f0f21..5e23aad061 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -294,6 +294,9 @@ message ShuffleWriter { // Size of the write buffer in bytes used when writing shuffle data to disk. // Larger values may improve write performance but use more memory. int32 write_buffer_size = 8; + // Maximum number of batches to buffer before spilling to disk. + // 0 means no limit (spill only when memory pool is exhausted). + int32 max_buffered_batches = 9; } message ParquetWriter { diff --git a/native/shuffle/benches/shuffle_writer.rs b/native/shuffle/benches/shuffle_writer.rs index 27abd919fa..8ff1f024d5 100644 --- a/native/shuffle/benches/shuffle_writer.rs +++ b/native/shuffle/benches/shuffle_writer.rs @@ -153,6 +153,7 @@ fn create_shuffle_writer_exec( "/tmp/index.out".to_string(), false, 1024 * 1024, + 0, // max_buffered_batches: no limit ) .unwrap() } diff --git a/native/shuffle/src/partitioners/multi_partition.rs b/native/shuffle/src/partitioners/multi_partition.rs index 42290c5510..f7bd83be66 100644 --- a/native/shuffle/src/partitioners/multi_partition.rs +++ b/native/shuffle/src/partitioners/multi_partition.rs @@ -124,6 +124,8 @@ pub(crate) struct MultiPartitionShuffleRepartitioner { tracing_enabled: bool, /// Size of the write buffer in bytes write_buffer_size: usize, + /// Maximum number of batches to buffer before spilling (0 = no limit) + max_buffered_batches: usize, } impl MultiPartitionShuffleRepartitioner { @@ -140,6 +142,7 @@ impl MultiPartitionShuffleRepartitioner { codec: CompressionCodec, tracing_enabled: bool, write_buffer_size: usize, + max_buffered_batches: usize, ) -> datafusion::common::Result { let num_output_partitions = partitioning.partition_count(); assert_ne!( @@ -189,6 +192,7 @@ impl MultiPartitionShuffleRepartitioner { reservation, tracing_enabled, write_buffer_size, + max_buffered_batches, }) } @@ -397,6 +401,12 @@ impl MultiPartitionShuffleRepartitioner { partition_row_indices: &[u32], partition_starts: &[u32], ) -> datafusion::common::Result<()> { + // Spill before buffering if we've reached the configured batch count limit. + if self.max_buffered_batches > 0 && self.buffered_batches.len() >= self.max_buffered_batches + { + self.spill()?; + } + let mut mem_growth: usize = input.get_array_memory_size(); let buffered_partition_idx = self.buffered_batches.len() as u32; self.buffered_batches.push(input); diff --git a/native/shuffle/src/shuffle_writer.rs b/native/shuffle/src/shuffle_writer.rs index e649aaac69..95a09610a4 100644 --- a/native/shuffle/src/shuffle_writer.rs +++ b/native/shuffle/src/shuffle_writer.rs @@ -68,6 +68,8 @@ pub struct ShuffleWriterExec { tracing_enabled: bool, /// Size of the write buffer in bytes write_buffer_size: usize, + /// Maximum number of batches to buffer before spilling (0 = no limit) + max_buffered_batches: usize, } impl ShuffleWriterExec { @@ -81,6 +83,7 @@ impl ShuffleWriterExec { output_index_file: String, tracing_enabled: bool, write_buffer_size: usize, + max_buffered_batches: usize, ) -> Result { let cache = PlanProperties::new( EquivalenceProperties::new(Arc::clone(&input.schema())), @@ -99,6 +102,7 @@ impl ShuffleWriterExec { codec, tracing_enabled, write_buffer_size, + max_buffered_batches, }) } } @@ -163,6 +167,7 @@ impl ExecutionPlan for ShuffleWriterExec { self.output_index_file.clone(), self.tracing_enabled, self.write_buffer_size, + self.max_buffered_batches, )?)), _ => panic!("ShuffleWriterExec wrong number of children"), } @@ -190,6 +195,7 @@ impl ExecutionPlan for ShuffleWriterExec { self.codec.clone(), self.tracing_enabled, self.write_buffer_size, + self.max_buffered_batches, ) .map_err(|e| ArrowError::ExternalError(Box::new(e))), ) @@ -210,6 +216,7 @@ async fn external_shuffle( codec: CompressionCodec, tracing_enabled: bool, write_buffer_size: usize, + max_buffered_batches: usize, ) -> Result { with_trace_async("external_shuffle", tracing_enabled, || async { let schema = input.schema(); @@ -238,6 +245,7 @@ async fn external_shuffle( codec, tracing_enabled, write_buffer_size, + max_buffered_batches, )?), }; @@ -362,6 +370,7 @@ mod test { CompressionCodec::Lz4Frame, false, 1024 * 1024, // write_buffer_size: 1MB default + 0, // max_buffered_batches: no limit ) .unwrap(); @@ -466,6 +475,7 @@ mod test { "/tmp/index.out".to_string(), false, 1024 * 1024, // write_buffer_size: 1MB default + 0, // max_buffered_batches: no limit ) .unwrap(); @@ -525,6 +535,7 @@ mod test { index_file.clone(), false, 1024 * 1024, + 0, // max_buffered_batches: no limit ) .unwrap(); diff --git a/native/shuffle/src/writers/partition_writer.rs b/native/shuffle/src/writers/partition_writer.rs index 48017871db..4de307de62 100644 --- a/native/shuffle/src/writers/partition_writer.rs +++ b/native/shuffle/src/writers/partition_writer.rs @@ -26,7 +26,6 @@ use std::fs::{File, OpenOptions}; struct SpillFile { temp_file: RefCountedTempFile, - file: File, } pub(crate) struct PartitionWriter { @@ -53,26 +52,28 @@ impl PartitionWriter { runtime: &RuntimeEnv, ) -> datafusion::common::Result<()> { if self.spill_file.is_none() { - // Spill file is not yet created, create it let spill_file = runtime .disk_manager .create_tmp_file("shuffle writer spill")?; - let spill_data = OpenOptions::new() - .write(true) - .create(true) - .truncate(true) - .open(spill_file.path()) - .map_err(|e| { - DataFusionError::Execution(format!("Error occurred while spilling {e}")) - })?; + // Create the file (truncating any pre-existing content) + File::create(spill_file.path()).map_err(|e| { + DataFusionError::Execution(format!("Error occurred while spilling {e}")) + })?; self.spill_file = Some(SpillFile { temp_file: spill_file, - file: spill_data, }); } Ok(()) } + fn open_spill_file_for_append(&self) -> datafusion::common::Result { + OpenOptions::new() + .write(true) + .append(true) + .open(self.spill_file.as_ref().unwrap().temp_file.path()) + .map_err(|e| DataFusionError::Execution(format!("Error occurred while spilling {e}"))) + } + pub(crate) fn spill( &mut self, iter: &mut PartitionedBatchIterator, @@ -84,10 +85,13 @@ impl PartitionWriter { if let Some(batch) = iter.next() { self.ensure_spill_file_created(runtime)?; + // Open the file for this spill and close it when done, so we don't + // hold open one FD per partition across multiple spill events. + let mut spill_data = self.open_spill_file_for_append()?; let total_bytes_written = { let mut buf_batch_writer = BufBatchWriter::new( &mut self.shuffle_block_writer, - &mut self.spill_file.as_mut().unwrap().file, + &mut spill_data, write_buffer_size, batch_size, ); @@ -104,6 +108,7 @@ impl PartitionWriter { buf_batch_writer.flush(&metrics.encode_time, &metrics.write_time)?; bytes_written }; + // spill_data is dropped here, closing the file descriptor Ok(total_bytes_written) } else { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala index 3fc222bd19..a80d8b2fa4 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala @@ -192,6 +192,8 @@ class CometNativeShuffleWriter[K, V]( CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL.get) shuffleWriterBuilder.setWriteBufferSize( CometConf.COMET_SHUFFLE_WRITE_BUFFER_SIZE.get().max(Int.MaxValue).toInt) + shuffleWriterBuilder.setMaxBufferedBatches( + CometConf.COMET_SHUFFLE_MAX_BUFFERED_BATCHES.get()) outputPartitioning match { case p if isSinglePartitioning(p) =>