diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index aecec35f2e16..e382176525f8 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -42,9 +42,9 @@ use datafusion_execution::config::SessionConfig; use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::Expr::Wildcard; use datafusion_expr::{ - avg, col, count, exists, expr, in_subquery, lit, max, out_ref_col, scalar_subquery, - sum, AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFrameBound, - WindowFrameUnits, WindowFunction, + array_agg, avg, col, count, exists, expr, in_subquery, lit, max, out_ref_col, + scalar_subquery, sum, AggregateFunction, Expr, ExprSchemable, WindowFrame, + WindowFrameBound, WindowFrameUnits, WindowFunction, }; use datafusion_physical_expr::var_provider::{VarProvider, VarType}; @@ -1340,3 +1340,23 @@ async fn use_var_provider() -> Result<()> { dataframe.collect().await?; Ok(()) } + +#[tokio::test] +async fn test_array_agg() -> Result<()> { + let df = create_test_table("test") + .await? + .aggregate(vec![], vec![array_agg(col("a"))])?; + + let results = df.collect().await?; + + let expected = vec![ + "+-------------------------------------+", + "| ARRAY_AGG(test.a) |", + "+-------------------------------------+", + "| [abcDEF, abc123, CBAdef, 123AbcDef] |", + "+-------------------------------------+", + ]; + assert_batches_eq!(expected, &results); + + Ok(()) +} diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index cb5317da4408..7c769490af29 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -136,6 +136,17 @@ pub fn sum(expr: Expr) -> Expr { )) } +/// Create an expression to represent the array_agg() aggregate function +pub fn array_agg(expr: Expr) -> Expr { + Expr::AggregateFunction(AggregateFunction::new( + aggregate_function::AggregateFunction::ArrayAgg, + vec![expr], + false, + None, + None, + )) +} + /// Create an expression to represent the avg() aggregate function pub fn avg(expr: Expr) -> Expr { Expr::AggregateFunction(AggregateFunction::new(