Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 72 additions & 50 deletions datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -621,12 +621,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
_ => {
let left_expr = self.sql_to_expr(*left, schema, planner_context)?;
let right_expr = self.sql_to_expr(*right, schema, planner_context)?;
plan_quantified_op(
&left_expr,
&right_expr,
&compare_op,
SetQuantifier::Any,
)
plan_any_op(left_expr, right_expr, &compare_op)
}
},
SQLExpr::AllOp {
Expand All @@ -645,12 +640,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
_ => {
let left_expr = self.sql_to_expr(*left, schema, planner_context)?;
let right_expr = self.sql_to_expr(*right, schema, planner_context)?;
plan_quantified_op(
&left_expr,
&right_expr,
&compare_op,
SetQuantifier::All,
)
plan_all_op(&left_expr, &right_expr, &compare_op)
}
},
#[expect(deprecated)]
Expand Down Expand Up @@ -1259,20 +1249,73 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
}
}

/// Plans `needle <compare_op> ANY/ALL(haystack)` with proper SQL NULL semantics.
/// Builds a CASE expression that handles NULL semantics for `x <op> ANY(arr)`:
///
/// ```text
/// CASE
/// WHEN <min_or_max>(arr) IS NOT NULL THEN <comparison>
/// WHEN arr IS NOT NULL THEN FALSE -- empty or all-null array
/// ELSE NULL -- NULL array
/// END
/// ```
fn any_op_with_null_handling(bound: Expr, comparison: Expr, arr: Expr) -> Result<Expr> {
when(bound.is_not_null(), comparison)
.when(arr.is_not_null(), lit(false))
.otherwise(lit(ScalarValue::Boolean(None)))
}

/// Plans a `<left> <op> ANY(<right>)` expression for non-subquery operands.
fn plan_any_op(
left_expr: Expr,
right_expr: Expr,
compare_op: &BinaryOperator,
) -> Result<Expr> {
match compare_op {
BinaryOperator::Eq => Ok(array_has(right_expr, left_expr)),
BinaryOperator::NotEq => {
let min = array_min(right_expr.clone());
let max = array_max(right_expr.clone());
// NOT EQ is true when either bound differs from left
let comparison = min
.not_eq(left_expr.clone())
.or(max.clone().not_eq(left_expr));
any_op_with_null_handling(max, comparison, right_expr)
}
BinaryOperator::Gt => {
let min = array_min(right_expr.clone());
any_op_with_null_handling(min.clone(), min.lt(left_expr), right_expr)
}
BinaryOperator::Lt => {
let max = array_max(right_expr.clone());
any_op_with_null_handling(max.clone(), max.gt(left_expr), right_expr)
}
BinaryOperator::GtEq => {
let min = array_min(right_expr.clone());
any_op_with_null_handling(min.clone(), min.lt_eq(left_expr), right_expr)
}
BinaryOperator::LtEq => {
let max = array_max(right_expr.clone());
any_op_with_null_handling(max.clone(), max.gt_eq(left_expr), right_expr)
}
_ => plan_err!(
"Unsupported AnyOp: '{compare_op}', only '=', '<>', '>', '<', '>=', '<=' are supported"
),
}
}

/// Plans `needle <compare_op> ALL(haystack)` with proper SQL NULL semantics.
///
/// CASE/WHEN structure:
/// WHEN arr IS NULL → NULL
/// WHEN empty → vacuous_result (ANY:false, ALL:true)
/// WHEN empty → TRUE
/// WHEN lhs IS NULL → NULL
/// WHEN decisive_condition → decisive_result (ANY:true match found, ALL:false violation found)
/// WHEN decisive_condition → FALSE
/// WHEN has_nulls → NULL
/// ELSE → vacuous_result
fn plan_quantified_op(
/// ELSE → TRUE
fn plan_all_op(
needle: &Expr,
haystack: &Expr,
compare_op: &BinaryOperator,
quantifier: SetQuantifier,
) -> Result<Expr> {
let null_arr_check = haystack.clone().is_null();
let empty_check = cardinality(haystack.clone()).eq(lit(0u64));
Expand All @@ -1282,61 +1325,40 @@ fn plan_quantified_op(
let has_nulls =
array_position(haystack.clone(), lit(ScalarValue::Null), lit(1i64)).is_not_null();

let decisive_condition = match (compare_op, quantifier) {
(BinaryOperator::Eq, SetQuantifier::Any)
| (BinaryOperator::NotEq, SetQuantifier::All) => {
array_has(haystack.clone(), needle.clone())
}
(BinaryOperator::Eq, SetQuantifier::All)
| (BinaryOperator::NotEq, SetQuantifier::Any) => {
let decisive_condition = match compare_op {
BinaryOperator::NotEq => array_has(haystack.clone(), needle.clone()),
BinaryOperator::Eq => {
let all_equal = array_min(haystack.clone())
.eq(needle.clone())
.and(array_max(haystack.clone()).eq(needle.clone()));
Expr::Not(Box::new(all_equal))
}
(BinaryOperator::Gt, SetQuantifier::Any) => {
needle.clone().gt(array_min(haystack.clone()))
}
(BinaryOperator::Gt, SetQuantifier::All) => {
BinaryOperator::Gt => {
Expr::Not(Box::new(needle.clone().gt(array_max(haystack.clone()))))
}
(BinaryOperator::Lt, SetQuantifier::Any) => {
needle.clone().lt(array_max(haystack.clone()))
}
(BinaryOperator::Lt, SetQuantifier::All) => {
BinaryOperator::Lt => {
Expr::Not(Box::new(needle.clone().lt(array_min(haystack.clone()))))
}
(BinaryOperator::GtEq, SetQuantifier::Any) => {
needle.clone().gt_eq(array_min(haystack.clone()))
}
(BinaryOperator::GtEq, SetQuantifier::All) => {
BinaryOperator::GtEq => {
Expr::Not(Box::new(needle.clone().gt_eq(array_max(haystack.clone()))))
}
(BinaryOperator::LtEq, SetQuantifier::Any) => {
needle.clone().lt_eq(array_max(haystack.clone()))
}
(BinaryOperator::LtEq, SetQuantifier::All) => {
BinaryOperator::LtEq => {
Expr::Not(Box::new(needle.clone().lt_eq(array_min(haystack.clone()))))
}
_ => {
return plan_err!(
"Unsupported {quantifier}Op: '{compare_op}', only '=', '<>', '>', '<', '>=', '<=' are supported"
"Unsupported AllOp: '{compare_op}', only '=', '<>', '>', '<', '>=', '<=' are supported"
);
}
};

let (vacuous_result, decisive_result) = match quantifier {
SetQuantifier::Any => (false, true),
SetQuantifier::All => (true, false),
};

let null_bool = lit(ScalarValue::Boolean(None));
when(null_arr_check, null_bool.clone())
.when(empty_check, lit(vacuous_result))
.when(empty_check, lit(true))
.when(null_lhs_check, null_bool.clone())
.when(decisive_condition, lit(decisive_result))
.when(decisive_condition, lit(false))
.when(has_nulls, null_bool)
.otherwise(lit(vacuous_result))
.otherwise(lit(true))
}

#[cfg(test)]
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ fn roundtrip_statement_postgres_any_array_expr() -> Result<(), DataFusionError>
sql: "select left from array where 1 = any(left);",
parser_dialect: GenericDialect {},
unparser_dialect: UnparserPostgreSqlDialect {},
expected: @r#"SELECT "array"."left" FROM "array" WHERE CASE WHEN "array"."left" IS NULL THEN NULL WHEN (cardinality("array"."left") = 0) THEN false WHEN 1 IS NULL THEN NULL WHEN 1 = ANY("array"."left") THEN true WHEN array_position("array"."left", NULL, 1) IS NOT NULL THEN NULL ELSE false END"#,
expected: @r#"SELECT "array"."left" FROM "array" WHERE 1 = ANY("array"."left")"#,
);
Ok(())
}
Expand Down
70 changes: 17 additions & 53 deletions datafusion/sqllogictest/test_files/array/array_has.slt
Original file line number Diff line number Diff line change
Expand Up @@ -517,18 +517,16 @@ logical_plan
03)----SubqueryAlias: test
04)------SubqueryAlias: t
05)--------Projection:
06)----------Filter: __common_expr_3 IS NULL AND Boolean(NULL) OR __common_expr_3 IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) IS NOT DISTINCT FROM Boolean(true) AND __common_expr_3 IS NOT NULL
07)------------Projection: substr(CAST(md5(CAST(generate_series().value AS Utf8View)) AS Utf8View), Int64(1), Int64(32)) AS __common_expr_3
08)--------------TableScan: generate_series() projection=[value]
06)----------Filter: substr(CAST(md5(CAST(generate_series().value AS Utf8View)) AS Utf8View), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")])
07)------------TableScan: generate_series() projection=[value]
physical_plan
01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)]
02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))]
03)----CoalescePartitionsExec
04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
05)--------FilterExec: __common_expr_3@0 IS NULL AND NULL OR __common_expr_3@0 IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) IS NOT DISTINCT FROM true AND __common_expr_3@0 IS NOT NULL, projection=[]
06)----------ProjectionExec: expr=[substr(md5(CAST(value@0 AS Utf8View)), 1, 32) as __common_expr_3]
07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
08)--------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192]
05)--------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]), projection=[]
06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
07)------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192]

query I
with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
Expand Down Expand Up @@ -756,26 +754,26 @@ select 5 <= any(make_array());
false

# Mixed NULL + non-NULL array where no non-NULL element satisfies the condition
# These return NULL because NULLs leave the result indeterminate
# These return false (NULLs are skipped by array_min/array_max)
query B
select 5 > any(make_array(6, NULL));
----
NULL
false

query B
select 5 < any(make_array(3, NULL));
----
NULL
false

query B
select 5 >= any(make_array(6, NULL));
----
NULL
false

query B
select 5 <= any(make_array(3, NULL));
----
NULL
false

# Mixed NULL + non-NULL array where a non-NULL element satisfies the condition
query B
Expand Down Expand Up @@ -806,38 +804,33 @@ true
query B
select 5 <> any(make_array(5, NULL));
----
NULL
false

# All-NULL array: all operators should return NULL (unknown comparison)
# All-NULL array: all operators should return false
query B
select 5 > any(make_array(NULL::INT, NULL::INT));
----
NULL
false

query B
select 5 < any(make_array(NULL::INT, NULL::INT));
----
NULL
false

query B
select 5 >= any(make_array(NULL::INT, NULL::INT));
----
NULL
false

query B
select 5 <= any(make_array(NULL::INT, NULL::INT));
----
NULL
false

query B
select 5 <> any(make_array(NULL::INT, NULL::INT));
----
NULL

query B
select 5 = any(make_array(NULL::INT, NULL::INT));
----
NULL
false

# NULL left operand: should return NULL for non-empty arrays
query B
Expand Down Expand Up @@ -897,35 +890,6 @@ select 5 <> any(NULL::INT[]);
----
NULL

query B
select 5 = any(NULL::INT[]);
----
NULL

# NULL = ANY with non-empty array
query B
select NULL = any(make_array(1, 2, 3));
----
NULL

# = ANY with no match, no NULLs
query B
select 5 = any(make_array(1, 2, 3));
----
false

# = ANY with mixed NULL (satisfying) returns TRUE
query B
select 5 = any(make_array(5, NULL));
----
true

# = ANY with mixed NULL (non-satisfying): NULLs leave result indeterminate
query B
select 5 = any(make_array(1, 2, NULL));
----
NULL

statement ok
DROP TABLE any_op_test;

Expand Down
Loading