diff --git a/datafusion/src/physical_plan/math_expressions.rs b/datafusion/src/physical_plan/math_expressions.rs index 72b4f102f662..0e0bed2deac2 100644 --- a/datafusion/src/physical_plan/math_expressions.rs +++ b/datafusion/src/physical_plan/math_expressions.rs @@ -17,37 +17,21 @@ //! Math expressions -use arrow::array::{make_array, Array, ArrayData, Float32Array, Float64Array}; -use arrow::buffer::Buffer; -use arrow::datatypes::{DataType, ToByteSlice}; - use super::{ColumnarValue, ScalarValue}; use crate::error::{DataFusionError, Result}; - -macro_rules! compute_op { - ($ARRAY:expr, $FUNC:ident, $TYPE:ident) => {{ - let len = $ARRAY.len(); - let result = (0..len) - .map(|i| $ARRAY.value(i).$FUNC() as f64) - .collect::>(); - let data = ArrayData::new( - DataType::Float64, - len, - Some($ARRAY.null_count()), - $ARRAY.data().null_buffer().cloned(), - 0, - vec![Buffer::from(result.to_byte_slice())], - vec![], - ); - Ok(make_array(data)) - }}; -} +use arrow::array::{Float32Array, Float64Array}; +use arrow::datatypes::DataType; +use std::sync::Arc; macro_rules! downcast_compute_op { ($ARRAY:expr, $NAME:expr, $FUNC:ident, $TYPE:ident) => {{ let n = $ARRAY.as_any().downcast_ref::<$TYPE>(); match n { - Some(array) => compute_op!(array, $FUNC, $TYPE), + Some(array) => { + let res: $TYPE = + arrow::compute::kernels::arity::unary(array, |x| x.$FUNC()); + Ok(Arc::new(res)) + } _ => Err(DataFusionError::Internal(format!( "Invalid data type for {}", $NAME diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index 5c90f8ac162b..88f163b9e34c 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -631,7 +631,7 @@ async fn sqrt_f32_vs_f64() -> Result<()> { // sqrt(f32)'s plan passes let sql = "SELECT avg(sqrt(c11)) FROM aggregate_test_100"; let actual = execute(&mut ctx, sql).await; - let expected = vec![vec!["0.6584408485889435"]]; + let expected = vec![vec!["0.6584407806396484"]]; assert_eq!(actual, expected); let sql = "SELECT avg(sqrt(CAST(c11 AS double))) FROM aggregate_test_100";