Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions datafusion/core/tests/parquet/row_group_pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2078,3 +2078,26 @@ async fn test_limit_pruning_exceeds_fully_matched() -> datafusion_common::error:
.await;
Ok(())
}

#[tokio::test]
async fn prune_like_prefix() {
// UTF8 scenario: 2 row groups (5 rows each)
// RG1: ["a","b","c","d",NULL] => min="a", max="d"
// RG2: ["e","f","g","h","i"] => min="e", max="i"
//
// LIKE 'a%' => build_like_match produces: "a" <= max AND min <= "a" (actually min < "b")
// RG1: "a" <= "d" ✓, "a" < "b" ✓ => matched
// RG2: "a" <= "i" ✓, "e" < "b" ✗ => pruned
RowGroupPruningTest::new()
.with_scenario(Scenario::UTF8)
.with_query("SELECT * FROM t WHERE utf8 LIKE 'a%'")
.with_expected_errors(Some(0))
.with_matched_by_stats(Some(1))
.with_pruned_by_stats(Some(1))
.with_pruned_files(Some(0))
.with_matched_by_bloom_filter(Some(1))
.with_pruned_by_bloom_filter(Some(0))
.with_expected_rows(1) // only "a" matches LIKE 'a%'
.test_row_group_prune()
.await;
}
156 changes: 126 additions & 30 deletions datafusion/expr-common/src/casts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
//! to different data types, originally extracted from the optimizer's
//! unwrap_cast module to be shared between logical and physical layers.

use std::borrow::Cow;
use std::cmp::Ordering;

use arrow::datatypes::{
Expand All @@ -31,23 +32,57 @@ use arrow::datatypes::{
use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS};
use datafusion_common::ScalarValue;

/// Convert a literal value from one data type to another
pub fn try_cast_literal_to_type(
lit_value: &ScalarValue,
/// Accepts either an owned [`ScalarValue`] or a `&ScalarValue` as the literal
/// argument to [`try_cast_literal_to_type`].
///
/// `std` does not provide a blanket `From<&T> for Cow<'_, T>`, and
/// [`ScalarValue`] lives in another crate, so we cannot rely on
/// `impl Into<Cow<'_, ScalarValue>>`. This small trait fills that gap: passing
/// an owned value lets the cast *move* its string contents instead of
/// re-allocating them.
pub trait IntoScalarCow<'a> {
fn into_scalar_cow(self) -> Cow<'a, ScalarValue>;
}

impl<'a> IntoScalarCow<'a> for ScalarValue {
fn into_scalar_cow(self) -> Cow<'a, ScalarValue> {
Cow::Owned(self)
}
}

impl<'a> IntoScalarCow<'a> for &'a ScalarValue {
fn into_scalar_cow(self) -> Cow<'a, ScalarValue> {
Cow::Borrowed(self)
}
}

/// Convert a literal value from one data type to another.
///
/// Accepts either an owned [`ScalarValue`] or a `&ScalarValue`. When an owned
/// value is passed, string casts move the underlying `String` into the new
/// value instead of re-allocating it.
pub fn try_cast_literal_to_type<'a>(
lit_value: impl IntoScalarCow<'a>,
target_type: &DataType,
) -> Option<ScalarValue> {
let lit_data_type = lit_value.data_type();
if !is_supported_type(&lit_data_type) || !is_supported_type(target_type) {
let lit_value = lit_value.into_scalar_cow();
if !is_supported_type(&lit_value.data_type()) || !is_supported_type(target_type) {
return None;
}
if lit_value.is_null() {
// null value can be cast to any type of null value
return ScalarValue::try_from(target_type).ok();
}
try_cast_numeric_literal(lit_value, target_type)
.or_else(|| try_cast_string_literal(lit_value, target_type))
.or_else(|| try_cast_dictionary(lit_value, target_type))
.or_else(|| try_cast_binary(lit_value, target_type))
// The numeric/dictionary/binary casts only need a reference. The string
// cast goes last so it is free to consume `lit_value` and move the string
// out of an owned `Cow` rather than re-allocating it.
if let Some(value) = try_cast_numeric_literal(&lit_value, target_type)
.or_else(|| try_cast_dictionary(&lit_value, target_type))
.or_else(|| try_cast_binary(&lit_value, target_type))
{
return Some(value);
}
try_cast_string_literal(lit_value, target_type)
}

/// Returns true if unwrap_cast_in_comparison supports this data type
Expand Down Expand Up @@ -332,17 +367,26 @@ fn try_cast_numeric_literal(
}

fn try_cast_string_literal(
lit_value: &ScalarValue,
lit_value: Cow<'_, ScalarValue>,
target_type: &DataType,
) -> Option<ScalarValue> {
let string_value = lit_value.try_as_str()?.map(|s| s.to_string());
let scalar_value = match target_type {
DataType::Utf8 => ScalarValue::Utf8(string_value),
DataType::LargeUtf8 => ScalarValue::LargeUtf8(string_value),
DataType::Utf8View => ScalarValue::Utf8View(string_value),
// Resolve the target string variant first so we bail out for non-string
// targets before consuming the value.
let wrap: fn(Option<String>) -> ScalarValue = match target_type {
DataType::Utf8 => ScalarValue::Utf8,
DataType::LargeUtf8 => ScalarValue::LargeUtf8,
DataType::Utf8View => ScalarValue::Utf8View,
_ => return None,
};
Some(scalar_value)
// Move the string out of an owned value; clone a borrowed one.
let string_value = match lit_value {
Cow::Owned(
ScalarValue::Utf8(s) | ScalarValue::LargeUtf8(s) | ScalarValue::Utf8View(s),
) => s,
Cow::Owned(_) => return None,
Cow::Borrowed(value) => value.try_as_str()?.map(|s| s.to_string()),
};
Some(wrap(string_value))
}

/// Attempt to cast to/from a dictionary type by wrapping/unwrapping the dictionary
Expand Down Expand Up @@ -774,7 +818,7 @@ mod tests {
fn test_try_cast_literal_to_timestamp() {
// same timestamp
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampNanosecond(Some(123456), None),
ScalarValue::TimestampNanosecond(Some(123456), None),
&DataType::Timestamp(TimeUnit::Nanosecond, None),
)
.unwrap();
Expand All @@ -786,7 +830,7 @@ mod tests {

// TimestampNanosecond to TimestampMicrosecond
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampNanosecond(Some(123456), None),
ScalarValue::TimestampNanosecond(Some(123456), None),
&DataType::Timestamp(TimeUnit::Microsecond, None),
)
.unwrap();
Expand All @@ -798,7 +842,7 @@ mod tests {

// TimestampNanosecond to TimestampMillisecond
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampNanosecond(Some(123456), None),
ScalarValue::TimestampNanosecond(Some(123456), None),
&DataType::Timestamp(TimeUnit::Millisecond, None),
)
.unwrap();
Expand All @@ -807,7 +851,7 @@ mod tests {

// TimestampNanosecond to TimestampSecond
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampNanosecond(Some(123456), None),
ScalarValue::TimestampNanosecond(Some(123456), None),
&DataType::Timestamp(TimeUnit::Second, None),
)
.unwrap();
Expand All @@ -816,7 +860,7 @@ mod tests {

// TimestampMicrosecond to TimestampNanosecond
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampMicrosecond(Some(123), None),
ScalarValue::TimestampMicrosecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Nanosecond, None),
)
.unwrap();
Expand All @@ -828,7 +872,7 @@ mod tests {

// TimestampMicrosecond to TimestampMillisecond
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampMicrosecond(Some(123), None),
ScalarValue::TimestampMicrosecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Millisecond, None),
)
.unwrap();
Expand All @@ -837,15 +881,15 @@ mod tests {

// TimestampMicrosecond to TimestampSecond
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampMicrosecond(Some(123456789), None),
ScalarValue::TimestampMicrosecond(Some(123456789), None),
&DataType::Timestamp(TimeUnit::Second, None),
)
.unwrap();
assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123), None));

// TimestampMillisecond to TimestampNanosecond
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampMillisecond(Some(123), None),
ScalarValue::TimestampMillisecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Nanosecond, None),
)
.unwrap();
Expand All @@ -856,7 +900,7 @@ mod tests {

// TimestampMillisecond to TimestampMicrosecond
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampMillisecond(Some(123), None),
ScalarValue::TimestampMillisecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Microsecond, None),
)
.unwrap();
Expand All @@ -866,15 +910,15 @@ mod tests {
);
// TimestampMillisecond to TimestampSecond
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampMillisecond(Some(123456789), None),
ScalarValue::TimestampMillisecond(Some(123456789), None),
&DataType::Timestamp(TimeUnit::Second, None),
)
.unwrap();
assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123456), None));

// TimestampSecond to TimestampNanosecond
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampSecond(Some(123), None),
ScalarValue::TimestampSecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Nanosecond, None),
)
.unwrap();
Expand All @@ -885,7 +929,7 @@ mod tests {

// TimestampSecond to TimestampMicrosecond
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampSecond(Some(123), None),
ScalarValue::TimestampSecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Microsecond, None),
)
.unwrap();
Expand All @@ -896,7 +940,7 @@ mod tests {

// TimestampSecond to TimestampMillisecond
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampSecond(Some(123), None),
ScalarValue::TimestampSecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Millisecond, None),
)
.unwrap();
Expand All @@ -907,7 +951,7 @@ mod tests {

// overflow
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampSecond(Some(i64::MAX), None),
ScalarValue::TimestampSecond(Some(i64::MAX), None),
&DataType::Timestamp(TimeUnit::Millisecond, None),
)
.unwrap();
Expand All @@ -930,6 +974,58 @@ mod tests {
}
}

#[test]
fn test_try_cast_literal_to_type_owned_moves_strings() {
// an owned string value can be passed directly (no `&`) and is moved
// into the target string variant
let cases = [
(DataType::Utf8, ScalarValue::Utf8(Some("abc".to_owned()))),
(
DataType::LargeUtf8,
ScalarValue::LargeUtf8(Some("abc".to_owned())),
),
(
DataType::Utf8View,
ScalarValue::Utf8View(Some("abc".to_owned())),
),
];
for (target_type, expected) in cases {
let actual = try_cast_literal_to_type(
ScalarValue::Utf8(Some("abc".to_owned())),
&target_type,
);
assert_eq!(actual, Some(expected));
}

// owned non-string casts and dictionary wrapping still work
assert_eq!(
try_cast_literal_to_type(ScalarValue::Int32(Some(1)), &DataType::Int64),
Some(ScalarValue::Int64(Some(1)))
);
assert_eq!(
try_cast_literal_to_type(
ScalarValue::Utf8(Some("abc".to_owned())),
&DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Utf8)
),
),
Some(ScalarValue::Dictionary(
Box::new(DataType::Int32),
Box::new(ScalarValue::Utf8(Some("abc".to_owned()))),
))
);

// unsupported owned cast returns None
assert_eq!(
try_cast_literal_to_type(
ScalarValue::Utf8(Some("abc".to_owned())),
&DataType::Int32
),
None
);
}

#[test]
fn test_try_cast_to_dictionary_type() {
fn dictionary_type(t: DataType) -> DataType {
Expand Down
29 changes: 20 additions & 9 deletions datafusion/pruning/src/pruning_predicate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ use datafusion_common::{
ScalarValue, internal_datafusion_err, plan_datafusion_err, plan_err,
tree_node::{Transformed, TreeNode},
};
use datafusion_expr_common::casts::try_cast_literal_to_type;
use datafusion_expr_common::operator::Operator;
use datafusion_physical_expr::utils::{Guarantee, LiteralGuarantee};
use datafusion_physical_expr::{PhysicalExprRef, expressions as phys_expr};
Expand Down Expand Up @@ -1816,6 +1817,20 @@ fn extract_string_literal(expr: &Arc<dyn PhysicalExpr>) -> Option<&str> {
None
}

/// Wrap a string in a `Literal` whose `ScalarValue` matches `target_type`.
///
/// Returns `None` if `target_type` is not a supported cast target for a string,
/// in which case the caller should skip pruning rather than emit a literal
/// whose type does not match the column statistics. The owned `value` is moved
/// into the literal without re-allocating.
fn string_literal_as(
value: String,
target_type: &DataType,
) -> Option<Arc<dyn PhysicalExpr>> {
let scalar = try_cast_literal_to_type(ScalarValue::Utf8(Some(value)), target_type)?;
Some(Arc::new(phys_expr::Literal::new(scalar)))
}

/// Convert `column LIKE literal` where P is a constant prefix of the literal
/// to a range check on the column: `P <= column && column < P'`, where P' is the
/// lowest string after all P* strings.
Expand All @@ -1835,6 +1850,8 @@ fn build_like_match(
let min_column_expr = expr_builder.min_column_expr().ok()?;
let max_column_expr = expr_builder.max_column_expr().ok()?;
let scalar_expr = expr_builder.scalar_expr();
// Synthesized bounds must match the column type (e.g. `Utf8View`).
let target_type = expr_builder.field.data_type();
// check that the scalar is a string literal
let s = extract_string_literal(scalar_expr)?;
// ANSI SQL specifies two wildcards: % and _. % matches zero or more characters, _ matches exactly one character.
Expand All @@ -1846,18 +1863,12 @@ fn build_like_match(
}
let (lower_bound, upper_bound) = if has_wildcard {
let incremented_prefix = increment_utf8(&decoded_prefix)?;
let lower_bound_lit = Arc::new(phys_expr::Literal::new(ScalarValue::Utf8(Some(
decoded_prefix,
))));
let upper_bound_lit = Arc::new(phys_expr::Literal::new(ScalarValue::Utf8(Some(
incremented_prefix,
))));
let lower_bound_lit = string_literal_as(decoded_prefix, target_type)?;
let upper_bound_lit = string_literal_as(incremented_prefix, target_type)?;
(lower_bound_lit, upper_bound_lit)
} else {
// the like expression is a literal and can be converted into a comparison
let bound = Arc::new(phys_expr::Literal::new(ScalarValue::Utf8(Some(
decoded_prefix,
))));
let bound = string_literal_as(decoded_prefix, target_type)?;
(Arc::clone(&bound), bound)
};
let lower_bound_expr = Arc::new(phys_expr::BinaryExpr::new(
Expand Down
Loading