diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index b17e4294a1ef..760952d94815 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -41,8 +41,6 @@ pub enum AggregateFunction { ArrayAgg, /// N'th value in a group according to some ordering NthValue, - /// Grouping - Grouping, } impl AggregateFunction { @@ -53,7 +51,6 @@ impl AggregateFunction { Max => "MAX", ArrayAgg => "ARRAY_AGG", NthValue => "NTH_VALUE", - Grouping => "GROUPING", } } } @@ -73,8 +70,6 @@ impl FromStr for AggregateFunction { "min" => AggregateFunction::Min, "array_agg" => AggregateFunction::ArrayAgg, "nth_value" => AggregateFunction::NthValue, - // other - "grouping" => AggregateFunction::Grouping, _ => { return plan_err!("There is no built-in function named {name}"); } @@ -119,7 +114,6 @@ impl AggregateFunction { coerced_data_types[0].clone(), input_expr_nullable[0], )))), - AggregateFunction::Grouping => Ok(DataType::Int32), AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()), } } @@ -130,7 +124,6 @@ impl AggregateFunction { match self { AggregateFunction::Max | AggregateFunction::Min => Ok(true), AggregateFunction::ArrayAgg => Ok(false), - AggregateFunction::Grouping => Ok(true), AggregateFunction::NthValue => Ok(true), } } @@ -141,9 +134,7 @@ impl AggregateFunction { pub fn signature(&self) -> Signature { // note: the physical expression must accept the type returned by this function or the execution panics. match self { - AggregateFunction::Grouping | AggregateFunction::ArrayAgg => { - Signature::any(1, Volatility::Immutable) - } + AggregateFunction::ArrayAgg => Signature::any(1, Volatility::Immutable), AggregateFunction::Min | AggregateFunction::Max => { let valid = STRINGS .iter() diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 36a789d5b0ee..0f7464b96b3e 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -102,7 +102,6 @@ pub fn coerce_types( get_min_max_result_type(input_types) } AggregateFunction::NthValue => Ok(input_types.to_vec()), - AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]), } } diff --git a/datafusion/functions-aggregate/src/grouping.rs b/datafusion/functions-aggregate/src/grouping.rs new file mode 100644 index 000000000000..6fb7c3800f4e --- /dev/null +++ b/datafusion/functions-aggregate/src/grouping.rs @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use std::any::Any; +use std::fmt; + +use arrow::datatypes::DataType; +use arrow::datatypes::Field; +use datafusion_common::{not_impl_err, Result}; +use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::function::StateFieldsArgs; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; + +make_udaf_expr_and_func!( + Grouping, + grouping, + expression, + "Returns 1 if the data is aggregated across the specified column or 0 for not aggregated in the result set.", + grouping_udaf +); + +pub struct Grouping { + signature: Signature, +} + +impl fmt::Debug for Grouping { + fn fmt(&self, f: &mut std::fmt::Formatter) -> fmt::Result { + f.debug_struct("Grouping") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for Grouping { + fn default() -> Self { + Self::new() + } +} + +impl Grouping { + /// Create a new GROUPING aggregate function. + pub fn new() -> Self { + Self { + signature: Signature::any(1, Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for Grouping { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "grouping" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + Ok(vec![Field::new( + format_state_name(args.name, "grouping"), + DataType::Int32, + true, + )]) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + not_impl_err!( + "physical plan is not yet implemented for GROUPING aggregate function" + ) + } +} diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 063e6000b4c9..fc485a284ab4 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -73,6 +73,7 @@ pub mod approx_percentile_cont_with_weight; pub mod average; pub mod bit_and_or_xor; pub mod bool_and_or; +pub mod grouping; pub mod string_agg; use crate::approx_percentile_cont::approx_percentile_cont_udaf; @@ -102,6 +103,7 @@ pub mod expr_fn { pub use super::covariance::covar_samp; pub use super::first_last::first_value; pub use super::first_last::last_value; + pub use super::grouping::grouping; pub use super::median::median; pub use super::regr::regr_avgx; pub use super::regr::regr_avgy; @@ -154,6 +156,7 @@ pub fn all_default_aggregate_functions() -> Vec> { bool_and_or::bool_and_udaf(), bool_and_or::bool_or_udaf(), average::avg_udaf(), + grouping::grouping_udaf(), ] } diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 169418d2daa0..adbbbd3e631e 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -60,11 +60,6 @@ pub fn create_aggregate_expr( .collect::>>()?; let input_phy_exprs = input_phy_exprs.to_vec(); Ok(match (fun, distinct) { - (AggregateFunction::Grouping, _) => Arc::new(expressions::Grouping::new( - input_phy_exprs[0].clone(), - name, - data_type, - )), (AggregateFunction::ArrayAgg, false) => { let expr = input_phy_exprs[0].clone(); let nullable = expr.nullable(input_schema)?; diff --git a/datafusion/physical-expr/src/aggregate/grouping.rs b/datafusion/physical-expr/src/aggregate/grouping.rs deleted file mode 100644 index d43bcd5c7091..000000000000 --- a/datafusion/physical-expr/src/aggregate/grouping.rs +++ /dev/null @@ -1,103 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions that can evaluated at runtime during query execution - -use std::any::Any; -use std::sync::Arc; - -use crate::aggregate::utils::down_cast_any_ref; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::datatypes::DataType; -use arrow::datatypes::Field; -use datafusion_common::{not_impl_err, Result}; -use datafusion_expr::Accumulator; - -use crate::expressions::format_state_name; - -/// GROUPING aggregate expression -/// Returns the amount of non-null values of the given expression. -#[derive(Debug)] -pub struct Grouping { - name: String, - data_type: DataType, - nullable: bool, - expr: Arc, -} - -impl Grouping { - /// Create a new GROUPING aggregate function. - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - expr, - data_type, - nullable: true, - } - } -} - -impl AggregateExpr for Grouping { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Int32, self.nullable)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - format_state_name(&self.name, "grouping"), - DataType::Int32, - true, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn create_accumulator(&self) -> Result> { - not_impl_err!( - "physical plan is not yet implemented for GROUPING aggregate function" - ) - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for Grouping { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.nullable == x.nullable - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index ca5bf3293442..f0de7446f6f1 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -20,7 +20,6 @@ pub use datafusion_physical_expr_common::aggregate::AggregateExpr; pub(crate) mod array_agg; pub(crate) mod array_agg_distinct; pub(crate) mod array_agg_ordered; -pub(crate) mod grouping; pub(crate) mod nth_value; #[macro_use] pub(crate) mod min_max; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index b87b6daa64c7..1f2c955ad07e 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -38,7 +38,6 @@ pub use crate::aggregate::array_agg::ArrayAgg; pub use crate::aggregate::array_agg_distinct::DistinctArrayAgg; pub use crate::aggregate::array_agg_ordered::OrderSensitiveArrayAgg; pub use crate::aggregate::build_in::create_aggregate_expr; -pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::min_max::{Max, MaxAccumulator, Min, MinAccumulator}; pub use crate::aggregate::nth_value::NthValueAgg; pub use crate::aggregate::stats::StatsType; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 7f4d6b9d927e..ce6c0c53c3fc 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -483,7 +483,7 @@ enum AggregateFunction { // APPROX_PERCENTILE_CONT = 14; // APPROX_MEDIAN = 15; // APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16; - GROUPING = 17; + // GROUPING = 17; // MEDIAN = 18; // BIT_AND = 19; // BIT_OR = 20; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 33cd634c4aad..347654e52b73 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -535,7 +535,6 @@ impl serde::Serialize for AggregateFunction { Self::Min => "MIN", Self::Max => "MAX", Self::ArrayAgg => "ARRAY_AGG", - Self::Grouping => "GROUPING", Self::NthValueAgg => "NTH_VALUE_AGG", }; serializer.serialize_str(variant) @@ -551,7 +550,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "MIN", "MAX", "ARRAY_AGG", - "GROUPING", "NTH_VALUE_AGG", ]; @@ -596,7 +594,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "MIN" => Ok(AggregateFunction::Min), "MAX" => Ok(AggregateFunction::Max), "ARRAY_AGG" => Ok(AggregateFunction::ArrayAgg), - "GROUPING" => Ok(AggregateFunction::Grouping), "NTH_VALUE_AGG" => Ok(AggregateFunction::NthValueAgg), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 83b8b738c4f4..c74f172482b7 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1935,7 +1935,7 @@ pub enum AggregateFunction { /// APPROX_PERCENTILE_CONT = 14; /// APPROX_MEDIAN = 15; /// APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16; - Grouping = 17, + /// GROUPING = 17; /// MEDIAN = 18; /// BIT_AND = 19; /// BIT_OR = 20; @@ -1964,7 +1964,6 @@ impl AggregateFunction { AggregateFunction::Min => "MIN", AggregateFunction::Max => "MAX", AggregateFunction::ArrayAgg => "ARRAY_AGG", - AggregateFunction::Grouping => "GROUPING", AggregateFunction::NthValueAgg => "NTH_VALUE_AGG", } } @@ -1974,7 +1973,6 @@ impl AggregateFunction { "MIN" => Some(Self::Min), "MAX" => Some(Self::Max), "ARRAY_AGG" => Some(Self::ArrayAgg), - "GROUPING" => Some(Self::Grouping), "NTH_VALUE_AGG" => Some(Self::NthValueAgg), _ => None, } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 609cbc1a286b..f4fb69280436 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -145,7 +145,6 @@ impl From for AggregateFunction { protobuf::AggregateFunction::Min => Self::Min, protobuf::AggregateFunction::Max => Self::Max, protobuf::AggregateFunction::ArrayAgg => Self::ArrayAgg, - protobuf::AggregateFunction::Grouping => Self::Grouping, protobuf::AggregateFunction::NthValueAgg => Self::NthValue, } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index ccc64119c8a1..7570040a1d08 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -117,7 +117,6 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Min => Self::Min, AggregateFunction::Max => Self::Max, AggregateFunction::ArrayAgg => Self::ArrayAgg, - AggregateFunction::Grouping => Self::Grouping, AggregateFunction::NthValue => Self::NthValueAgg, } } @@ -378,7 +377,6 @@ pub fn serialize_expr( AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, AggregateFunction::Min => protobuf::AggregateFunction::Min, AggregateFunction::Max => protobuf::AggregateFunction::Max, - AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, AggregateFunction::NthValue => { protobuf::AggregateFunction::NthValueAgg } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 375361261952..23cdc666e701 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -24,9 +24,9 @@ use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ ArrayAgg, BinaryExpr, CaseExpr, CastExpr, Column, CumeDist, DistinctArrayAgg, - Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, - NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, RankType, - RowNumber, TryCastExpr, WindowShift, + InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, + NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, RankType, RowNumber, + TryCastExpr, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -244,9 +244,7 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { let aggr_expr = expr.as_any(); let mut distinct = false; - let inner = if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Grouping - } else if aggr_expr.downcast_ref::().is_some() { + let inner = if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::ArrayAgg } else if aggr_expr.downcast_ref::().is_some() { distinct = true; diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index fe3da3d05854..5fc3a9a8a197 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -43,8 +43,8 @@ use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::count::count_udaf; use datafusion::functions_aggregate::expr_fn::{ approx_median, approx_percentile_cont, approx_percentile_cont_with_weight, count, - count_distinct, covar_pop, covar_samp, first_value, median, stddev, stddev_pop, sum, - var_pop, var_sample, + count_distinct, covar_pop, covar_samp, first_value, grouping, median, stddev, + stddev_pop, sum, var_pop, var_sample, }; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; @@ -695,6 +695,7 @@ async fn roundtrip_expr_api() -> Result<()> { approx_median(lit(2)), approx_percentile_cont(lit(2), lit(0.5)), approx_percentile_cont_with_weight(lit(2), lit(1), lit(0.5)), + grouping(lit(1)), bit_and(lit(2)), bit_or(lit(2)), bit_xor(lit(2)), diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index ec623a956186..aca0d040bb8d 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -37,10 +37,10 @@ use datafusion_sql::{ planner::{ParserOptions, SqlToRel}, }; -use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::{ approx_median::approx_median_udaf, count::count_udaf, }; +use datafusion_functions_aggregate::{average::avg_udaf, grouping::grouping_udaf}; use rstest::rstest; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; @@ -2693,7 +2693,8 @@ fn logical_plan_with_dialect_and_options( .with_udaf(sum_udaf()) .with_udaf(approx_median_udaf()) .with_udaf(count_udaf()) - .with_udaf(avg_udaf()); + .with_udaf(avg_udaf()) + .with_udaf(grouping_udaf()); let planner = SqlToRel::new_with_options(&context, options); let result = DFParser::parse_sql_with_dialect(sql, dialect); @@ -3097,8 +3098,8 @@ fn aggregate_with_rollup() { fn aggregate_with_rollup_with_grouping() { let sql = "SELECT id, state, age, grouping(state), grouping(age), grouping(state) + grouping(age), count(*) \ FROM person GROUP BY id, ROLLUP (state, age)"; - let expected = "Projection: person.id, person.state, person.age, GROUPING(person.state), GROUPING(person.age), GROUPING(person.state) + GROUPING(person.age), count(*)\ - \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.state, person.age))]], aggr=[[GROUPING(person.state), GROUPING(person.age), count(*)]]\ + let expected = "Projection: person.id, person.state, person.age, grouping(person.state), grouping(person.age), grouping(person.state) + grouping(person.age), count(*)\ + \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.state, person.age))]], aggr=[[grouping(person.state), grouping(person.age), count(*)]]\ \n TableScan: person"; quick_test(sql, expected); } @@ -3118,9 +3119,9 @@ fn rank_partition_grouping() { from person group by rollup(state, last_name)"; - let expected = "Projection: sum(person.age) AS total_sum, person.state, person.last_name, GROUPING(person.state) + GROUPING(person.last_name) AS x, RANK() PARTITION BY [GROUPING(person.state) + GROUPING(person.last_name), CASE WHEN GROUPING(person.last_name) = Int64(0) THEN person.state END] ORDER BY [sum(person.age) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS the_rank\ - \n WindowAggr: windowExpr=[[RANK() PARTITION BY [GROUPING(person.state) + GROUPING(person.last_name), CASE WHEN GROUPING(person.last_name) = Int64(0) THEN person.state END] ORDER BY [sum(person.age) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n Aggregate: groupBy=[[ROLLUP (person.state, person.last_name)]], aggr=[[sum(person.age), GROUPING(person.state), GROUPING(person.last_name)]]\ + let expected = "Projection: sum(person.age) AS total_sum, person.state, person.last_name, grouping(person.state) + grouping(person.last_name) AS x, RANK() PARTITION BY [grouping(person.state) + grouping(person.last_name), CASE WHEN grouping(person.last_name) = Int64(0) THEN person.state END] ORDER BY [sum(person.age) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS the_rank\ + \n WindowAggr: windowExpr=[[RANK() PARTITION BY [grouping(person.state) + grouping(person.last_name), CASE WHEN grouping(person.last_name) = Int64(0) THEN person.state END] ORDER BY [sum(person.age) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + \n Aggregate: groupBy=[[ROLLUP (person.state, person.last_name)]], aggr=[[sum(person.age), grouping(person.state), grouping(person.last_name)]]\ \n TableScan: person"; quick_test(sql, expected); }