From 0d228d9269771723caef939b391ebbc12dcaa556 Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Wed, 27 May 2026 20:46:39 +0530 Subject: [PATCH 1/3] Fix TopK DISTINCT aggregation preserving NULLs --- .../aggregate_statistics.rs | 84 ++++++++++++++ .../src/aggregates/topk_stream.rs | 40 ++++++- .../test_files/aggregates_topk.slt | 105 ++++++++++++++++++ 3 files changed, 225 insertions(+), 4 deletions(-) diff --git a/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs b/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs index 808e163b08369..0fa60ae20d2be 100644 --- a/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs @@ -553,3 +553,87 @@ async fn test_count_distinct_optimization() -> Result<()> { Ok(()) } + +/// Regression test for https://github.com/apache/datafusion/issues/22554 +/// +/// TopK aggregation for DISTINCT queries was unconditionally dropping NULL +/// group keys, producing wrong results with NULLS FIRST / NULLS LAST ordering. +#[tokio::test] +async fn topk_distinct_preserves_nulls() -> Result<()> { + let ctx = SessionContext::new_with_config(SessionConfig::new()); + + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("v", DataType::Utf8, true)])), + vec![Arc::new(StringArray::from(vec![None, Some(""), Some("a")]))], + )?; + let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?; + ctx.register_table("t", Arc::new(table))?; + + // ASC NULLS FIRST LIMIT 1 → NULL should come first + let result = ctx + .sql("SELECT DISTINCT v FROM t ORDER BY v ASC NULLS FIRST LIMIT 1") + .await? + .collect() + .await?; + assert_batches_eq!(&["+---+", "| v |", "+---+", "| |", "+---+"], &result); + assert!(result[0].column(0).is_null(0), "first row should be NULL"); + + // ASC NULLS FIRST LIMIT 2 → NULL, then empty string + let result = ctx + .sql("SELECT DISTINCT v FROM t ORDER BY v ASC NULLS FIRST LIMIT 2") + .await? + .collect() + .await?; + assert_eq!(result[0].num_rows(), 2); + assert!(result[0].column(0).is_null(0)); + assert!(!result[0].column(0).is_null(1)); + + // ASC NULLS LAST LIMIT 1 → empty string (smallest non-null) + let result = ctx + .sql("SELECT DISTINCT v FROM t ORDER BY v ASC NULLS LAST LIMIT 1") + .await? + .collect() + .await?; + assert!( + !result[0].column(0).is_null(0), + "first row should NOT be NULL" + ); + + // Full result with NULLS LAST should include NULL at end + let result = ctx + .sql("SELECT DISTINCT v FROM t ORDER BY v ASC NULLS LAST LIMIT 3") + .await? + .collect() + .await?; + assert_eq!(result[0].num_rows(), 3); + assert!(result[0].column(0).is_null(2), "last row should be NULL"); + + // Integer column + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("v", DataType::Int64, true)])), + vec![Arc::new(Int64Array::from(vec![None, Some(3), Some(1)]))], + )?; + let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?; + ctx.register_table("t_int", Arc::new(table))?; + + let result = ctx + .sql("SELECT DISTINCT v FROM t_int ORDER BY v ASC NULLS FIRST LIMIT 1") + .await? + .collect() + .await?; + assert!( + result[0].column(0).is_null(0), + "integer NULL should be first" + ); + + let result = ctx + .sql("SELECT DISTINCT v FROM t_int ORDER BY v DESC NULLS LAST LIMIT 2") + .await? + .collect() + .await?; + assert_eq!(result[0].num_rows(), 2); + assert!(!result[0].column(0).is_null(0)); + assert!(!result[0].column(0).is_null(1)); + + Ok(()) +} diff --git a/datafusion/physical-plan/src/aggregates/topk_stream.rs b/datafusion/physical-plan/src/aggregates/topk_stream.rs index 9128844f1d1ef..65a5ea1a71e1c 100644 --- a/datafusion/physical-plan/src/aggregates/topk_stream.rs +++ b/datafusion/physical-plan/src/aggregates/topk_stream.rs @@ -28,7 +28,8 @@ use crate::aggregates::{ use crate::metrics::BaselineMetrics; use crate::stream::EmptyRecordBatchStream; use crate::{RecordBatchStream, SendableRecordBatchStream}; -use arrow::array::{Array, ArrayRef, RecordBatch}; +use arrow::array::{Array, ArrayRef, RecordBatch, new_null_array}; +use arrow::compute::concat; use arrow::datatypes::SchemaRef; use arrow::util::pretty::print_batches; use datafusion_common::Result; @@ -46,6 +47,7 @@ pub struct GroupedTopKAggregateStream { partition: usize, row_count: usize, started: bool, + done: bool, schema: SchemaRef, input: SendableRecordBatchStream, baseline_metrics: BaselineMetrics, @@ -53,6 +55,8 @@ pub struct GroupedTopKAggregateStream { aggregate_arguments: Vec>>, group_by: Arc, priority_map: PriorityMap, + /// Whether a NULL group key has been seen (only tracked for DISTINCT queries) + null_group_seen: bool, } impl GroupedTopKAggregateStream { @@ -109,6 +113,7 @@ impl GroupedTopKAggregateStream { Ok(GroupedTopKAggregateStream { partition, started: false, + done: false, row_count: 0, schema: agg_schema, input, @@ -117,6 +122,7 @@ impl GroupedTopKAggregateStream { aggregate_arguments, group_by, priority_map, + null_group_seen: false, }) } } @@ -128,6 +134,10 @@ impl RecordBatchStream for GroupedTopKAggregateStream { } impl GroupedTopKAggregateStream { + fn is_distinct(&self) -> bool { + self.aggregate_arguments.is_empty() + } + fn intern(&mut self, ids: &ArrayRef, vals: &ArrayRef) -> Result<()> { let _timer = self.group_by_metrics.time_calculating_group_ids.timer(); @@ -138,6 +148,9 @@ impl GroupedTopKAggregateStream { let has_nulls = vals.null_count() > 0; for row_idx in 0..len { if has_nulls && vals.is_null(row_idx) { + if self.is_distinct() { + self.null_group_seen = true; + } continue; } self.priority_map.insert(row_idx)?; @@ -153,6 +166,9 @@ impl Stream for GroupedTopKAggregateStream { mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { + if self.done { + return Poll::Ready(None); + } let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); let emitting_time = self.group_by_metrics.emitting_time.clone(); while let Poll::Ready(res) = self.input.poll_next_unpin(cx) { @@ -209,17 +225,32 @@ impl Stream for GroupedTopKAggregateStream { // Release the input pipeline's resources before emitting. let input_schema = self.input.schema(); self.input = Box::pin(EmptyRecordBatchStream::new(input_schema)); - if self.priority_map.is_empty() { + if self.priority_map.is_empty() && !self.null_group_seen { trace!("partition {} emit None", self.partition); + self.done = true; return Poll::Ready(None); } let batch = { let _timer = emitting_time.timer(); - let mut cols = self.priority_map.emit()?; + let mut cols = if self.priority_map.is_empty() { + vec![] + } else { + self.priority_map.emit()? + }; // For DISTINCT case (no aggregate expressions), only use the group key column // since the schema only has one field and key/value are the same - if self.aggregate_arguments.is_empty() { + if self.is_distinct() { cols.truncate(1); + if self.null_group_seen { + let dt = self.schema.field(0).data_type(); + let null_arr = new_null_array(dt, 1); + if cols.is_empty() { + cols.push(null_arr); + } else { + cols[0] = + concat(&[cols[0].as_ref(), null_arr.as_ref()])?; + } + } } RecordBatch::try_new(Arc::clone(&self.schema), cols)? }; @@ -232,6 +263,7 @@ impl Stream for GroupedTopKAggregateStream { if log::log_enabled!(Level::Trace) { print_batches(std::slice::from_ref(&batch))?; } + self.done = true; return Poll::Ready(Some(Ok(batch))); } // inner had error, return to caller diff --git a/datafusion/sqllogictest/test_files/aggregates_topk.slt b/datafusion/sqllogictest/test_files/aggregates_topk.slt index 19ead8965ed01..c45c047c867fd 100644 --- a/datafusion/sqllogictest/test_files/aggregates_topk.slt +++ b/datafusion/sqllogictest/test_files/aggregates_topk.slt @@ -456,6 +456,111 @@ select count(*) from (select category from values_table group by category order ---- 3 +# Test DISTINCT with NULLs and NULLS FIRST ordering (issue #22554) +statement ok +create table nullable_vals (v varchar) as values (NULL), (''), ('a'), ('b'); + +# NULLS FIRST: NULL should be the first row returned by LIMIT +query T +select distinct v from nullable_vals order by v asc nulls first limit 1; +---- +NULL + +query T +select distinct v from nullable_vals order by v asc nulls first limit 2; +---- +NULL +(empty) + +query T +select distinct v from nullable_vals order by v asc nulls first limit 3; +---- +NULL +(empty) +a + +# NULLS LAST: non-null values come first +query T +select distinct v from nullable_vals order by v asc nulls last limit 1; +---- +(empty) + +query T +select distinct v from nullable_vals order by v asc nulls last limit 4; +---- +(empty) +a +b +NULL + +# DESC NULLS FIRST: NULL comes first +query T +select distinct v from nullable_vals order by v desc nulls first limit 1; +---- +NULL + +# DESC NULLS LAST: NULL comes last +query T +select distinct v from nullable_vals order by v desc nulls last limit 1; +---- +b + +query T +select distinct v from nullable_vals order by v desc nulls last limit 4; +---- +b +a +(empty) +NULL + +# Test with integer column containing NULLs +statement ok +create table nullable_ints (v int) as values (NULL), (3), (1), (2); + +query I +select distinct v from nullable_ints order by v asc nulls first limit 1; +---- +NULL + +query I +select distinct v from nullable_ints order by v asc nulls first limit 3; +---- +NULL +1 +2 + +query I +select distinct v from nullable_ints order by v desc nulls last limit 2; +---- +3 +2 + +query I +select distinct v from nullable_ints order by v asc nulls last limit 4; +---- +1 +2 +3 +NULL + +# Test with all-NULL column +statement ok +create table all_nulls (v varchar) as values (NULL), (NULL); + +query T +select distinct v from all_nulls order by v asc nulls first limit 1; +---- +NULL + +statement ok +drop table nullable_vals; + +statement ok +drop table nullable_ints; + +statement ok +drop table all_nulls; + statement ok drop table values_table; From 2a700cd793cfa1639f9e7142274c266eaa3bdf4b Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Wed, 27 May 2026 21:55:12 +0530 Subject: [PATCH 2/3] added explain test --- .../sqllogictest/test_files/aggregates_topk.slt | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/datafusion/sqllogictest/test_files/aggregates_topk.slt b/datafusion/sqllogictest/test_files/aggregates_topk.slt index c45c047c867fd..81c85c433b78a 100644 --- a/datafusion/sqllogictest/test_files/aggregates_topk.slt +++ b/datafusion/sqllogictest/test_files/aggregates_topk.slt @@ -460,6 +460,22 @@ select count(*) from (select category from values_table group by category order statement ok create table nullable_vals (v varchar) as values (NULL), (''), ('a'), ('b'); +# Verify this regression test exercises the TopK aggregation path +query TT +explain select distinct v from nullable_vals order by v asc nulls first limit 1; +---- +logical_plan +01)Sort: nullable_vals.v ASC NULLS FIRST, fetch=1 +02)--Aggregate: groupBy=[[nullable_vals.v]], aggr=[[]] +03)----TableScan: nullable_vals projection=[v] +physical_plan +01)SortPreservingMergeExec: [v@0 ASC], fetch=1 +02)--SortExec: TopK(fetch=1), expr=[v@0 ASC], preserve_partitioning=[true] +03)----AggregateExec: mode=FinalPartitioned, gby=[v@0 as v], aggr=[], lim=[1] +04)------RepartitionExec: partitioning=Hash([v@0], 4), input_partitions=1 +05)--------AggregateExec: mode=Partial, gby=[v@0 as v], aggr=[], lim=[1] +06)----------DataSourceExec: partitions=1, partition_sizes=[1] + # NULLS FIRST: NULL should be the first row returned by LIMIT query T select distinct v from nullable_vals order by v asc nulls first limit 1; From 48629631cc9c72219903c86fedabfbee88d89d02 Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Thu, 28 May 2026 11:25:13 +0530 Subject: [PATCH 3/3] refactor --- .../src/aggregates/topk_stream.rs | 68 +++++++++++-------- 1 file changed, 41 insertions(+), 27 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/topk_stream.rs b/datafusion/physical-plan/src/aggregates/topk_stream.rs index 65a5ea1a71e1c..97f4662c11342 100644 --- a/datafusion/physical-plan/src/aggregates/topk_stream.rs +++ b/datafusion/physical-plan/src/aggregates/topk_stream.rs @@ -55,7 +55,7 @@ pub struct GroupedTopKAggregateStream { aggregate_arguments: Vec>>, group_by: Arc, priority_map: PriorityMap, - /// Whether a NULL group key has been seen (only tracked for DISTINCT queries) + /// Whether a NULL group key has been seen for a group-by-only aggregation. null_group_seen: bool, } @@ -134,7 +134,7 @@ impl RecordBatchStream for GroupedTopKAggregateStream { } impl GroupedTopKAggregateStream { - fn is_distinct(&self) -> bool { + fn is_group_by_only(&self) -> bool { self.aggregate_arguments.is_empty() } @@ -146,17 +146,50 @@ impl GroupedTopKAggregateStream { .set_batch(Arc::clone(ids), Arc::clone(vals)); let has_nulls = vals.null_count() > 0; + if has_nulls && self.is_group_by_only() { + self.null_group_seen = true; + } for row_idx in 0..len { if has_nulls && vals.is_null(row_idx) { - if self.is_distinct() { - self.null_group_seen = true; - } continue; } self.priority_map.insert(row_idx)?; } Ok(()) } + + fn emit_columns(&mut self) -> Result> { + let mut cols = if self.priority_map.is_empty() { + vec![] + } else { + self.priority_map.emit()? + }; + + // GROUP BY-only aggregation covers DISTINCT-like queries. The group + // key and heap value are the same column, but the output schema has + // only the group key. + if self.is_group_by_only() { + cols.truncate(1); + if self.null_group_seen { + self.append_null_group(&mut cols)?; + } + } + + Ok(cols) + } + + fn append_null_group(&self, cols: &mut Vec) -> Result<()> { + let dt = self.schema.field(0).data_type(); + let null_arr = new_null_array(dt, 1); + if cols.is_empty() { + cols.push(null_arr); + } else { + // NULL group keys are tracked outside the heap, so append a + // one-row NULL array to the emitted non-NULL group key column. + cols[0] = concat(&[cols[0].as_ref(), null_arr.as_ref()])?; + } + Ok(()) + } } impl Stream for GroupedTopKAggregateStream { @@ -201,8 +234,8 @@ impl Stream for GroupedTopKAggregateStream { "Exactly 1 group value required" ); let group_by_values = Arc::clone(&group_by_values[0][0]); - let input_values = if self.aggregate_arguments.is_empty() { - // DISTINCT case: use group key as both key and value + let input_values = if self.is_group_by_only() { + // GROUP BY-only case: use group key as both key and value Arc::clone(&group_by_values) } else { // MIN/MAX case: evaluate aggregate expressions @@ -232,26 +265,7 @@ impl Stream for GroupedTopKAggregateStream { } let batch = { let _timer = emitting_time.timer(); - let mut cols = if self.priority_map.is_empty() { - vec![] - } else { - self.priority_map.emit()? - }; - // For DISTINCT case (no aggregate expressions), only use the group key column - // since the schema only has one field and key/value are the same - if self.is_distinct() { - cols.truncate(1); - if self.null_group_seen { - let dt = self.schema.field(0).data_type(); - let null_arr = new_null_array(dt, 1); - if cols.is_empty() { - cols.push(null_arr); - } else { - cols[0] = - concat(&[cols[0].as_ref(), null_arr.as_ref()])?; - } - } - } + let cols = self.emit_columns()?; RecordBatch::try_new(Arc::clone(&self.schema), cols)? }; let batch = batch.record_output(&self.baseline_metrics);