diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/bool_op.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/bool_op.rs index be2b5e48a8db..f4b4c0c93215 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/bool_op.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/bool_op.rs @@ -17,6 +17,7 @@ use std::sync::Arc; +use crate::aggregate::groups_accumulator::nulls::filtered_null_mask; use arrow::array::{ArrayRef, AsArray, BooleanArray, BooleanBufferBuilder}; use arrow::buffer::BooleanBuffer; use datafusion_common::Result; @@ -135,4 +136,22 @@ where // capacity is in bits, so convert to bytes self.values.capacity() / 8 + self.null_state.size() } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + let values = values[0].as_boolean().clone(); + + let values_null_buffer_filtered = filtered_null_mask(opt_filter, &values); + let (values_buf, _) = values.into_parts(); + let values_filtered = BooleanArray::new(values_buf, values_null_buffer_filtered); + + Ok(vec![Arc::new(values_filtered)]) + } + + fn supports_convert_to_state(&self) -> bool { + true + } } diff --git a/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt b/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt index ba378f4230f8..ab1c7e78f1ff 100644 --- a/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt +++ b/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt @@ -40,6 +40,22 @@ STORED AS CSV LOCATION '../../testing/data/csv/aggregate_test_100.csv' OPTIONS ('format.has_header' 'true'); +# Table to test `bool_and()`, `bool_or()` aggregate functions +statement ok +CREATE TABLE aggregate_test_100_bool ( + v1 VARCHAR NOT NULL, + v2 BOOLEAN, + v3 BOOLEAN +); + +statement ok +INSERT INTO aggregate_test_100_bool +SELECT + c1 as v1, + CASE WHEN c2 > 3 THEN TRUE WHEN c2 > 1 THEN FALSE ELSE NULL END as v2, + CASE WHEN c1='a' OR c1='b' THEN TRUE WHEN c1='c' OR c1='d' THEN FALSE ELSE NULL END as v3 +FROM aggregate_test_100; + # Prepare settings to skip partial aggregation from the beginning statement ok set datafusion.execution.skip_partial_aggregation_probe_rows_threshold = 0; @@ -117,6 +133,33 @@ GROUP BY 1, 2 ORDER BY 1 LIMIT 5; -2117946883 d -2117946883 NULL NULL NULL -2098805236 c -2098805236 NULL NULL NULL +# FIXME: add bool_and(v3) column when issue fixed +# ISSUE https://github.com/apache/datafusion/issues/11846 +query TBBB rowsort +select v1, bool_or(v2), bool_and(v2), bool_or(v3) +from aggregate_test_100_bool +group by v1 +---- +a true false true +b true false true +c true false false +d true false false +e true false NULL + +query TBBB rowsort +select v1, + bool_or(v2) FILTER (WHERE v1 = 'a' OR v1 = 'c' OR v1 = 'e'), + bool_or(v2) FILTER (WHERE v2 = false), + bool_or(v2) FILTER (WHERE v2 = NULL) +from aggregate_test_100_bool +group by v1 +---- +a true false NULL +b NULL false NULL +c true false NULL +d NULL false NULL +e true false NULL + # Prepare settings to always skip aggregation after couple of batches statement ok set datafusion.execution.skip_partial_aggregation_probe_rows_threshold = 10; @@ -223,6 +266,32 @@ c 2.666666666667 0.425241138254 d 2.444444444444 0.541519476308 e 3 0.505440263521 +# FIXME: add bool_and(v3) column when issue fixed +# ISSUE https://github.com/apache/datafusion/issues/11846 +query TBBB rowsort +select v1, bool_or(v2), bool_and(v2), bool_or(v3) +from aggregate_test_100_bool +group by v1 +---- +a true false true +b true false true +c true false false +d true false false +e true false NULL + +query TBBB rowsort +select v1, + bool_or(v2) FILTER (WHERE v1 = 'a' OR v1 = 'c' OR v1 = 'e'), + bool_or(v2) FILTER (WHERE v2 = false), + bool_or(v2) FILTER (WHERE v2 = NULL) +from aggregate_test_100_bool +group by v1 +---- +a true false NULL +b NULL false NULL +c true false NULL +d NULL false NULL +e true false NULL # Enabling PG dialect for filtered aggregates tests statement ok @@ -377,3 +446,48 @@ ORDER BY i; statement ok DROP TABLE decimal_table; + +# Extra tests for 'bool_*()' edge cases +statement ok +set datafusion.execution.skip_partial_aggregation_probe_rows_threshold = 0; + +statement ok +set datafusion.execution.skip_partial_aggregation_probe_ratio_threshold = 0.0; + +statement ok +set datafusion.execution.target_partitions = 1; + +statement ok +set datafusion.execution.batch_size = 1; + +statement ok +create table bool_aggregate_functions ( + c1 boolean not null, + c2 boolean not null, + c3 boolean not null, + c4 boolean not null, + c5 boolean, + c6 boolean, + c7 boolean, + c8 boolean +) +as values + (true, true, false, false, true, true, null, null), + (true, false, true, false, false, null, false, null), + (true, true, false, false, null, true, false, null); + +query BBBBBBBB +SELECT bool_and(c1), bool_and(c2), bool_and(c3), bool_and(c4), bool_and(c5), bool_and(c6), bool_and(c7), bool_and(c8) FROM bool_aggregate_functions +---- +true false false false false true false NULL + +statement ok +set datafusion.execution.skip_partial_aggregation_probe_rows_threshold = 2; + +query BBBBBBBB +SELECT bool_and(c1), bool_and(c2), bool_and(c3), bool_and(c4), bool_and(c5), bool_and(c6), bool_and(c7), bool_and(c8) FROM bool_aggregate_functions +---- +true false false false false true false NULL + +statement ok +DROP TABLE aggregate_test_100_bool