Skip to content

Commit

Permalink
Introduce expr builder for aggregate function (#10560)
Browse files Browse the repository at this point in the history
* expr builder

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fmt

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* build

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* upd user-guide

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fix builder

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* Consolidate example in udaf_expr.rs, simplify filter API

* Add doc strings and examples

* Add tests and checks

* Improve documentation more

* fixup

* rm spce

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

---------

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
jayzhan211 and alamb committed Jun 9, 2024
1 parent 8b1f06b commit 24a0846
Show file tree
Hide file tree
Showing 10 changed files with 467 additions and 62 deletions.
44 changes: 38 additions & 6 deletions datafusion-examples/examples/expr_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use arrow::record_batch::RecordBatch;
use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use datafusion::common::DFSchema;
use datafusion::error::Result;
use datafusion::functions_aggregate::first_last::first_value_udaf;
use datafusion::optimizer::simplify_expressions::ExprSimplifier;
use datafusion::physical_expr::{analyze, AnalysisContext, ExprBoundaries};
use datafusion::prelude::*;
Expand All @@ -32,7 +33,7 @@ use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::expr::BinaryExpr;
use datafusion_expr::interval_arithmetic::Interval;
use datafusion_expr::simplify::SimplifyContext;
use datafusion_expr::{ColumnarValue, ExprSchemable, Operator};
use datafusion_expr::{AggregateExt, ColumnarValue, ExprSchemable, Operator};

/// This example demonstrates the DataFusion [`Expr`] API.
///
Expand All @@ -44,11 +45,12 @@ use datafusion_expr::{ColumnarValue, ExprSchemable, Operator};
/// also comes with APIs for evaluation, simplification, and analysis.
///
/// The code in this example shows how to:
/// 1. Create [`Exprs`] using different APIs: [`main`]`
/// 2. Evaluate [`Exprs`] against data: [`evaluate_demo`]
/// 3. Simplify expressions: [`simplify_demo`]
/// 4. Analyze predicates for boundary ranges: [`range_analysis_demo`]
/// 5. Get the types of the expressions: [`expression_type_demo`]
/// 1. Create [`Expr`]s using different APIs: [`main`]`
/// 2. Use the fluent API to easly create complex [`Expr`]s: [`expr_fn_demo`]
/// 3. Evaluate [`Expr`]s against data: [`evaluate_demo`]
/// 4. Simplify expressions: [`simplify_demo`]
/// 5. Analyze predicates for boundary ranges: [`range_analysis_demo`]
/// 6. Get the types of the expressions: [`expression_type_demo`]
#[tokio::main]
async fn main() -> Result<()> {
// The easiest way to do create expressions is to use the
Expand All @@ -63,6 +65,9 @@ async fn main() -> Result<()> {
));
assert_eq!(expr, expr2);

// See how to build aggregate functions with the expr_fn API
expr_fn_demo()?;

// See how to evaluate expressions
evaluate_demo()?;

Expand All @@ -78,6 +83,33 @@ async fn main() -> Result<()> {
Ok(())
}

/// Datafusion's `expr_fn` API makes it easy to create [`Expr`]s for the
/// full range of expression types such as aggregates and window functions.
fn expr_fn_demo() -> Result<()> {
// Let's say you want to call the "first_value" aggregate function
let first_value = first_value_udaf();

// For example, to create the expression `FIRST_VALUE(price)`
// These expressions can be passed to `DataFrame::aggregate` and other
// APIs that take aggregate expressions.
let agg = first_value.call(vec![col("price")]);
assert_eq!(agg.to_string(), "first_value(price)");

// You can use the AggregateExt trait to create more complex aggregates
// such as `FIRST_VALUE(price FILTER quantity > 100 ORDER BY ts )
let agg = first_value
.call(vec![col("price")])
.order_by(vec![col("ts").sort(false, false)])
.filter(col("quantity").gt(lit(100)))
.build()?; // build the aggregate
assert_eq!(
agg.to_string(),
"first_value(price) FILTER (WHERE quantity > Int32(100)) ORDER BY [ts DESC NULLS LAST]"
);

Ok(())
}

/// DataFusion can also evaluate arbitrary expressions on Arrow arrays.
fn evaluate_demo() -> Result<()> {
// For example, let's say you have some integers in an array
Expand Down
190 changes: 187 additions & 3 deletions datafusion/core/tests/expr_api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,18 @@
// specific language governing permissions and limitations
// under the License.

use arrow::util::pretty::pretty_format_columns;
use arrow::util::pretty::{pretty_format_batches, pretty_format_columns};
use arrow_array::builder::{ListBuilder, StringBuilder};
use arrow_array::{ArrayRef, RecordBatch, StringArray, StructArray};
use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray, StructArray};
use arrow_schema::{DataType, Field};
use datafusion::prelude::*;
use datafusion_common::{DFSchema, ScalarValue};
use datafusion_common::{assert_contains, DFSchema, ScalarValue};
use datafusion_expr::AggregateExt;
use datafusion_functions::core::expr_ext::FieldAccessor;
use datafusion_functions_aggregate::first_last::first_value_udaf;
use datafusion_functions_aggregate::sum::sum_udaf;
use datafusion_functions_array::expr_ext::{IndexAccessor, SliceAccessor};
use sqlparser::ast::NullTreatment;
/// Tests of using and evaluating `Expr`s outside the context of a LogicalPlan
use std::sync::{Arc, OnceLock};

Expand Down Expand Up @@ -162,6 +166,183 @@ fn test_list_range() {
);
}

#[tokio::test]
async fn test_aggregate_error() {
let err = first_value_udaf()
.call(vec![col("props")])
// not a sort column
.order_by(vec![col("id")])
.build()
.unwrap_err()
.to_string();
assert_contains!(
err,
"Error during planning: ORDER BY expressions must be Expr::Sort"
);
}

#[tokio::test]
async fn test_aggregate_ext_order_by() {
let agg = first_value_udaf().call(vec![col("props")]);

// ORDER BY id ASC
let agg_asc = agg
.clone()
.order_by(vec![col("id").sort(true, true)])
.build()
.unwrap()
.alias("asc");

// ORDER BY id DESC
let agg_desc = agg
.order_by(vec![col("id").sort(false, true)])
.build()
.unwrap()
.alias("desc");

evaluate_agg_test(
agg_asc,
vec![
"+-----------------+",
"| asc |",
"+-----------------+",
"| {a: 2021-02-01} |",
"+-----------------+",
],
)
.await;

evaluate_agg_test(
agg_desc,
vec![
"+-----------------+",
"| desc |",
"+-----------------+",
"| {a: 2021-02-03} |",
"+-----------------+",
],
)
.await;
}

#[tokio::test]
async fn test_aggregate_ext_filter() {
let agg = first_value_udaf()
.call(vec![col("i")])
.order_by(vec![col("i").sort(true, true)])
.filter(col("i").is_not_null())
.build()
.unwrap()
.alias("val");

#[rustfmt::skip]
evaluate_agg_test(
agg,
vec![
"+-----+",
"| val |",
"+-----+",
"| 5 |",
"+-----+",
],
)
.await;
}

#[tokio::test]
async fn test_aggregate_ext_distinct() {
let agg = sum_udaf()
.call(vec![lit(5)])
// distinct sum should be 5, not 15
.distinct()
.build()
.unwrap()
.alias("distinct");

evaluate_agg_test(
agg,
vec![
"+----------+",
"| distinct |",
"+----------+",
"| 5 |",
"+----------+",
],
)
.await;
}

#[tokio::test]
async fn test_aggregate_ext_null_treatment() {
let agg = first_value_udaf()
.call(vec![col("i")])
.order_by(vec![col("i").sort(true, true)]);

let agg_respect = agg
.clone()
.null_treatment(NullTreatment::RespectNulls)
.build()
.unwrap()
.alias("respect");

let agg_ignore = agg
.null_treatment(NullTreatment::IgnoreNulls)
.build()
.unwrap()
.alias("ignore");

evaluate_agg_test(
agg_respect,
vec![
"+---------+",
"| respect |",
"+---------+",
"| |",
"+---------+",
],
)
.await;

evaluate_agg_test(
agg_ignore,
vec![
"+--------+",
"| ignore |",
"+--------+",
"| 5 |",
"+--------+",
],
)
.await;
}

/// Evaluates the specified expr as an aggregate and compares the result to the
/// expected result.
async fn evaluate_agg_test(expr: Expr, expected_lines: Vec<&str>) {
let batch = test_batch();

let ctx = SessionContext::new();
let group_expr = vec![];
let agg_expr = vec![expr];
let result = ctx
.read_batch(batch)
.unwrap()
.aggregate(group_expr, agg_expr)
.unwrap()
.collect()
.await
.unwrap();

let result = pretty_format_batches(&result).unwrap().to_string();
let actual_lines = result.lines().collect::<Vec<_>>();

assert_eq!(
expected_lines, actual_lines,
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
expected_lines, actual_lines
);
}

/// Converts the `Expr` to a `PhysicalExpr`, evaluates it against the provided
/// `RecordBatch` and compares the result to the expected result.
fn evaluate_expr_test(expr: Expr, expected_lines: Vec<&str>) {
Expand Down Expand Up @@ -189,6 +370,8 @@ fn test_batch() -> RecordBatch {
TEST_BATCH
.get_or_init(|| {
let string_array: ArrayRef = Arc::new(StringArray::from(vec!["1", "2", "3"]));
let int_array: ArrayRef =
Arc::new(Int64Array::from_iter(vec![Some(10), None, Some(5)]));

// { a: "2021-02-01" } { a: "2021-02-02" } { a: "2021-02-03" }
let struct_array: ArrayRef = Arc::from(StructArray::from(vec![(
Expand All @@ -209,6 +392,7 @@ fn test_batch() -> RecordBatch {

RecordBatch::try_from_iter(vec![
("id", string_array),
("i", int_array),
("props", struct_array),
("list", list_array),
])
Expand Down
15 changes: 14 additions & 1 deletion datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,19 +255,23 @@ pub enum Expr {
/// can be used. The first form consists of a series of boolean "when" expressions with
/// corresponding "then" expressions, and an optional "else" expression.
///
/// ```text
/// CASE WHEN condition THEN result
/// [WHEN ...]
/// [ELSE result]
/// END
/// ```
///
/// The second form uses a base expression and then a series of "when" clauses that match on a
/// literal value.
///
/// ```text
/// CASE expression
/// WHEN value THEN result
/// [WHEN ...]
/// [ELSE result]
/// END
/// ```
Case(Case),
/// Casts the expression to a given type and will return a runtime error if the expression cannot be cast.
/// This expression is guaranteed to have a fixed type.
Expand All @@ -279,7 +283,12 @@ pub enum Expr {
Sort(Sort),
/// Represents the call of a scalar function with a set of arguments.
ScalarFunction(ScalarFunction),
/// Represents the call of an aggregate built-in function with arguments.
/// Calls an aggregate function with arguments, and optional
/// `ORDER BY`, `FILTER`, `DISTINCT` and `NULL TREATMENT`.
///
/// See also [`AggregateExt`] to set these fields.
///
/// [`AggregateExt`]: crate::udaf::AggregateExt
AggregateFunction(AggregateFunction),
/// Represents the call of a window function with arguments.
WindowFunction(WindowFunction),
Expand Down Expand Up @@ -623,6 +632,10 @@ impl AggregateFunctionDefinition {
}

/// Aggregate function
///
/// See also [`AggregateExt`] to set these fields on `Expr`
///
/// [`AggregateExt`]: crate::udaf::AggregateExt
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct AggregateFunction {
/// Name of the function
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ pub use signature::{
ArrayFunctionSignature, Signature, TypeSignature, Volatility, TIMEZONE_WILDCARD,
};
pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF};
pub use udaf::{AggregateExt, AggregateUDF, AggregateUDFImpl, ReversedUDAF};
pub use udf::{ScalarUDF, ScalarUDFImpl};
pub use udwf::{WindowUDF, WindowUDFImpl};
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};
Expand Down
Loading

0 comments on commit 24a0846

Please sign in to comment.