Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 139 additions & 2 deletions datafusion/pruning/src/pruning_predicate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1484,6 +1484,81 @@ fn build_predicate_expression(
required_columns,
unhandled_hook,
);
} else if !in_list.negated() {
// Large non-negated IN list: extract all-literal values and build a
// coarse range predicate `col >= min(list) AND col <= max(list)`.
// This won't prune as tightly as expanding every value, but it
// eliminates row groups whose entire range is outside the IN list's span.
//
// Example: `c1 IN (3, 99, 7, 42, ...)` with 50 values
// → `c1 >= 3 AND c1 <= 99`
// → prunes row groups where max < 3 or min > 99
//
// Negated lists (`NOT IN`) have no useful range: `c1 NOT IN (1..100)`
// is satisfied by any value outside [1,100], so a range predicate
// would incorrectly prune row groups that contain values outside it.
// Scan for min/max, bailing out early on any non-literal or
// incomparable value (e.g. mixed timestamp units/timezones).
let mut min_val: Option<&ScalarValue> = None;
let mut max_val: Option<&ScalarValue> = None;
let mut all_literals = true;
'scan: for e in in_list.list() {
let Some(lit) = e.downcast_ref::<phys_expr::Literal>() else {
all_literals = false;
break;
};
let v = lit.value();
if let Some(prev) = min_val {
match v.partial_cmp(prev) {
Some(std::cmp::Ordering::Less) => min_val = Some(v),
Some(_) => {}
None => {
all_literals = false;
break 'scan;
}
}
} else {
min_val = Some(v);
}
if let Some(prev) = max_val {
match v.partial_cmp(prev) {
Some(std::cmp::Ordering::Greater) => max_val = Some(v),
Some(_) => {}
None => {
all_literals = false;
break 'scan;
}
}
} else {
max_val = Some(v);
}
}
if all_literals && let (Some(min_val), Some(max_val)) = (min_val, max_val) {
let min_lit = Arc::new(phys_expr::Literal::new(min_val.clone()))
as Arc<dyn PhysicalExpr>;
let max_lit = Arc::new(phys_expr::Literal::new(max_val.clone()))
as Arc<dyn PhysicalExpr>;
let range_expr = Arc::new(phys_expr::BinaryExpr::new(
Arc::new(phys_expr::BinaryExpr::new(
Arc::clone(in_list.expr()),
Operator::GtEq,
min_lit,
)),
Operator::And,
Arc::new(phys_expr::BinaryExpr::new(
Arc::clone(in_list.expr()),
Operator::LtEq,
max_lit,
)),
)) as Arc<dyn PhysicalExpr>;
return build_predicate_expression(
&range_expr,
schema,
required_columns,
unhandled_hook,
);
}
return unhandled_hook.handle(expr);
} else {
return unhandled_hook.handle(expr);
}
Expand Down Expand Up @@ -3215,10 +3290,24 @@ mod tests {
fn row_group_predicate_in_list_to_many_values() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);
// test c1 in(1..21)
// in pruning.rs has MAX_LIST_VALUE_SIZE_REWRITE = 20, more than this value will be rewrite
// always true
// in pruning.rs has MAX_LIST_VALUE_SIZE_REWRITE = 20,
// falls back to range predicate c1 >= 1 AND c1 <= 21
let expr = col("c1").in_list((1..=21).map(lit).collect(), false);

let expected_expr = "c1_null_count@1 != row_count@2 AND c1_max@0 >= 1 AND c1_null_count@1 != row_count@2 AND c1_min@3 <= 21";
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);

Ok(())
}

#[test]
fn row_group_predicate_in_list_large_negated() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);
// NOT IN with large list: no useful range pruning, falls back to true
let expr = col("c1").in_list((1..=21).map(lit).collect(), true);

let expected_expr = "true";
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
Expand All @@ -3227,6 +3316,54 @@ mod tests {
Ok(())
}

#[test]
fn row_group_predicate_in_list_large_unsorted() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);
// Values out of order — min/max should still be computed correctly
let values = vec![lit(50), lit(3), lit(99), lit(7), lit(1)];
// Pad to exceed the threshold
let mut all_values = values;
for i in 100..116i32 {
all_values.push(lit(i));
}
let expr = col("c1").in_list(all_values, false);

let expected_expr = "c1_null_count@1 != row_count@2 AND c1_max@0 >= 1 AND c1_null_count@1 != row_count@2 AND c1_min@3 <= 115";
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);

Ok(())
}

#[test]
fn row_group_predicate_in_list_large_mixed_timestamp_units() -> Result<()> {
// An IN list whose values mix timestamp units (Second vs Millisecond) are
// incomparable via partial_cmp → falls back to true (no range pruning).
use arrow::datatypes::TimeUnit;
let schema = Schema::new(vec![Field::new(
"ts",
DataType::Timestamp(TimeUnit::Second, None),
false,
)]);

// Build a list that exceeds the threshold (21+ values) mixing
// TimestampSecond and TimestampMillisecond — partial_cmp returns None.
let mut values: Vec<Expr> = (0..20i64)
.map(|i| lit(ScalarValue::TimestampSecond(Some(i * 1000), None)))
.collect();
// Add one incomparable value (different unit)
values.push(lit(ScalarValue::TimestampMillisecond(Some(999_999), None)));

let expr = col("ts").in_list(values, false);
// Incomparable types → bail out → true
let expected_expr = "true";
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);
Ok(())
}

#[test]
fn row_group_predicate_cast_int_int() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);
Expand Down
83 changes: 83 additions & 0 deletions datafusion/sqllogictest/test_files/parquet.slt
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,89 @@ reset datafusion.execution.listing_table_ignore_subdirectory;
statement ok
reset datafusion.execution.parquet.coerce_int96;

###
### Large IN list pruning — range predicate fallback
###

statement ok
set datafusion.explain.physical_plan_only = true;

statement ok
set datafusion.execution.target_partitions = 1;

# Create a table with two row groups: [1..10] and [91..100]
statement ok
COPY (
SELECT i AS val FROM generate_series(1, 10) t(i)
UNION ALL
SELECT i AS val FROM generate_series(91, 100) t(i)
)
TO 'test_files/scratch/parquet/in_list_prune.parquet'
STORED AS PARQUET
OPTIONS ('max_row_group_size' 10);

statement ok
CREATE EXTERNAL TABLE in_list_prune
STORED AS PARQUET
LOCATION 'test_files/scratch/parquet/in_list_prune.parquet';

# Small IN list (≤ 20): exact per-value pruning — second row group [91..100] pruned
# because none of 1..5 fall in [91,100]
query I
SELECT val FROM in_list_prune WHERE val IN (1, 2, 3, 4, 5) ORDER BY val;
----
1
2
3
4
5

# Large IN list (> 20): range pruning — values span [1, 25], so row group [91..100]
# is pruned (its min=91 > 25)
query I
SELECT val FROM in_list_prune
WHERE val IN (1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25)
ORDER BY val;
----
1
2
3
4
5
6
7
8
9
10

# Large IN list where range spans both row groups: both row groups kept
# values include 5 (in RG1) and 95 (in RG2), so neither is pruned
query I
SELECT val FROM in_list_prune
WHERE val IN (5,95,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118)
ORDER BY val;
----
5
95

# Verify pruning is applied: explain shows pruning_predicate uses range
query TT
EXPLAIN SELECT val FROM in_list_prune
WHERE val IN (1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21);
----
physical_plan
01)FilterExec: val@0 IN (SET) ([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21])
02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/in_list_prune.parquet]]}, projection=[val], file_type=parquet, predicate=val@0 IN (SET) ([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]), pruning_predicate=val_null_count@1 != row_count@2 AND val_max@0 >= 1 AND val_null_count@1 != row_count@2 AND val_min@3 <= 21, required_guarantees=[val in (1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 21, 3, 4, 5, 6, 7, 8, 9)]

statement ok
DROP TABLE in_list_prune;

statement ok
reset datafusion.explain.physical_plan_only;

statement ok
set datafusion.execution.target_partitions = 4;

# Config reset
statement ok
RESET datafusion.catalog.create_default_catalog_and_schema;
Expand Down
Loading