diff --git a/datafusion/physical-plan/src/aggregates/no_grouping.rs b/datafusion/physical-plan/src/aggregates/no_grouping.rs index 99417e4ee3e9..ecbf2ffc74cd 100644 --- a/datafusion/physical-plan/src/aggregates/no_grouping.rs +++ b/datafusion/physical-plan/src/aggregates/no_grouping.rs @@ -215,7 +215,7 @@ fn aggregate_batch( .try_for_each(|((accum, expr), filter)| { // 1.2 let batch = match filter { - Some(filter) => Cow::Owned(batch_filter(&batch, filter)?), + Some(filter) => Cow::Owned(batch_filter(&batch, filter, true)?), None => Cow::Borrowed(&batch), }; diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 568987b14798..8945441ad15f 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -15,11 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{ready, Context, Poll}; - use super::{ ColumnStatistics, DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, @@ -28,10 +23,16 @@ use crate::{ metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, DisplayFormatType, ExecutionPlan, }; +use arrow::array::MutableArrayData; +use std::any::Any; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{ready, Context, Poll}; use arrow::compute::filter_record_batch; use arrow::datatypes::{DataType, SchemaRef}; use arrow::record_batch::RecordBatch; +use arrow_array::{make_array, ArrayRef, RecordBatchOptions}; use datafusion_common::cast::as_boolean_array; use datafusion_common::stats::Precision; use datafusion_common::{internal_err, plan_err, DataFusionError, Result}; @@ -61,6 +62,9 @@ pub struct FilterExec { default_selectivity: u8, /// Properties equivalence properties, partitioning, etc. cache: PlanProperties, + /// Whether to allow an input batch to be returned unmodified in the case where + /// the predicate evaluates to true for all rows in the batch + reuse_input_batches: bool, } impl FilterExec { @@ -68,6 +72,15 @@ impl FilterExec { pub fn try_new( predicate: Arc, input: Arc, + ) -> Result { + Self::try_new_with_reuse_input_batches(predicate, input, true) + } + + /// Create a FilterExec on an input using the specified kernel to create filtered batches + pub fn try_new_with_reuse_input_batches( + predicate: Arc, + input: Arc, + reuse_input_batches: bool, ) -> Result { match predicate.data_type(input.schema().as_ref())? { DataType::Boolean => { @@ -80,6 +93,7 @@ impl FilterExec { metrics: ExecutionPlanMetricsSet::new(), default_selectivity, cache, + reuse_input_batches, }) } other => { @@ -283,6 +297,7 @@ impl ExecutionPlan for FilterExec { predicate: Arc::clone(&self.predicate), input: self.input.execute(partition, context)?, baseline_metrics, + reuse_input_batches: self.reuse_input_batches, })) } @@ -345,11 +360,15 @@ struct FilterExecStream { input: SendableRecordBatchStream, /// runtime metrics recording baseline_metrics: BaselineMetrics, + /// Whether to allow an input batch to be returned unmodified in the case where + /// the predicate evaluates to true for all rows in the batch + reuse_input_batches: bool, } pub(crate) fn batch_filter( batch: &RecordBatch, predicate: &Arc, + reuse_input_batches: bool, ) -> Result { predicate .evaluate(batch) @@ -357,7 +376,39 @@ pub(crate) fn batch_filter( .and_then(|array| { Ok(match as_boolean_array(&array) { // apply filter array to record batch - Ok(filter_array) => filter_record_batch(batch, filter_array)?, + Ok(filter_array) => { + if reuse_input_batches { + filter_record_batch(batch, filter_array)? + } else { + if filter_array.true_count() == batch.num_rows() { + // special case where we just make an exact copy + let arrays: Vec = batch + .columns() + .iter() + .map(|array| { + let capacity = array.len(); + let data = array.to_data(); + let mut mutable = MutableArrayData::new( + vec![&data], + false, + capacity, + ); + mutable.extend(0, 0, capacity); + make_array(mutable.freeze()) + }) + .collect(); + let options = RecordBatchOptions::new() + .with_row_count(Some(batch.num_rows())); + RecordBatch::try_new_with_options( + batch.schema().clone(), + arrays, + &options, + )? + } else { + filter_record_batch(batch, filter_array)? + } + } + } Err(_) => { return internal_err!( "Cannot create filter_array from non-boolean predicates" @@ -379,7 +430,8 @@ impl Stream for FilterExecStream { match ready!(self.input.poll_next_unpin(cx)) { Some(Ok(batch)) => { let timer = self.baseline_metrics.elapsed_compute().timer(); - let filtered_batch = batch_filter(&batch, &self.predicate)?; + let filtered_batch = + batch_filter(&batch, &self.predicate, self.reuse_input_batches)?; timer.done(); // skip entirely filtered batches if filtered_batch.num_rows() == 0 {