Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 149 additions & 63 deletions datafusion/physical-expr/benches/case_when.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,38 @@
// specific language governing permissions and limitations
// under the License.

use arrow::array::builder::{Int32Builder, StringBuilder};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::array::builder::StringBuilder;
use arrow::array::{Array, ArrayRef, Int32Array};
use arrow::datatypes::{Field, Schema};
use arrow::record_batch::RecordBatch;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use datafusion_common::ScalarValue;
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{BinaryExpr, CaseExpr, Column, Literal};
use datafusion_physical_expr::expressions::{case, col, lit, BinaryExpr};
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use std::sync::Arc;

fn make_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
Arc::new(Column::new(name, index))
fn make_x_cmp_y(
x: &Arc<dyn PhysicalExpr>,
op: Operator,
y: i32,
) -> Arc<dyn PhysicalExpr> {
Arc::new(BinaryExpr::new(Arc::clone(x), op, lit(y)))
}

fn make_lit_i32(n: i32) -> Arc<dyn PhysicalExpr> {
Arc::new(Literal::new(ScalarValue::Int32(Some(n))))
}
/// Create a record batch with the given number of rows and columns.
/// Columns are named `c<i>` where `i` is the column index.
///
/// The minimum value for `column_count` is `3`.
/// `c0` contains incrementing int32 values
/// `c1` contains strings with one null inserted every 7 rows
/// `c2` contains strings with one null inserted every 9 rows
/// `c3` to `cn`, is present, contain unspecified int32 values
fn make_batch(row_count: usize, column_count: usize) -> RecordBatch {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it worth a comment saying that this column could is always 3 or more?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a description and assertion

assert!(column_count >= 3);

fn criterion_benchmark(c: &mut Criterion) {
// create input data
let mut c1 = Int32Builder::new();
let mut c2 = StringBuilder::new();
let mut c3 = StringBuilder::new();
for i in 0..1000 {
c1.append_value(i);
for i in 0..row_count {
if i % 7 == 0 {
c2.append_null();
} else {
Expand All @@ -51,72 +58,151 @@ fn criterion_benchmark(c: &mut Criterion) {
c3.append_value(format!("other string {i}"));
}
}
let c1 = Arc::new(c1.finish());
let c1 = Arc::new(Int32Array::from_iter_values(0..row_count as i32));
let c2 = Arc::new(c2.finish());
let c3 = Arc::new(c3.finish());
let schema = Schema::new(vec![
Field::new("c1", DataType::Int32, true),
Field::new("c2", DataType::Utf8, true),
Field::new("c3", DataType::Utf8, true),
]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2, c3]).unwrap();

// use same predicate for all benchmarks
let predicate = Arc::new(BinaryExpr::new(
make_col("c1", 0),
Operator::LtEq,
make_lit_i32(500),
));
let mut columns: Vec<ArrayRef> = vec![c1, c2, c3];
for _ in 3..column_count {
columns.push(Arc::new(Int32Array::from_iter_values(0..row_count as i32)));
}

// CASE WHEN c1 <= 500 THEN 1 ELSE 0 END
c.bench_function("case_when: scalar or scalar", |b| {
let expr = Arc::new(
CaseExpr::try_new(
None,
vec![(predicate.clone(), make_lit_i32(1))],
Some(make_lit_i32(0)),
let fields = columns
.iter()
.enumerate()
.map(|(i, c)| {
Field::new(
format!("c{}", i + 1),
c.data_type().clone(),
c.is_nullable(),
)
.unwrap(),
);
b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap()))
});
})
.collect::<Vec<_>>();

// CASE WHEN c1 <= 500 THEN c2 [ELSE NULL] END
c.bench_function("case_when: column or null", |b| {
let expr = Arc::new(
CaseExpr::try_new(None, vec![(predicate.clone(), make_col("c2", 1))], None)
let schema = Arc::new(Schema::new(fields));
RecordBatch::try_new(Arc::clone(&schema), columns).unwrap()
}

fn criterion_benchmark(c: &mut Criterion) {
run_benchmarks(c, &make_batch(8192, 3));
run_benchmarks(c, &make_batch(8192, 50));
run_benchmarks(c, &make_batch(8192, 100));
}

fn run_benchmarks(c: &mut Criterion, batch: &RecordBatch) {
let c1 = col("c1", &batch.schema()).unwrap();
let c2 = col("c2", &batch.schema()).unwrap();
let c3 = col("c3", &batch.schema()).unwrap();

// No expression, when/then/else, literal values
c.bench_function(
format!(
"case_when {}x{}: CASE WHEN c1 <= 500 THEN 1 ELSE 0 END",
batch.num_rows(),
batch.num_columns()
)
.as_str(),
|b| {
let expr = Arc::new(
case(
None,
vec![(make_x_cmp_y(&c1, Operator::LtEq, 500), lit(1))],
Some(lit(0)),
)
.unwrap(),
);
b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap()))
});
);
b.iter(|| black_box(expr.evaluate(black_box(batch)).unwrap()))
},
);

// No expression, when/then/else, column reference values
c.bench_function(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are nice labels:

Benchmarking case_when 8192x3: CASE WHEN c1 == 0 THEN 0 WHEN c1 == 1 THEN 1 ... WHEN c1 == n THEN n ELSE n + 1 EN...: Warming up for 3.0000 s

format!(
"case_when {}x{}: CASE WHEN c1 <= 500 THEN c2 ELSE c3 END",
batch.num_rows(),
batch.num_columns()
)
.as_str(),
|b| {
let expr = Arc::new(
case(
None,
vec![(make_x_cmp_y(&c1, Operator::LtEq, 500), Arc::clone(&c2))],
Some(Arc::clone(&c3)),
)
.unwrap(),
);
b.iter(|| black_box(expr.evaluate(black_box(batch)).unwrap()))
},
);

// CASE WHEN c1 <= 500 THEN c2 ELSE c3 END
c.bench_function("case_when: expr or expr", |b| {
// No expression, when/then, implicit else
c.bench_function(
format!(
"case_when {}x{}: CASE WHEN c1 <= 500 THEN c2 [ELSE NULL] END",
batch.num_rows(),
batch.num_columns()
)
.as_str(),
|b| {
let expr = Arc::new(
case(
None,
vec![(make_x_cmp_y(&c1, Operator::LtEq, 500), Arc::clone(&c2))],
None,
)
.unwrap(),
);
b.iter(|| black_box(expr.evaluate(black_box(batch)).unwrap()))
},
);

// With expression, two when/then branches
c.bench_function(
format!(
"case_when {}x{}: CASE c1 WHEN 1 THEN c2 WHEN 2 THEN c3 END",
batch.num_rows(),
batch.num_columns()
)
.as_str(),
|b| {
let expr = Arc::new(
case(
Some(Arc::clone(&c1)),
vec![(lit(1), Arc::clone(&c2)), (lit(2), Arc::clone(&c3))],
None,
)
.unwrap(),
);
b.iter(|| black_box(expr.evaluate(black_box(batch)).unwrap()))
},
);

// Many when/then branches where all are effectively reachable
c.bench_function(format!("case_when {}x{}: CASE WHEN c1 == 0 THEN 0 WHEN c1 == 1 THEN 1 ... WHEN c1 == n THEN n ELSE n + 1 END", batch.num_rows(), batch.num_columns()).as_str(), |b| {
let when_thens = (0..batch.num_rows() as i32).map(|i| (make_x_cmp_y(&c1, Operator::Eq, i), lit(i))).collect();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is a lot of when_thens!

Copy link
Contributor Author

@pepijnve pepijnve Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Intentionally so. This is a torture test benchmark to really stress the code.

The first 'all reachable' one is really a worst case scenario test case. This is intended to be able to measure improvements in the processing that's being done in each loop iteration. Filtering, scattering, etc.

The second 'few reachable' one is intended to measure the short circuiting behaviour.

let expr = Arc::new(
CaseExpr::try_new(
case(
None,
vec![(predicate.clone(), make_col("c2", 1))],
Some(make_col("c3", 2)),
when_thens,
Some(lit(batch.num_rows() as i32))
)
.unwrap(),
.unwrap(),
);
b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap()))
b.iter(|| black_box(expr.evaluate(black_box(batch)).unwrap()))
});

// CASE c1 WHEN 1 THEN c2 WHEN 2 THEN c3 END
c.bench_function("case_when: CASE expr", |b| {
// Many when/then branches where all but the first few are effectively unreachable
c.bench_function(format!("case_when {}x{}: CASE WHEN c1 < 0 THEN 0 WHEN c1 < 1000 THEN 1 ... WHEN c1 < n * 1000 THEN n ELSE n + 1 END", batch.num_rows(), batch.num_columns()).as_str(), |b| {
let when_thens = (0..batch.num_rows() as i32).map(|i| (make_x_cmp_y(&c1, Operator::Eq, i * 1000), lit(i))).collect();
let expr = Arc::new(
CaseExpr::try_new(
Some(make_col("c1", 0)),
vec![
(make_lit_i32(1), make_col("c2", 1)),
(make_lit_i32(2), make_col("c3", 2)),
],
case(
None,
when_thens,
Some(lit(batch.num_rows() as i32))
)
.unwrap(),
.unwrap(),
);
b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap()))
b.iter(|| black_box(expr.evaluate(black_box(batch)).unwrap()))
});
}

Expand Down