diff --git a/datafusion/core/src/execution/memory_manager.rs b/datafusion/core/src/execution/memory_manager.rs index f148e331e60f..48d4ca3c3d32 100644 --- a/datafusion/core/src/execution/memory_manager.rs +++ b/datafusion/core/src/execution/memory_manager.rs @@ -195,6 +195,13 @@ pub trait MemoryConsumer: Send + Sync { Ok(()) } + /// Grow without spilling to the disk. It grows the memory directly + /// so it should be only used when the consumer already allocated the + /// memory and it is safe to grow without spilling. + fn grow(&self, required: usize) { + self.memory_manager().record_free_then_acquire(0, required); + } + /// Return `freed` memory to the memory manager, /// may wake up other requesters waiting for their minimum memory quota. fn shrink(&self, freed: usize) { diff --git a/datafusion/core/src/physical_plan/sorts/sort.rs b/datafusion/core/src/physical_plan/sorts/sort.rs index 4586106bbfcc..763c7c553f41 100644 --- a/datafusion/core/src/physical_plan/sorts/sort.rs +++ b/datafusion/core/src/physical_plan/sorts/sort.rs @@ -50,7 +50,7 @@ use futures::lock::Mutex; use futures::{Stream, StreamExt, TryFutureExt, TryStreamExt}; use log::{debug, error}; use std::any::Any; -use std::cmp::min; +use std::cmp::{min, Ordering}; use std::fmt; use std::fmt::{Debug, Formatter}; use std::fs::File; @@ -124,20 +124,27 @@ impl ExternalSorter { // calls to `timer.done()` below. let _timer = tracking_metrics.elapsed_compute().timer(); let partial = sort_batch(input, self.schema.clone(), &self.expr, self.fetch)?; - // The resulting batch might be smaller than the input batch if there - // is an propagated limit. - - if self.fetch.is_some() { - let new_size = batch_byte_size(&partial.sorted_batch); - let size_delta = size.checked_sub(new_size).ok_or_else(|| { - DataFusionError::Internal(format!( - "The size of the sorted batch is larger than the size of the input batch: {} > {}", - new_size, - size - )) - })?; - self.shrink(size_delta); - self.metrics.mem_used().sub(size_delta); + + // The resulting batch might be smaller (or larger, see #3747) than the input + // batch due to either a propagated limit or the re-construction of arrays. So + // for being reliable, we need to reflect the memory usage of the partial batch. + let new_size = batch_byte_size(&partial.sorted_batch); + match new_size.cmp(&size) { + Ordering::Greater => { + // We don't have to call try_grow here, since we have already used the + // memory (so spilling right here wouldn't help at all for the current + // operation). But we still have to record it so that other requesters + // would know about this unexpected increase in memory consuption. + let new_size_delta = new_size - size; + self.grow(new_size_delta); + self.metrics.mem_used().add(new_size_delta); + } + Ordering::Less => { + let size_delta = size - new_size; + self.shrink(size_delta); + self.metrics.mem_used().sub(size_delta); + } + Ordering::Equal => {} } in_mem_batches.push(partial); } diff --git a/datafusion/core/tests/sql/decimal.rs b/datafusion/core/tests/sql/decimal.rs index 9d32f1c318be..2e3e3d2abdfa 100644 --- a/datafusion/core/tests/sql/decimal.rs +++ b/datafusion/core/tests/sql/decimal.rs @@ -690,6 +690,48 @@ async fn decimal_sort() -> Result<()> { ]; assert_batches_eq!(expected, &actual); + let sql = "select * from decimal_simple where c1 >= 0.00004 order by c1 limit 10"; + let actual = execute_to_batches(&ctx, sql).await; + assert_eq!( + &DataType::Decimal128(10, 6), + actual[0].schema().field(0).data_type() + ); + let expected = vec![ + "+----------+----------------+-----+-------+-----------+", + "| c1 | c2 | c3 | c4 | c5 |", + "+----------+----------------+-----+-------+-----------+", + "| 0.000040 | 0.000000000004 | 5 | true | 0.0000440 |", + "| 0.000040 | 0.000000000004 | 12 | false | 0.0000400 |", + "| 0.000040 | 0.000000000004 | 14 | true | 0.0000400 |", + "| 0.000040 | 0.000000000004 | 8 | false | 0.0000440 |", + "| 0.000050 | 0.000000000005 | 9 | true | 0.0000520 |", + "| 0.000050 | 0.000000000005 | 4 | true | 0.0000780 |", + "| 0.000050 | 0.000000000005 | 8 | false | 0.0000330 |", + "| 0.000050 | 0.000000000005 | 100 | true | 0.0000680 |", + "| 0.000050 | 0.000000000005 | 1 | false | 0.0001000 |", + "+----------+----------------+-----+-------+-----------+", + ]; + assert_batches_eq!(expected, &actual); + + let sql = "select * from decimal_simple where c1 >= 0.00004 order by c1 limit 5"; + let actual = execute_to_batches(&ctx, sql).await; + assert_eq!( + &DataType::Decimal128(10, 6), + actual[0].schema().field(0).data_type() + ); + let expected = vec![ + "+----------+----------------+----+-------+-----------+", + "| c1 | c2 | c3 | c4 | c5 |", + "+----------+----------------+----+-------+-----------+", + "| 0.000040 | 0.000000000004 | 5 | true | 0.0000440 |", + "| 0.000040 | 0.000000000004 | 12 | false | 0.0000400 |", + "| 0.000040 | 0.000000000004 | 14 | true | 0.0000400 |", + "| 0.000040 | 0.000000000004 | 8 | false | 0.0000440 |", + "| 0.000050 | 0.000000000005 | 9 | true | 0.0000520 |", + "+----------+----------------+----+-------+-----------+", + ]; + assert_batches_eq!(expected, &actual); + let sql = "select * from decimal_simple where c1 >= 0.00004 order by c1 desc"; let actual = execute_to_batches(&ctx, sql).await; assert_eq!(