Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions datafusion/core/tests/sqllogictests/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -87,18 +87,42 @@ SELECT bit_and(c5), bit_and(c6), bit_and(c7), bit_and(c8), bit_and(c9) FROM aggr
----
0 0 0 0 0

# csv_query_bit_and_distinct
query IIIII
SELECT bit_and(distinct c5), bit_and(distinct c6), bit_and(distinct c7), bit_and(distinct c8), bit_and(distinct c9) FROM aggregate_test_100
----
0 0 0 0 0

# csv_query_bit_or
query IIIII
SELECT bit_or(c5), bit_or(c6), bit_or(c7), bit_or(c8), bit_or(c9) FROM aggregate_test_100
----
-1 -1 255 65535 4294967295

# csv_query_bit_or_distinct
query IIIII
SELECT bit_or(distinct c5), bit_or(distinct c6), bit_or(distinct c7), bit_or(distinct c8), bit_or(distinct c9) FROM aggregate_test_100
----
-1 -1 255 65535 4294967295

# csv_query_bit_xor
query IIIII
SELECT bit_xor(c5), bit_xor(c6), bit_xor(c7), bit_xor(c8), bit_xor(c9) FROM aggregate_test_100
----
1632751011 5960911605712039654 148 54789 169634700

# csv_query_bit_xor_distinct (should be different than above)
query IIIII
SELECT bit_xor(distinct c5), bit_xor(distinct c6), bit_xor(distinct c7), bit_xor(distinct c8), bit_xor(distinct c9) FROM aggregate_test_100
----
1632751011 5960911605712039654 196 54789 169634700

# csv_query_bit_xor_distinct_expr
query I
SELECT bit_xor(distinct c5 % 2) FROM aggregate_test_100
----
-2

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even though, as you note,BIT_AND(DISTINCT), BIT_OR(DISTINCT), BOOL_AND(DISTINCT) and BOOL_OR(DISTINCT) are the same as their non distinct versions, I think we should still add SQL coverage so that we don't accidentally break the feature during some future refactoring

# csv_query_covariance_1
query R
SELECT covar_pop(c2, c12) FROM aggregate_test_100
Expand Down Expand Up @@ -1496,12 +1520,24 @@ SELECT bool_and(c1), bool_and(c2), bool_and(c3), bool_and(c4), bool_and(c5), boo
----
true false false false false true false NULL

# query_bool_and_distinct
query BBBBBBBB
SELECT bool_and(distinct c1), bool_and(distinct c2), bool_and(distinct c3), bool_and(distinct c4), bool_and(distinct c5), bool_and(distinct c6), bool_and(distinct c7), bool_and(distinct c8) FROM bool_aggregate_functions
----
true false false false false true false NULL

# query_bool_or
query BBBBBBBB
SELECT bool_or(c1), bool_or(c2), bool_or(c3), bool_or(c4), bool_or(c5), bool_or(c6), bool_or(c7), bool_or(c8) FROM bool_aggregate_functions
----
true true true false true true false NULL

# query_bool_or_distinct
query BBBBBBBB
SELECT bool_or(distinct c1), bool_or(distinct c2), bool_or(distinct c3), bool_or(distinct c4), bool_or(distinct c5), bool_or(distinct c6), bool_or(distinct c7), bool_or(distinct c8) FROM bool_aggregate_functions
----
true true true false true true false NULL

statement ok
create table t as
select
Expand Down
222 changes: 215 additions & 7 deletions datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! Defines physical expressions that can evaluated at runtime during query execution

use ahash::RandomState;
use std::any::Any;
use std::convert::TryFrom;
use std::sync::Arc;
Expand All @@ -32,6 +33,7 @@ use arrow::{
};
use datafusion_common::{downcast_value, DataFusionError, Result, ScalarValue};
use datafusion_expr::Accumulator;
use std::collections::HashSet;

use crate::aggregate::row_accumulator::{
is_row_accumulator_support_dtype, RowAccumulator,
Expand Down Expand Up @@ -751,6 +753,170 @@ impl RowAccumulator for BitXorRowAccumulator {
}
}

/// Expression for a BIT_XOR(DISTINCT) aggregation.
#[derive(Debug, Clone)]
pub struct DistinctBitXor {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it is possible to avoid duplicated code for distinct aggregates -- they are all basically doing the same thing 🤔

name: String,
pub data_type: DataType,
expr: Arc<dyn PhysicalExpr>,
nullable: bool,
}

impl DistinctBitXor {
/// Create a new DistinctBitXor aggregate function
pub fn new(
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
data_type: DataType,
) -> Self {
Self {
name: name.into(),
expr,
data_type,
nullable: true,
}
}
}

impl AggregateExpr for DistinctBitXor {
/// Return a reference to Any that can be used for downcasting
fn as_any(&self) -> &dyn Any {
self
}

fn field(&self) -> Result<Field> {
Ok(Field::new(
&self.name,
self.data_type.clone(),
self.nullable,
))
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(DistinctBitXorAccumulator::try_new(
&self.data_type,
)?))
}

fn state_fields(&self) -> Result<Vec<Field>> {
// State field is a List which stores items to rebuild hash set.
Ok(vec![Field::new_list(
format_state_name(&self.name, "bit_xor distinct"),
Field::new("item", self.data_type.clone(), true),
false,
)])
}

fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![self.expr.clone()]
}

fn name(&self) -> &str {
&self.name
}
}

impl PartialEq<dyn Any> for DistinctBitXor {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| {
self.name == x.name
&& self.data_type == x.data_type
&& self.nullable == x.nullable
&& self.expr.eq(&x.expr)
})
.unwrap_or(false)
}
}

#[derive(Debug)]
struct DistinctBitXorAccumulator {
hash_values: HashSet<ScalarValue, RandomState>,
data_type: DataType,
}

impl DistinctBitXorAccumulator {
pub fn try_new(data_type: &DataType) -> Result<Self> {
Ok(Self {
hash_values: HashSet::default(),
data_type: data_type.clone(),
})
}
}

impl Accumulator for DistinctBitXorAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
// 1. Stores aggregate state in `ScalarValue::List`
// 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set
let state_out = {
let mut distinct_values = Vec::new();
self.hash_values
.iter()
.for_each(|distinct_value| distinct_values.push(distinct_value.clone()));
vec![ScalarValue::new_list(
Some(distinct_values),
self.data_type.clone(),
)]
};
Ok(state_out)
}

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}

let arr = &values[0];
(0..values[0].len()).try_for_each(|index| {
if !arr.is_null(index) {
let v = ScalarValue::try_from_array(arr, index)?;
self.hash_values.insert(v);
}
Ok(())
})
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}

let arr = &states[0];
(0..arr.len()).try_for_each(|index| {
let scalar = ScalarValue::try_from_array(arr, index)?;

if let ScalarValue::List(Some(scalar), _) = scalar {
scalar.iter().for_each(|scalar| {
if !ScalarValue::is_null(scalar) {
self.hash_values.insert(scalar.clone());
}
});
} else {
return Err(DataFusionError::Internal(
"Unexpected accumulator state".into(),
));
}
Ok(())
})
}

fn evaluate(&self) -> Result<ScalarValue> {
let mut bit_xor_value = ScalarValue::try_from(&self.data_type)?;
for distinct_value in self.hash_values.iter() {
bit_xor_value = bit_xor_value.bitxor(distinct_value)?;
}
Ok(bit_xor_value)
}

fn size(&self) -> usize {
std::mem::size_of_val(self) + ScalarValue::size_of_hashset(&self.hash_values)
- std::mem::size_of_val(&self.hash_values)
+ self.data_type.size()
- std::mem::size_of_val(&self.data_type)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -813,15 +979,20 @@ mod tests {

#[test]
fn bit_xor_i32() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![4, 7, 15]));
generic_test_op!(a, DataType::Int32, BitXor, ScalarValue::from(12i32))
let a: ArrayRef = Arc::new(Int32Array::from(vec![4, 7, 4, 7, 15]));
generic_test_op!(a, DataType::Int32, BitXor, ScalarValue::from(15i32))
}

#[test]
fn bit_xor_i32_with_nulls() -> Result<()> {
let a: ArrayRef =
Arc::new(Int32Array::from(vec![Some(1), None, Some(3), Some(5)]));
generic_test_op!(a, DataType::Int32, BitXor, ScalarValue::from(7i32))
let a: ArrayRef = Arc::new(Int32Array::from(vec![
Some(1),
Some(1),
None,
Some(3),
Some(5),
]));
generic_test_op!(a, DataType::Int32, BitXor, ScalarValue::from(6i32))
}

#[test]
Expand All @@ -832,7 +1003,44 @@ mod tests {

#[test]
fn bit_xor_u32() -> Result<()> {
let a: ArrayRef = Arc::new(UInt32Array::from(vec![4_u32, 7_u32, 15_u32]));
generic_test_op!(a, DataType::UInt32, BitXor, ScalarValue::from(12u32))
let a: ArrayRef =
Arc::new(UInt32Array::from(vec![4_u32, 7_u32, 4_u32, 7_u32, 15_u32]));
generic_test_op!(a, DataType::UInt32, BitXor, ScalarValue::from(15u32))
}

#[test]
fn bit_xor_distinct_i32() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![4, 7, 4, 7, 15]));
generic_test_op!(a, DataType::Int32, DistinctBitXor, ScalarValue::from(12i32))
}

#[test]
fn bit_xor_distinct_i32_with_nulls() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![
Some(1),
Some(1),
None,
Some(3),
Some(5),
]));
generic_test_op!(a, DataType::Int32, DistinctBitXor, ScalarValue::from(7i32))
}

#[test]
fn bit_xor_distinct_i32_all_nulls() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None]));
generic_test_op!(a, DataType::Int32, DistinctBitXor, ScalarValue::Int32(None))
}

#[test]
fn bit_xor_distinct_u32() -> Result<()> {
let a: ArrayRef =
Arc::new(UInt32Array::from(vec![4_u32, 7_u32, 4_u32, 7_u32, 15_u32]));
generic_test_op!(
a,
DataType::UInt32,
DistinctBitXor,
ScalarValue::from(12u32)
)
}
}
38 changes: 9 additions & 29 deletions datafusion/physical-expr/src/aggregate/build_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,56 +65,36 @@ pub fn create_aggregate_expr(
name,
rt_type,
)),
(AggregateFunction::BitAnd, false) => Arc::new(expressions::BitAnd::new(
(AggregateFunction::BitAnd, _) => Arc::new(expressions::BitAnd::new(
input_phy_exprs[0].clone(),
name,
rt_type,
)),
(AggregateFunction::BitAnd, true) => {
return Err(DataFusionError::NotImplemented(
"BIT_AND(DISTINCT) aggregations are not available".to_string(),
));
}
(AggregateFunction::BitOr, false) => Arc::new(expressions::BitOr::new(
(AggregateFunction::BitOr, _) => Arc::new(expressions::BitOr::new(
input_phy_exprs[0].clone(),
name,
rt_type,
)),
(AggregateFunction::BitOr, true) => {
return Err(DataFusionError::NotImplemented(
"BIT_OR(DISTINCT) aggregations are not available".to_string(),
));
}
(AggregateFunction::BitXor, false) => Arc::new(expressions::BitXor::new(
input_phy_exprs[0].clone(),
name,
rt_type,
)),
(AggregateFunction::BitXor, true) => {
return Err(DataFusionError::NotImplemented(
"BIT_XOR(DISTINCT) aggregations are not available".to_string(),
));
}
(AggregateFunction::BoolAnd, false) => Arc::new(expressions::BoolAnd::new(
(AggregateFunction::BitXor, true) => Arc::new(expressions::DistinctBitXor::new(
input_phy_exprs[0].clone(),
name,
rt_type,
)),
(AggregateFunction::BoolAnd, true) => {
return Err(DataFusionError::NotImplemented(
"BOOL_AND(DISTINCT) aggregations are not available".to_string(),
));
}
(AggregateFunction::BoolOr, false) => Arc::new(expressions::BoolOr::new(
(AggregateFunction::BoolAnd, _) => Arc::new(expressions::BoolAnd::new(
input_phy_exprs[0].clone(),
name,
rt_type,
)),
(AggregateFunction::BoolOr, _) => Arc::new(expressions::BoolOr::new(
input_phy_exprs[0].clone(),
name,
rt_type,
)),
(AggregateFunction::BoolOr, true) => {
return Err(DataFusionError::NotImplemented(
"BOOL_OR(DISTINCT) aggregations are not available".to_string(),
));
}
(AggregateFunction::Sum, false) => {
let cast_to_sum_type = rt_type != input_phy_types[0];
Arc::new(expressions::Sum::new_with_pre_cast(
Expand Down
Loading