From f678d375d364871abcce60a8121b13c03a218650 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Wed, 31 May 2023 09:43:40 +0300 Subject: [PATCH] Bug fix, first multiple batches. Add unit test --- .../physical-expr/src/aggregate/first_last.rs | 54 +++++++++++++++++-- 1 file changed, 51 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index f65360c75199c..5dd9620ce0a68 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -112,25 +112,35 @@ impl PartialEq for FirstValue { #[derive(Debug)] struct FirstValueAccumulator { first: ScalarValue, + // At the beginning, `is_set` is `false`, this means `first` is not seen yet. + // Once we see (`is_set=true`) first value, we do not update `first`. + is_set: bool, } impl FirstValueAccumulator { /// Creates a new `FirstValueAccumulator` for the given `data_type`. pub fn try_new(data_type: &DataType) -> Result { - ScalarValue::try_from(data_type).map(|value| Self { first: value }) + ScalarValue::try_from(data_type).map(|value| Self { + first: value, + is_set: false, + }) } } impl Accumulator for FirstValueAccumulator { fn state(&self) -> Result> { - Ok(vec![self.first.clone()]) + Ok(vec![ + self.first.clone(), + ScalarValue::Boolean(Some(self.is_set)), + ]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { // If we have seen first value, we shouldn't update it let values = &values[0]; - if !values.is_empty() { + if !values.is_empty() && !self.is_set { self.first = ScalarValue::try_from_array(values, 0)?; + self.is_set = true; } Ok(()) } @@ -270,3 +280,41 @@ impl Accumulator for LastValueAccumulator { std::mem::size_of_val(self) - std::mem::size_of_val(&self.last) + self.last.size() } } + +#[cfg(test)] +mod tests { + use crate::aggregate::first_last::{FirstValueAccumulator, LastValueAccumulator}; + use arrow_array::{ArrayRef, Int64Array}; + use arrow_schema::DataType; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::Accumulator; + use std::sync::Arc; + + #[test] + fn test_first_last_value_value() -> Result<()> { + let mut first_accumulator = FirstValueAccumulator::try_new(&DataType::Int64)?; + let mut last_accumulator = LastValueAccumulator::try_new(&DataType::Int64)?; + // first value in the tuple is start of the range (inclusive), + // second value in the tuple is end of the range (exclusive) + let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)]; + // create 3 ArrayRefs between each interval e.g from 0 to 9, 1 to 10, 2 to 12 + let arrs = ranges + .into_iter() + .map(|(start, end)| { + Arc::new(Int64Array::from((start..end).collect::>())) as ArrayRef + }) + .collect::>(); + for arr in arrs { + // Once first_value is set, accumulator should remember it. + // It shouldn't update first_value for each new batch + first_accumulator.update_batch(&[arr.clone()])?; + // last_value should be updated for each new batch. + last_accumulator.update_batch(&[arr])?; + } + // First Value comes from the first value of the first batch which is 0 + assert_eq!(first_accumulator.evaluate()?, ScalarValue::Int64(Some(0))); + // Last value comes from the last value of the last batch which is 12 + assert_eq!(last_accumulator.evaluate()?, ScalarValue::Int64(Some(12))); + Ok(()) + } +}