Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions datafusion/core/src/physical_plan/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ impl ExecutionPlan for AggregateExec {
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let batch_size = context.session_config().batch_size();
let input = self.input.execute(partition, context)?;

let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
Expand All @@ -318,6 +319,7 @@ impl ExecutionPlan for AggregateExec {
self.aggr_expr.clone(),
input,
baseline_metrics,
batch_size,
)?))
} else {
Ok(Box::pin(GroupedHashAggregateStream::new(
Expand Down
121 changes: 72 additions & 49 deletions datafusion/core/src/physical_plan/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,11 @@ pub(crate) struct GroupedHashAggregateStreamV2 {

baseline_metrics: BaselineMetrics,
random_state: RandomState,
finished: bool,
/// size to be used for resulting RecordBatches
batch_size: usize,
/// if the result is chunked into batches,
/// last offset is preserved for continuation.
row_group_skip_position: usize,
}

fn aggr_state_schema(aggr_expr: &[Arc<dyn AggregateExpr>]) -> Result<SchemaRef> {
Expand All @@ -105,6 +109,7 @@ impl GroupedHashAggregateStreamV2 {
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
input: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
batch_size: usize,
) -> Result<Self> {
let timer = baseline_metrics.elapsed_compute().timer();

Expand Down Expand Up @@ -135,7 +140,8 @@ impl GroupedHashAggregateStreamV2 {
aggregate_expressions,
aggr_state: Default::default(),
random_state: Default::default(),
finished: false,
batch_size,
row_group_skip_position: 0,
})
}
}
Expand All @@ -148,56 +154,62 @@ impl Stream for GroupedHashAggregateStreamV2 {
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let this = &mut *self;
if this.finished {
return Poll::Ready(None);
}

let elapsed_compute = this.baseline_metrics.elapsed_compute();

loop {
let result = match ready!(this.input.poll_next_unpin(cx)) {
Some(Ok(batch)) => {
let timer = elapsed_compute.timer();
let result = group_aggregate_batch(
&this.mode,
&this.random_state,
&this.group_by,
&mut this.accumulators,
&this.group_schema,
this.aggr_layout.clone(),
batch,
&mut this.aggr_state,
&this.aggregate_expressions,
);

timer.done();

match result {
Ok(_) => continue,
Err(e) => Err(ArrowError::ExternalError(Box::new(e))),
let result: ArrowResult<Option<RecordBatch>> =
match ready!(this.input.poll_next_unpin(cx)) {
Some(Ok(batch)) => {
let timer = elapsed_compute.timer();
let result = group_aggregate_batch(
&this.mode,
&this.random_state,
&this.group_by,
&mut this.accumulators,
&this.group_schema,
this.aggr_layout.clone(),
batch,
&mut this.aggr_state,
&this.aggregate_expressions,
);

timer.done();

match result {
Ok(_) => continue,
Err(e) => Err(ArrowError::ExternalError(Box::new(e))),
}
}
Some(Err(e)) => Err(e),
None => {
let timer = this.baseline_metrics.elapsed_compute().timer();
let result = create_batch_from_map(
&this.mode,
&this.group_schema,
&this.aggr_schema,
this.batch_size,
this.row_group_skip_position,
&mut this.aggr_state,
&mut this.accumulators,
&this.schema,
);

timer.done();
result
}
};

this.row_group_skip_position += this.batch_size;
match result {
Ok(Some(result)) => {
return Poll::Ready(Some(Ok(
result.record_output(&this.baseline_metrics)
)))
}
Some(Err(e)) => Err(e),
None => {
this.finished = true;
let timer = this.baseline_metrics.elapsed_compute().timer();
let result = create_batch_from_map(
&this.mode,
&this.group_schema,
&this.aggr_schema,
&mut this.aggr_state,
&mut this.accumulators,
&this.schema,
)
.record_output(&this.baseline_metrics);

timer.done();
result
}
};

this.finished = true;
return Poll::Ready(Some(result));
Ok(None) => return Poll::Ready(None),
Err(error) => return Poll::Ready(Some(Err(error))),
}
}
}
}
Expand Down Expand Up @@ -419,23 +431,34 @@ fn create_group_rows(arrays: Vec<ArrayRef>, schema: &Schema) -> Vec<Vec<u8>> {
}

/// Create a RecordBatch with all group keys and accumulator' states or values.
#[allow(clippy::too_many_arguments)]
fn create_batch_from_map(
mode: &AggregateMode,
group_schema: &Schema,
aggr_schema: &Schema,
batch_size: usize,
skip_items: usize,
aggr_state: &mut AggregationState,
accumulators: &mut [AccumulatorItemV2],
output_schema: &Schema,
) -> ArrowResult<RecordBatch> {
) -> ArrowResult<Option<RecordBatch>> {
if skip_items > aggr_state.group_states.len() {
return Ok(None);
}

if aggr_state.group_states.is_empty() {
return Ok(RecordBatch::new_empty(Arc::new(output_schema.to_owned())));
return Ok(Some(RecordBatch::new_empty(Arc::new(
output_schema.to_owned(),
))));
}

let mut state_accessor = RowAccessor::new(aggr_schema, RowType::WordAligned);

let (group_buffers, mut state_buffers): (Vec<_>, Vec<_>) = aggr_state
.group_states
.iter()
.skip(skip_items)
.take(batch_size)
.map(|gs| (gs.group_by_values.clone(), gs.aggregation_buffer.clone()))
.unzip();

Expand Down Expand Up @@ -471,7 +494,7 @@ fn create_batch_from_map(
.map(|(col, desired_field)| cast(col, desired_field.data_type()))
.collect::<ArrowResult<Vec<_>>>()?;

RecordBatch::try_new(Arc::new(output_schema.to_owned()), columns)
RecordBatch::try_new(Arc::new(output_schema.to_owned()), columns).map(Some)
}

fn read_as_batch(rows: &[Vec<u8>], schema: &Schema, row_type: RowType) -> Vec<ArrayRef> {
Expand Down