diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index daf092ecd4cf9..ba7811acd8f3c 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -621,12 +621,7 @@ impl SqlToRel<'_, S> { _ => { let left_expr = self.sql_to_expr(*left, schema, planner_context)?; let right_expr = self.sql_to_expr(*right, schema, planner_context)?; - plan_quantified_op( - &left_expr, - &right_expr, - &compare_op, - SetQuantifier::Any, - ) + plan_any_op(left_expr, right_expr, &compare_op) } }, SQLExpr::AllOp { @@ -645,12 +640,7 @@ impl SqlToRel<'_, S> { _ => { let left_expr = self.sql_to_expr(*left, schema, planner_context)?; let right_expr = self.sql_to_expr(*right, schema, planner_context)?; - plan_quantified_op( - &left_expr, - &right_expr, - &compare_op, - SetQuantifier::All, - ) + plan_all_op(&left_expr, &right_expr, &compare_op) } }, #[expect(deprecated)] @@ -1259,20 +1249,73 @@ impl SqlToRel<'_, S> { } } -/// Plans `needle ANY/ALL(haystack)` with proper SQL NULL semantics. +/// Builds a CASE expression that handles NULL semantics for `x ANY(arr)`: +/// +/// ```text +/// CASE +/// WHEN (arr) IS NOT NULL THEN +/// WHEN arr IS NOT NULL THEN FALSE -- empty or all-null array +/// ELSE NULL -- NULL array +/// END +/// ``` +fn any_op_with_null_handling(bound: Expr, comparison: Expr, arr: Expr) -> Result { + when(bound.is_not_null(), comparison) + .when(arr.is_not_null(), lit(false)) + .otherwise(lit(ScalarValue::Boolean(None))) +} + +/// Plans a ` ANY()` expression for non-subquery operands. +fn plan_any_op( + left_expr: Expr, + right_expr: Expr, + compare_op: &BinaryOperator, +) -> Result { + match compare_op { + BinaryOperator::Eq => Ok(array_has(right_expr, left_expr)), + BinaryOperator::NotEq => { + let min = array_min(right_expr.clone()); + let max = array_max(right_expr.clone()); + // NOT EQ is true when either bound differs from left + let comparison = min + .not_eq(left_expr.clone()) + .or(max.clone().not_eq(left_expr)); + any_op_with_null_handling(max, comparison, right_expr) + } + BinaryOperator::Gt => { + let min = array_min(right_expr.clone()); + any_op_with_null_handling(min.clone(), min.lt(left_expr), right_expr) + } + BinaryOperator::Lt => { + let max = array_max(right_expr.clone()); + any_op_with_null_handling(max.clone(), max.gt(left_expr), right_expr) + } + BinaryOperator::GtEq => { + let min = array_min(right_expr.clone()); + any_op_with_null_handling(min.clone(), min.lt_eq(left_expr), right_expr) + } + BinaryOperator::LtEq => { + let max = array_max(right_expr.clone()); + any_op_with_null_handling(max.clone(), max.gt_eq(left_expr), right_expr) + } + _ => plan_err!( + "Unsupported AnyOp: '{compare_op}', only '=', '<>', '>', '<', '>=', '<=' are supported" + ), + } +} + +/// Plans `needle ALL(haystack)` with proper SQL NULL semantics. /// /// CASE/WHEN structure: /// WHEN arr IS NULL → NULL -/// WHEN empty → vacuous_result (ANY:false, ALL:true) +/// WHEN empty → TRUE /// WHEN lhs IS NULL → NULL -/// WHEN decisive_condition → decisive_result (ANY:true match found, ALL:false violation found) +/// WHEN decisive_condition → FALSE /// WHEN has_nulls → NULL -/// ELSE → vacuous_result -fn plan_quantified_op( +/// ELSE → TRUE +fn plan_all_op( needle: &Expr, haystack: &Expr, compare_op: &BinaryOperator, - quantifier: SetQuantifier, ) -> Result { let null_arr_check = haystack.clone().is_null(); let empty_check = cardinality(haystack.clone()).eq(lit(0u64)); @@ -1282,61 +1325,40 @@ fn plan_quantified_op( let has_nulls = array_position(haystack.clone(), lit(ScalarValue::Null), lit(1i64)).is_not_null(); - let decisive_condition = match (compare_op, quantifier) { - (BinaryOperator::Eq, SetQuantifier::Any) - | (BinaryOperator::NotEq, SetQuantifier::All) => { - array_has(haystack.clone(), needle.clone()) - } - (BinaryOperator::Eq, SetQuantifier::All) - | (BinaryOperator::NotEq, SetQuantifier::Any) => { + let decisive_condition = match compare_op { + BinaryOperator::NotEq => array_has(haystack.clone(), needle.clone()), + BinaryOperator::Eq => { let all_equal = array_min(haystack.clone()) .eq(needle.clone()) .and(array_max(haystack.clone()).eq(needle.clone())); Expr::Not(Box::new(all_equal)) } - (BinaryOperator::Gt, SetQuantifier::Any) => { - needle.clone().gt(array_min(haystack.clone())) - } - (BinaryOperator::Gt, SetQuantifier::All) => { + BinaryOperator::Gt => { Expr::Not(Box::new(needle.clone().gt(array_max(haystack.clone())))) } - (BinaryOperator::Lt, SetQuantifier::Any) => { - needle.clone().lt(array_max(haystack.clone())) - } - (BinaryOperator::Lt, SetQuantifier::All) => { + BinaryOperator::Lt => { Expr::Not(Box::new(needle.clone().lt(array_min(haystack.clone())))) } - (BinaryOperator::GtEq, SetQuantifier::Any) => { - needle.clone().gt_eq(array_min(haystack.clone())) - } - (BinaryOperator::GtEq, SetQuantifier::All) => { + BinaryOperator::GtEq => { Expr::Not(Box::new(needle.clone().gt_eq(array_max(haystack.clone())))) } - (BinaryOperator::LtEq, SetQuantifier::Any) => { - needle.clone().lt_eq(array_max(haystack.clone())) - } - (BinaryOperator::LtEq, SetQuantifier::All) => { + BinaryOperator::LtEq => { Expr::Not(Box::new(needle.clone().lt_eq(array_min(haystack.clone())))) } _ => { return plan_err!( - "Unsupported {quantifier}Op: '{compare_op}', only '=', '<>', '>', '<', '>=', '<=' are supported" + "Unsupported AllOp: '{compare_op}', only '=', '<>', '>', '<', '>=', '<=' are supported" ); } }; - let (vacuous_result, decisive_result) = match quantifier { - SetQuantifier::Any => (false, true), - SetQuantifier::All => (true, false), - }; - let null_bool = lit(ScalarValue::Boolean(None)); when(null_arr_check, null_bool.clone()) - .when(empty_check, lit(vacuous_result)) + .when(empty_check, lit(true)) .when(null_lhs_check, null_bool.clone()) - .when(decisive_condition, lit(decisive_result)) + .when(decisive_condition, lit(false)) .when(has_nulls, null_bool) - .otherwise(lit(vacuous_result)) + .otherwise(lit(true)) } #[cfg(test)] diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 6c260dc019251..62912c7ff86c9 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -370,7 +370,7 @@ fn roundtrip_statement_postgres_any_array_expr() -> Result<(), DataFusionError> sql: "select left from array where 1 = any(left);", parser_dialect: GenericDialect {}, unparser_dialect: UnparserPostgreSqlDialect {}, - expected: @r#"SELECT "array"."left" FROM "array" WHERE CASE WHEN "array"."left" IS NULL THEN NULL WHEN (cardinality("array"."left") = 0) THEN false WHEN 1 IS NULL THEN NULL WHEN 1 = ANY("array"."left") THEN true WHEN array_position("array"."left", NULL, 1) IS NOT NULL THEN NULL ELSE false END"#, + expected: @r#"SELECT "array"."left" FROM "array" WHERE 1 = ANY("array"."left")"#, ); Ok(()) } diff --git a/datafusion/sqllogictest/test_files/array/array_has.slt b/datafusion/sqllogictest/test_files/array/array_has.slt index abfd697a42d54..e343c1b1fae41 100644 --- a/datafusion/sqllogictest/test_files/array/array_has.slt +++ b/datafusion/sqllogictest/test_files/array/array_has.slt @@ -517,18 +517,16 @@ logical_plan 03)----SubqueryAlias: test 04)------SubqueryAlias: t 05)--------Projection: -06)----------Filter: __common_expr_3 IS NULL AND Boolean(NULL) OR __common_expr_3 IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) IS NOT DISTINCT FROM Boolean(true) AND __common_expr_3 IS NOT NULL -07)------------Projection: substr(CAST(md5(CAST(generate_series().value AS Utf8View)) AS Utf8View), Int64(1), Int64(32)) AS __common_expr_3 -08)--------------TableScan: generate_series() projection=[value] +06)----------Filter: substr(CAST(md5(CAST(generate_series().value AS Utf8View)) AS Utf8View), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) +07)------------TableScan: generate_series() projection=[value] physical_plan 01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)] 02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] -05)--------FilterExec: __common_expr_3@0 IS NULL AND NULL OR __common_expr_3@0 IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) IS NOT DISTINCT FROM true AND __common_expr_3@0 IS NOT NULL, projection=[] -06)----------ProjectionExec: expr=[substr(md5(CAST(value@0 AS Utf8View)), 1, 32) as __common_expr_3] -07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -08)--------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] +05)--------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]), projection=[] +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] query I with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) @@ -756,26 +754,26 @@ select 5 <= any(make_array()); false # Mixed NULL + non-NULL array where no non-NULL element satisfies the condition -# These return NULL because NULLs leave the result indeterminate +# These return false (NULLs are skipped by array_min/array_max) query B select 5 > any(make_array(6, NULL)); ---- -NULL +false query B select 5 < any(make_array(3, NULL)); ---- -NULL +false query B select 5 >= any(make_array(6, NULL)); ---- -NULL +false query B select 5 <= any(make_array(3, NULL)); ---- -NULL +false # Mixed NULL + non-NULL array where a non-NULL element satisfies the condition query B @@ -806,38 +804,33 @@ true query B select 5 <> any(make_array(5, NULL)); ---- -NULL +false -# All-NULL array: all operators should return NULL (unknown comparison) +# All-NULL array: all operators should return false query B select 5 > any(make_array(NULL::INT, NULL::INT)); ---- -NULL +false query B select 5 < any(make_array(NULL::INT, NULL::INT)); ---- -NULL +false query B select 5 >= any(make_array(NULL::INT, NULL::INT)); ---- -NULL +false query B select 5 <= any(make_array(NULL::INT, NULL::INT)); ---- -NULL +false query B select 5 <> any(make_array(NULL::INT, NULL::INT)); ---- -NULL - -query B -select 5 = any(make_array(NULL::INT, NULL::INT)); ----- -NULL +false # NULL left operand: should return NULL for non-empty arrays query B @@ -897,35 +890,6 @@ select 5 <> any(NULL::INT[]); ---- NULL -query B -select 5 = any(NULL::INT[]); ----- -NULL - -# NULL = ANY with non-empty array -query B -select NULL = any(make_array(1, 2, 3)); ----- -NULL - -# = ANY with no match, no NULLs -query B -select 5 = any(make_array(1, 2, 3)); ----- -false - -# = ANY with mixed NULL (satisfying) returns TRUE -query B -select 5 = any(make_array(5, NULL)); ----- -true - -# = ANY with mixed NULL (non-satisfying): NULLs leave result indeterminate -query B -select 5 = any(make_array(1, 2, NULL)); ----- -NULL - statement ok DROP TABLE any_op_test;