diff --git a/datafusion/src/physical_plan/expressions/binary.rs b/datafusion/src/physical_plan/expressions/binary.rs index bd593fd6ecb5..4f0f6753d455 100644 --- a/datafusion/src/physical_plan/expressions/binary.rs +++ b/datafusion/src/physical_plan/expressions/binary.rs @@ -23,23 +23,22 @@ use arrow::compute::kernels::arithmetic::{ add, divide, divide_scalar, modulus, modulus_scalar, multiply, subtract, }; use arrow::compute::kernels::boolean::{and_kleene, not, or_kleene}; -use arrow::compute::kernels::comparison::{eq, gt, gt_eq, lt, lt_eq, neq}; use arrow::compute::kernels::comparison::{ - eq_bool, eq_bool_scalar, gt_bool, gt_bool_scalar, gt_eq_bool, gt_eq_bool_scalar, - lt_bool, lt_bool_scalar, lt_eq_bool, lt_eq_bool_scalar, neq_bool, neq_bool_scalar, + eq_bool_scalar, gt_bool_scalar, gt_eq_bool_scalar, lt_bool_scalar, lt_eq_bool_scalar, + neq_bool_scalar, }; use arrow::compute::kernels::comparison::{ - eq_scalar, gt_eq_scalar, gt_scalar, lt_eq_scalar, lt_scalar, neq_scalar, + eq_dyn, gt_dyn, gt_eq_dyn, lt_dyn, lt_eq_dyn, neq_dyn, }; use arrow::compute::kernels::comparison::{ - eq_utf8, gt_eq_utf8, gt_utf8, like_utf8, lt_eq_utf8, lt_utf8, neq_utf8, nlike_utf8, - regexp_is_match_utf8, + eq_scalar, gt_eq_scalar, gt_scalar, lt_eq_scalar, lt_scalar, neq_scalar, }; use arrow::compute::kernels::comparison::{ eq_utf8_scalar, gt_eq_utf8_scalar, gt_utf8_scalar, like_utf8_scalar, lt_eq_utf8_scalar, lt_utf8_scalar, neq_utf8_scalar, nlike_utf8_scalar, regexp_is_match_utf8_scalar, }; +use arrow::compute::kernels::comparison::{like_utf8, nlike_utf8, regexp_is_match_utf8}; use arrow::datatypes::{ArrowNumericType, DataType, Schema, TimeUnit}; use arrow::record_batch::RecordBatch; @@ -357,6 +356,17 @@ macro_rules! binary_array_op_scalar { }}; } +/// Calls a dynamic comparison operation from Arrow and converts the +/// error and return type appropriately +fn call_dyn_cmp(left: Arc, right: Arc, f: F) -> Result +where + F: Fn(&dyn Array, &dyn Array) -> arrow::error::Result, +{ + f(left.as_ref(), right.as_ref()) + .map(|a| Arc::new(a) as ArrayRef) + .map_err(|e| e.into()) +} + /// The binary_array_op macro includes types that extend beyond the primitive, /// such as Utf8 strings. #[macro_export] @@ -755,12 +765,12 @@ impl BinaryExpr { match &self.op { Operator::Like => binary_string_array_op!(left, right, like), Operator::NotLike => binary_string_array_op!(left, right, nlike), - Operator::Lt => binary_array_op!(left, right, lt), - Operator::LtEq => binary_array_op!(left, right, lt_eq), - Operator::Gt => binary_array_op!(left, right, gt), - Operator::GtEq => binary_array_op!(left, right, gt_eq), - Operator::Eq => binary_array_op!(left, right, eq), - Operator::NotEq => binary_array_op!(left, right, neq), + Operator::Lt => call_dyn_cmp(left, right, lt_dyn), + Operator::LtEq => call_dyn_cmp(left, right, lt_eq_dyn), + Operator::Gt => call_dyn_cmp(left, right, gt_dyn), + Operator::GtEq => call_dyn_cmp(left, right, gt_eq_dyn), + Operator::Eq => call_dyn_cmp(left, right, eq_dyn), + Operator::NotEq => call_dyn_cmp(left, right, neq_dyn), Operator::IsDistinctFrom => binary_array_op!(left, right, is_distinct_from), Operator::IsNotDistinctFrom => { binary_array_op!(left, right, is_not_distinct_from)