Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions datafusion/core/src/physical_plan/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ impl AggregateExec {
self.aggr_expr.clone(),
input,
baseline_metrics,
context,
partition,
)?))
} else if self.row_aggregate_supported() {
Ok(StreamType::GroupedHashAggregateStreamV2(
Expand Down Expand Up @@ -737,7 +739,7 @@ mod tests {
use arrow::error::{ArrowError, Result as ArrowResult};
use arrow::record_batch::RecordBatch;
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_physical_expr::expressions::{lit, ApproxDistinct, Count};
use datafusion_physical_expr::expressions::{lit, ApproxDistinct, Count, Median};
use datafusion_physical_expr::{AggregateExpr, PhysicalExpr, PhysicalSortExpr};
use futures::{FutureExt, Stream};
use std::any::Any;
Expand Down Expand Up @@ -1131,12 +1133,20 @@ mod tests {
);
let task_ctx = session_ctx.task_ctx();

let groups = PhysicalGroupBy {
let groups_none = PhysicalGroupBy::default();
let groups_some = PhysicalGroupBy {
expr: vec![(col("a", &input_schema)?, "a".to_string())],
null_expr: vec![],
groups: vec![vec![false]],
};

// something that allocates within the aggregator
let aggregates_v0: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Median::new(
col("a", &input_schema)?,
"MEDIAN(a)".to_string(),
DataType::UInt32,
))];

// use slow-path in `hash.rs`
let aggregates_v1: Vec<Arc<dyn AggregateExpr>> =
vec![Arc::new(ApproxDistinct::new(
Expand All @@ -1152,10 +1162,14 @@ mod tests {
DataType::Float64,
))];

for (version, aggregates) in [(1, aggregates_v1), (2, aggregates_v2)] {
for (version, groups, aggregates) in [
(0, groups_none, aggregates_v0),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 this is the test coverage 👍

(1, groups_some.clone(), aggregates_v1),
(2, groups_some, aggregates_v2),
] {
let partial_aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
groups.clone(),
groups,
aggregates,
input.clone(),
input_schema.clone(),
Expand All @@ -1165,6 +1179,9 @@ mod tests {

// ensure that we really got the version we wanted
match version {
0 => {
assert!(matches!(stream, StreamType::AggregateStream(_)));
}
1 => {
assert!(matches!(stream, StreamType::GroupedHashAggregateStream(_)));
}
Expand Down
160 changes: 105 additions & 55 deletions datafusion/core/src/physical_plan/aggregates/no_grouping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

//! Aggregate without grouping columns

use crate::execution::context::TaskContext;
use crate::execution::memory_manager::proxy::MemoryConsumerProxy;
use crate::execution::MemoryConsumerId;
use crate::physical_plan::aggregates::{
aggregate_expressions, create_accumulators, finalize_aggregation, AccumulatorItem,
AggregateMode,
Expand All @@ -28,22 +31,31 @@ use arrow::error::{ArrowError, Result as ArrowResult};
use arrow::record_batch::RecordBatch;
use datafusion_common::Result;
use datafusion_physical_expr::{AggregateExpr, PhysicalExpr};
use futures::stream::BoxStream;
use std::sync::Arc;
use std::task::{Context, Poll};

use futures::{
ready,
stream::{Stream, StreamExt},
};
use futures::stream::{Stream, StreamExt};

/// stream struct for aggregation without grouping columns
pub(crate) struct AggregateStream {
stream: BoxStream<'static, ArrowResult<RecordBatch>>,
schema: SchemaRef,
}

/// Actual implementation of [`AggregateStream`].
///
/// This is wrapped into yet another struct because we need to interact with the async memory management subsystem
/// during poll. To have as little code "weirdness" as possible, we chose to just use [`BoxStream`] together with
/// [`futures::stream::unfold`]. The latter requires a state object, which is [`GroupedHashAggregateStreamV2Inner`].
struct AggregateStreamInner {
schema: SchemaRef,
mode: AggregateMode,
input: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
aggregate_expressions: Vec<Vec<Arc<dyn PhysicalExpr>>>,
accumulators: Vec<AccumulatorItem>,
memory_consumer: MemoryConsumerProxy,
finished: bool,
}

Expand All @@ -55,19 +67,87 @@ impl AggregateStream {
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
input: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
context: Arc<TaskContext>,
partition: usize,
) -> datafusion_common::Result<Self> {
let aggregate_expressions = aggregate_expressions(&aggr_expr, &mode, 0)?;
let accumulators = create_accumulators(&aggr_expr)?;

Ok(Self {
schema,
let memory_consumer = MemoryConsumerProxy::new(
"AggregationState",
MemoryConsumerId::new(partition),
Arc::clone(&context.runtime_env().memory_manager),
);

let inner = AggregateStreamInner {
schema: Arc::clone(&schema),
mode,
input,
baseline_metrics,
aggregate_expressions,
accumulators,
memory_consumer,
finished: false,
})
};
let stream = futures::stream::unfold(inner, |mut this| async move {
if this.finished {
return None;
}

let elapsed_compute = this.baseline_metrics.elapsed_compute();

loop {
let result = match this.input.next().await {
Some(Ok(batch)) => {
let timer = elapsed_compute.timer();
let result = aggregate_batch(
&this.mode,
&batch,
&mut this.accumulators,
&this.aggregate_expressions,
);

timer.done();

// allocate memory
// This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with
// overshooting a bit. Also this means we either store the whole record batch or not.
let result = match result {
Ok(allocated) => this.memory_consumer.alloc(allocated).await,
Err(e) => Err(e),
};

match result {
Ok(_) => continue,
Err(e) => Err(ArrowError::ExternalError(Box::new(e))),
}
}
Some(Err(e)) => Err(e),
None => {
this.finished = true;
let timer = this.baseline_metrics.elapsed_compute().timer();
let result = finalize_aggregation(&this.accumulators, &this.mode)
.map_err(|e| ArrowError::ExternalError(Box::new(e)))
.and_then(|columns| {
RecordBatch::try_new(this.schema.clone(), columns)
})
.record_output(&this.baseline_metrics);

timer.done();

result
}
};

this.finished = true;
return Some((result, this));
}
});

// seems like some consumers call this stream even after it returned `None`, so let's fuse the stream.
let stream = stream.fuse();
let stream = Box::pin(stream);

Ok(Self { schema, stream })
}
}

Expand All @@ -79,49 +159,7 @@ impl Stream for AggregateStream {
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let this = &mut *self;
if this.finished {
return Poll::Ready(None);
}

let elapsed_compute = this.baseline_metrics.elapsed_compute();

loop {
let result = match ready!(this.input.poll_next_unpin(cx)) {
Some(Ok(batch)) => {
let timer = elapsed_compute.timer();
let result = aggregate_batch(
&this.mode,
&batch,
&mut this.accumulators,
&this.aggregate_expressions,
);

timer.done();

match result {
Ok(_) => continue,
Err(e) => Err(ArrowError::ExternalError(Box::new(e))),
}
}
Some(Err(e)) => Err(e),
None => {
this.finished = true;
let timer = this.baseline_metrics.elapsed_compute().timer();
let result = finalize_aggregation(&this.accumulators, &this.mode)
.map_err(|e| ArrowError::ExternalError(Box::new(e)))
.and_then(|columns| {
RecordBatch::try_new(this.schema.clone(), columns)
})
.record_output(&this.baseline_metrics);

timer.done();
result
}
};

this.finished = true;
return Poll::Ready(Some(result));
}
this.stream.poll_next_unpin(cx)
}
}

Expand All @@ -131,13 +169,19 @@ impl RecordBatchStream for AggregateStream {
}
}

/// Perform group-by aggregation for the given [`RecordBatch`].
///
/// If successfull, this returns the additional number of bytes that were allocated during this process.
///
/// TODO: Make this a member function
fn aggregate_batch(
mode: &AggregateMode,
batch: &RecordBatch,
accumulators: &mut [AccumulatorItem],
expressions: &[Vec<Arc<dyn PhysicalExpr>>],
) -> Result<()> {
) -> Result<usize> {
let mut allocated = 0usize;

// 1.1 iterate accumulators and respective expressions together
// 1.2 evaluate expressions
// 1.3 update / merge accumulators with the expressions' values
Expand All @@ -155,11 +199,17 @@ fn aggregate_batch(
.collect::<Result<Vec<_>>>()?;

// 1.3
match mode {
let size_pre = accum.size();
let res = match mode {
AggregateMode::Partial => accum.update_batch(values),
AggregateMode::Final | AggregateMode::FinalPartitioned => {
accum.merge_batch(values)
}
}
})
};
let size_post = accum.size();
allocated += size_post.saturating_sub(size_pre);
res
})?;

Ok(allocated)
}