Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3791,6 +3791,92 @@ mod tests {
Ok(())
}

/// When `skip_partial_aggregation_probe_ratio_threshold` is set to 1.0,
/// the feature must be effectively disabled: even with 100% cardinality
/// (every row is a unique group), no rows should be skipped.
#[tokio::test]
async fn test_skip_aggregation_disabled_at_threshold_one() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("key", DataType::Int32, true),
Field::new("val", DataType::Int32, true),
]));

let group_by =
PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]);

let aggr_expr = vec![
AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?])
.schema(Arc::clone(&schema))
.alias(String::from("COUNT(val)"))
.build()
.map(Arc::new)?,
];

// Two batches are required: batch 1 triggers the probe threshold so the
// skip decision is evaluated; batch 2 is what would be skipped on main
// (where >= caused threshold=1.0 to still skip at 100% cardinality).
// All rows have unique keys => ratio = 1.0 (100% cardinality).
let input_data = vec![
// Batch 1: fires the probe check (ratio = 5/5 = 1.0)
RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])),
Arc::new(Int32Array::from(vec![0, 0, 0, 0, 0])),
],
)
.unwrap(),
// Batch 2: would be skipped if threshold=1.0 did not disable the feature
RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![6, 7, 8, 9, 10])),
Arc::new(Int32Array::from(vec![0, 0, 0, 0, 0])),
],
)
.unwrap(),
];

let input =
TestMemoryExec::try_new_exec(&[input_data], Arc::clone(&schema), None)?;
let aggregate_exec = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
group_by,
aggr_expr,
vec![None],
Arc::clone(&input) as Arc<dyn ExecutionPlan>,
schema,
)?);

let mut session_config = SessionConfig::default();
// Set a very low probe threshold so the ratio check fires immediately
session_config = session_config.set(
"datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
&ScalarValue::Int64(Some(1)),
);
// threshold=1.0 must disable the feature entirely
session_config = session_config.set(
"datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
&ScalarValue::Float64(Some(1.0)),
);

let ctx = TaskContext::default().with_session_config(session_config);
collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;

let metrics = aggregate_exec.metrics().unwrap();
let skipped_rows = metrics
.sum_by_name("skipped_aggregation_rows")
.map(|m| m.as_usize())
.unwrap_or(0);

assert_eq!(
skipped_rows, 0,
"threshold=1.0 should disable skip aggregation, but {skipped_rows} rows were skipped"
);

Ok(())
}

#[test]
fn group_exprs_nullable() -> Result<()> {
let input_schema = Arc::new(Schema::new(vec![
Expand Down
34 changes: 20 additions & 14 deletions datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ impl SkipAggregationProbe {
self.num_groups = num_groups;
if self.input_rows >= self.probe_rows_threshold {
self.should_skip = self.num_groups as f64 / self.input_rows as f64
>= self.probe_ratio_threshold;
> self.probe_ratio_threshold;
// Set is_locked to true only if we have decided to skip, otherwise we can try to skip
// during processing the next record_batch.
self.is_locked = self.should_skip;
Expand Down Expand Up @@ -644,14 +644,20 @@ impl GroupedHashAggregateStream {
options.skip_partial_aggregation_probe_rows_threshold;
let probe_ratio_threshold =
options.skip_partial_aggregation_probe_ratio_threshold;
let skipped_aggregation_rows = MetricBuilder::new(&agg.metrics)
.with_category(MetricCategory::Rows)
.counter("skipped_aggregation_rows", partition);
Some(SkipAggregationProbe::new(
probe_rows_threshold,
probe_ratio_threshold,
skipped_aggregation_rows,
))
// A threshold >= 1.0 means the ratio (num_groups / input_rows) can
// never exceed it, so the feature is effectively disabled.
if probe_ratio_threshold >= 1.0 {
None
} else {
let skipped_aggregation_rows = MetricBuilder::new(&agg.metrics)
.with_category(MetricCategory::Rows)
.counter("skipped_aggregation_rows", partition);
Some(SkipAggregationProbe::new(
probe_rows_threshold,
probe_ratio_threshold,
skipped_aggregation_rows,
))
}
} else {
None
};
Expand Down Expand Up @@ -1630,11 +1636,11 @@ mod tests {
],
)?;

// Batch 2: 350 rows with 350 unique NEW groups (starting from group 10)
// After batch 2, total: 450 rows, 360 groups
// Ratio: 360/450 = 0.8 (80%) >= 0.8 -> SHOULD decide to skip
let batch2_rows = 350;
let batch2_groups = 350;
// Batch 2: 360 rows with 360 unique NEW groups (starting from group 10)
// After batch 2, total: 460 rows, 370 groups
// Ratio: 370/460 ≈ 0.804 (80.4%) > 0.8 -> SHOULD decide to skip
let batch2_rows = 360;
let batch2_groups = 360;
let group_ids_batch2: Vec<i32> = (batch1_groups..(batch1_groups + batch2_groups))
.map(|x| x as i32)
.collect();
Expand Down
Loading