Skip to content

Commit

Permalink
Remove builtin count (#10893)
Browse files Browse the repository at this point in the history
* rm expr fn

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* rm function

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fix query and fmt

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fix example

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* Update datafusion/expr/src/test/function_stub.rs

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>

---------

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
jayzhan211 and alamb committed Jun 13, 2024
1 parent b7d2aea commit b627ca3
Show file tree
Hide file tree
Showing 28 changed files with 200 additions and 219 deletions.
6 changes: 0 additions & 6 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ use strum_macros::EnumIter;
// https://datafusion.apache.org/contributor-guide/index.html#how-to-add-a-new-aggregate-function
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)]
pub enum AggregateFunction {
/// Count
Count,
/// Minimum
Min,
/// Maximum
Expand Down Expand Up @@ -89,7 +87,6 @@ impl AggregateFunction {
pub fn name(&self) -> &str {
use AggregateFunction::*;
match self {
Count => "COUNT",
Min => "MIN",
Max => "MAX",
Avg => "AVG",
Expand Down Expand Up @@ -135,7 +132,6 @@ impl FromStr for AggregateFunction {
"bit_xor" => AggregateFunction::BitXor,
"bool_and" => AggregateFunction::BoolAnd,
"bool_or" => AggregateFunction::BoolOr,
"count" => AggregateFunction::Count,
"max" => AggregateFunction::Max,
"mean" => AggregateFunction::Avg,
"min" => AggregateFunction::Min,
Expand Down Expand Up @@ -190,7 +186,6 @@ impl AggregateFunction {
})?;

match self {
AggregateFunction::Count => Ok(DataType::Int64),
AggregateFunction::Max | AggregateFunction::Min => {
// For min and max agg function, the returned type is same as input type.
// The coerced_data_types is same with input_types.
Expand Down Expand Up @@ -249,7 +244,6 @@ impl AggregateFunction {
pub fn signature(&self) -> Signature {
// note: the physical expression must accept the type returned by this function or the execution panics.
match self {
AggregateFunction::Count => Signature::variadic_any(Volatility::Immutable),
AggregateFunction::Grouping | AggregateFunction::ArrayAgg => {
Signature::any(1, Volatility::Immutable)
}
Expand Down
13 changes: 0 additions & 13 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2135,18 +2135,6 @@ mod test {

use super::*;

#[test]
fn test_count_return_type() -> Result<()> {
let fun = find_df_window_func("count").unwrap();
let observed = fun.return_type(&[DataType::Utf8])?;
assert_eq!(DataType::Int64, observed);

let observed = fun.return_type(&[DataType::UInt64])?;
assert_eq!(DataType::Int64, observed);

Ok(())
}

#[test]
fn test_first_value_return_type() -> Result<()> {
let fun = find_df_window_func("first_value").unwrap();
Expand Down Expand Up @@ -2250,7 +2238,6 @@ mod test {
"nth_value",
"min",
"max",
"count",
"avg",
];
for name in names {
Expand Down
26 changes: 0 additions & 26 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,19 +192,6 @@ pub fn avg(expr: Expr) -> Expr {
))
}

/// Create an expression to represent the count() aggregate function
// TODO: Remove this and use `expr_fn::count` instead
pub fn count(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
aggregate_function::AggregateFunction::Count,
vec![expr],
false,
None,
None,
None,
))
}

/// Return a new expression with bitwise AND
pub fn bitwise_and(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr::new(
Expand Down Expand Up @@ -250,19 +237,6 @@ pub fn bitwise_shift_left(left: Expr, right: Expr) -> Expr {
))
}

/// Create an expression to represent the count(distinct) aggregate function
// TODO: Remove this and use `expr_fn::count_distinct` instead
pub fn count_distinct(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
aggregate_function::AggregateFunction::Count,
vec![expr],
true,
None,
None,
None,
))
}

/// Create an in_list expression
pub fn in_list(expr: Expr, list: Vec<Expr>, negated: bool) -> Expr {
Expr::InList(InList::new(Box::new(expr), list, negated))
Expand Down
4 changes: 3 additions & 1 deletion datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2965,11 +2965,13 @@ mod tests {
use super::*;
use crate::builder::LogicalTableSource;
use crate::logical_plan::table_scan;
use crate::{col, count, exists, in_subquery, lit, placeholder, GroupingSet};
use crate::{col, exists, in_subquery, lit, placeholder, GroupingSet};

use datafusion_common::tree_node::TreeNodeVisitor;
use datafusion_common::{not_impl_err, Constraint, ScalarValue};

use crate::test::function_stub::count;

fn employee_schema() -> Schema {
Schema::new(vec![
Field::new("id", DataType::Int32, false),
Expand Down
86 changes: 85 additions & 1 deletion datafusion/expr/src/test/function_stub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use crate::{
use arrow::datatypes::{
DataType, Field, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
};
use datafusion_common::{exec_err, Result};
use datafusion_common::{exec_err, not_impl_err, Result};

macro_rules! create_func {
($UDAF:ty, $AGGREGATE_UDF_FN:ident) => {
Expand Down Expand Up @@ -69,6 +69,19 @@ pub fn sum(expr: Expr) -> Expr {
))
}

create_func!(Count, count_udaf);

pub fn count(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new_udf(
count_udaf(),
vec![expr],
false,
None,
None,
None,
))
}

/// Stub `sum` used for optimizer testing
#[derive(Debug)]
pub struct Sum {
Expand Down Expand Up @@ -189,3 +202,74 @@ impl AggregateUDFImpl for Sum {
AggregateOrderSensitivity::Insensitive
}
}

/// Testing stub implementation of COUNT aggregate
pub struct Count {
signature: Signature,
aliases: Vec<String>,
}

impl std::fmt::Debug for Count {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("Count")
.field("name", &self.name())
.field("signature", &self.signature)
.finish()
}
}

impl Default for Count {
fn default() -> Self {
Self::new()
}
}

impl Count {
pub fn new() -> Self {
Self {
aliases: vec!["count".to_string()],
signature: Signature::variadic_any(Volatility::Immutable),
}
}
}

impl AggregateUDFImpl for Count {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"COUNT"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Int64)
}

fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
not_impl_err!("no impl for stub")
}

fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
not_impl_err!("no impl for stub")
}

fn aliases(&self) -> &[String] {
&self.aliases
}

fn create_groups_accumulator(
&self,
_args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
not_impl_err!("no impl for stub")
}

fn reverse_expr(&self) -> ReversedUDAF {
ReversedUDAF::Identical
}
}
2 changes: 0 additions & 2 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ pub fn coerce_types(
check_arg_count(agg_fun.name(), input_types, &signature.type_signature)?;

match agg_fun {
AggregateFunction::Count => Ok(input_types.to_vec()),
AggregateFunction::ArrayAgg => Ok(input_types.to_vec()),
AggregateFunction::Min | AggregateFunction::Max => {
// min and max support the dictionary data type
Expand Down Expand Up @@ -525,7 +524,6 @@ mod tests {
// test count, array_agg, approx_distinct, min, max.
// the coerced types is same with input types
let funs = vec![
AggregateFunction::Count,
AggregateFunction::ArrayAgg,
AggregateFunction::Min,
AggregateFunction::Max,
Expand Down
1 change: 1 addition & 0 deletions datafusion/optimizer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,6 @@ regex-syntax = "0.8.0"
[dev-dependencies]
arrow-buffer = { workspace = true }
ctor = { workspace = true }
datafusion-functions-aggregate = { workspace = true }
datafusion-sql = { workspace = true }
env_logger = { workspace = true }
42 changes: 12 additions & 30 deletions datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ use datafusion_expr::expr::{
AggregateFunction, AggregateFunctionDefinition, WindowFunction,
};
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_expr::{
aggregate_function, lit, Expr, LogicalPlan, WindowFunctionDefinition,
};
use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition};

/// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`.
///
Expand Down Expand Up @@ -56,37 +54,19 @@ fn is_wildcard(expr: &Expr) -> bool {
}

fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool {
match aggregate_function {
matches!(aggregate_function,
AggregateFunction {
func_def: AggregateFunctionDefinition::UDF(udf),
args,
..
} if udf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0]) => true,
AggregateFunction {
func_def:
AggregateFunctionDefinition::BuiltIn(
datafusion_expr::aggregate_function::AggregateFunction::Count,
),
args,
..
} if args.len() == 1 && is_wildcard(&args[0]) => true,
_ => false,
}
} if udf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0]))
}

fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool {
let args = &window_function.args;
match window_function.fun {
WindowFunctionDefinition::AggregateFunction(
aggregate_function::AggregateFunction::Count,
) if args.len() == 1 && is_wildcard(&args[0]) => true,
matches!(window_function.fun,
WindowFunctionDefinition::AggregateUDF(ref udaf)
if udaf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0]) =>
{
true
}
_ => false,
}
if udaf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0]))
}

fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
Expand Down Expand Up @@ -121,14 +101,16 @@ mod tests {
use arrow::datatypes::DataType;
use datafusion_common::ScalarValue;
use datafusion_expr::expr::Sort;
use datafusion_expr::test::function_stub::sum;
use datafusion_expr::{
col, count, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, max,
out_ref_col, scalar_subquery, wildcard, AggregateFunction, WindowFrame,
WindowFrameBound, WindowFrameUnits,
col, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, max,
out_ref_col, scalar_subquery, wildcard, WindowFrame, WindowFrameBound,
WindowFrameUnits,
};
use datafusion_functions_aggregate::count::count_udaf;
use std::sync::Arc;

use datafusion_functions_aggregate::expr_fn::{count, sum};

fn assert_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> {
assert_analyzed_plan_eq_display_indent(
Arc::new(CountWildcardRule::new()),
Expand Down Expand Up @@ -239,7 +221,7 @@ mod tests {

let plan = LogicalPlanBuilder::from(table_scan)
.window(vec![Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count),
WindowFunctionDefinition::AggregateUDF(count_udaf()),
vec![wildcard()],
vec![],
vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))],
Expand Down
10 changes: 2 additions & 8 deletions datafusion/optimizer/src/decorrelate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -432,14 +432,8 @@ fn agg_exprs_evaluation_result_on_empty_batch(
Expr::AggregateFunction(expr::AggregateFunction {
func_def, ..
}) => match func_def {
AggregateFunctionDefinition::BuiltIn(fun) => {
if matches!(fun, datafusion_expr::AggregateFunction::Count) {
Transformed::yes(Expr::Literal(ScalarValue::Int64(Some(
0,
))))
} else {
Transformed::yes(Expr::Literal(ScalarValue::Null))
}
AggregateFunctionDefinition::BuiltIn(_fun) => {
Transformed::yes(Expr::Literal(ScalarValue::Null))
}
AggregateFunctionDefinition::UDF(fun) => {
if fun.name() == "COUNT" {
Expand Down
6 changes: 4 additions & 2 deletions datafusion/optimizer/src/eliminate_group_by_constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,12 @@ mod tests {
use datafusion_common::Result;
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::{
col, count, lit, ColumnarValue, LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl,
Signature, TypeSignature,
col, lit, ColumnarValue, LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature,
TypeSignature,
};

use datafusion_functions_aggregate::expr_fn::count;

use std::sync::Arc;

#[derive(Debug)]
Expand Down

0 comments on commit b627ca3

Please sign in to comment.