From a78598d702bb3456333ee028c82436b615f065d9 Mon Sep 17 00:00:00 2001 From: buraksenb Date: Wed, 16 Oct 2024 18:56:37 +0300 Subject: [PATCH 1/9] removed last uses of make_function_scalar_inputs --- datafusion/functions/src/math/log.rs | 24 ++++++------- datafusion/functions/src/math/round.rs | 50 +++++++++++++------------- 2 files changed, 38 insertions(+), 36 deletions(-) diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index 07ff8e2166ff..b6af6057b81b 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -22,8 +22,8 @@ use std::sync::{Arc, OnceLock}; use super::power::PowerFunc; -use arrow::array::{ArrayRef, Float32Array, Float64Array}; -use arrow::datatypes::DataType; +use arrow::array::{ArrayRef, AsArray, Float32Array, Float64Array}; +use arrow::datatypes::{DataType, Float32Type, Float64Type}; use datafusion_common::{ exec_err, internal_err, plan_datafusion_err, plan_err, DataFusionError, Result, ScalarValue, @@ -139,11 +139,11 @@ impl ScalarUDFImpl for LogFunc { // note in f64::log params order is different than in sql. e.g in sql log(base, x) == f64::log(x, base) let arr: ArrayRef = match args[0].data_type() { DataType::Float64 => match base { - ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => { - Arc::new(make_function_scalar_inputs!(x, "x", Float64Array, { - |value: f64| f64::log(value, base as f64) - })) - } + ColumnarValue::Scalar(ScalarValue::Float64(Some(base))) => Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float64Type>(|value: f64| f64::log(value, base)), + ), ColumnarValue::Array(base) => Arc::new(make_function_inputs2!( x, base, @@ -158,11 +158,11 @@ impl ScalarUDFImpl for LogFunc { }, DataType::Float32 => match base { - ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => { - Arc::new(make_function_scalar_inputs!(x, "x", Float32Array, { - |value: f32| f32::log(value, base) - })) - } + ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float32Type>(|value: f32| f32::log(value, base)), + ), ColumnarValue::Array(base) => Arc::new(make_function_inputs2!( x, base, diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index ae8eee0dbba2..27aa5b0078ce 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -20,10 +20,10 @@ use std::sync::{Arc, OnceLock}; use crate::utils::make_scalar_function; -use arrow::array::{ArrayRef, Float32Array, Float64Array, Int32Array}; +use arrow::array::{ArrayRef, AsArray, Float32Array, Float64Array, Int32Array}; use arrow::compute::{cast_with_options, CastOptions}; -use arrow::datatypes::DataType; use arrow::datatypes::DataType::{Float32, Float64, Int32}; +use arrow::datatypes::{DataType, Float32Type, Float64Type}; use datafusion_common::{ exec_datafusion_err, exec_err, DataFusionError, Result, ScalarValue, }; @@ -148,17 +148,18 @@ pub fn round(args: &[ArrayRef]) -> Result { ) })?; - Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "value", - Float64Array, - { - |value: f64| { - (value * 10.0_f64.powi(decimal_places)).round() - / 10.0_f64.powi(decimal_places) - } - } - )) as ArrayRef) + Ok(Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float64Type>(|value: f64| { + if value == 0_f64 { + 0_f64 + } else { + (value * 10.0_f64.powi(decimal_places)).round() + / 10.0_f64.powi(decimal_places) + } + }), + ) as ArrayRef) } ColumnarValue::Array(decimal_places) => { let options = CastOptions { @@ -197,17 +198,18 @@ pub fn round(args: &[ArrayRef]) -> Result { ) })?; - Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "value", - Float32Array, - { - |value: f32| { - (value * 10.0_f32.powi(decimal_places)).round() - / 10.0_f32.powi(decimal_places) - } - } - )) as ArrayRef) + Ok(Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float32Type>(|value: f32| { + if value == 0_f32 { + 0_f32 + } else { + (value * 10.0_f32.powi(decimal_places)).round() + / 10.0_f32.powi(decimal_places) + } + }), + ) as ArrayRef) } ColumnarValue::Array(_) => { let ColumnarValue::Array(decimal_places) = From 53f54d4053cfabb6d8b5b6a898593a85b9f60724 Mon Sep 17 00:00:00 2001 From: buraksenb Date: Thu, 17 Oct 2024 09:03:07 +0300 Subject: [PATCH 2/9] delete make_function_scalar_inputs --- datafusion/functions/src/macros.rs | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index 85ffaa868f24..36aca0278867 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -383,19 +383,6 @@ macro_rules! make_math_binary_udf { }; } -macro_rules! make_function_scalar_inputs { - ($ARG: expr, $NAME:expr, $ARRAY_TYPE:ident, $FUNC: block) => {{ - let arg = downcast_arg!($ARG, $NAME, $ARRAY_TYPE); - - arg.iter() - .map(|a| match a { - Some(a) => Some($FUNC(a)), - _ => None, - }) - .collect::<$ARRAY_TYPE>() - }}; -} - macro_rules! make_function_inputs2 { ($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE:ident, $FUNC: block) => {{ let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE); From 58d3d86c793a119804a1322d75eb9c3607a64df2 Mon Sep 17 00:00:00 2001 From: buraksenb Date: Thu, 17 Oct 2024 14:48:29 +0300 Subject: [PATCH 3/9] fix --- datafusion/functions/src/math/log.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index b6af6057b81b..731afdaf8110 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -140,8 +140,7 @@ impl ScalarUDFImpl for LogFunc { let arr: ArrayRef = match args[0].data_type() { DataType::Float64 => match base { ColumnarValue::Scalar(ScalarValue::Float64(Some(base))) => Arc::new( - args[0] - .as_primitive::() + x.as_primitive::() .unary::<_, Float64Type>(|value: f64| f64::log(value, base)), ), ColumnarValue::Array(base) => Arc::new(make_function_inputs2!( @@ -159,8 +158,7 @@ impl ScalarUDFImpl for LogFunc { DataType::Float32 => match base { ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => Arc::new( - args[0] - .as_primitive::() + x.as_primitive::() .unary::<_, Float32Type>(|value: f32| f32::log(value, base)), ), ColumnarValue::Array(base) => Arc::new(make_function_inputs2!( From 7f94ee3ad47735f3b9cee6a6b97ff9b01c474459 Mon Sep 17 00:00:00 2001 From: buraksenb Date: Fri, 18 Oct 2024 16:57:48 +0300 Subject: [PATCH 4/9] refactored other macros --- datafusion/functions/src/macros.rs | 134 +++++++++---------------- datafusion/functions/src/math/log.rs | 39 ++++--- datafusion/functions/src/math/nanvl.rs | 32 +++--- datafusion/functions/src/math/power.rs | 23 +++-- datafusion/functions/src/math/round.rs | 55 +++++----- 5 files changed, 113 insertions(+), 170 deletions(-) diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index 36aca0278867..83bbdc852b4e 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -112,26 +112,6 @@ macro_rules! make_stub_package { }; } -/// Invokes a function on each element of an array and returns the result as a new array -/// -/// $ARG: ArrayRef -/// $NAME: name of the function (for error messages) -/// $ARGS_TYPE: the type of array to cast the argument to -/// $RETURN_TYPE: the type of array to return -/// $FUNC: the function to apply to each element of $ARG -macro_rules! make_function_scalar_inputs_return_type { - ($ARG: expr, $NAME:expr, $ARG_TYPE:ident, $RETURN_TYPE:ident, $FUNC: block) => {{ - let arg = downcast_arg!($ARG, $NAME, $ARG_TYPE); - - arg.iter() - .map(|a| match a { - Some(a) => Some($FUNC(a)), - _ => None, - }) - .collect::<$RETURN_TYPE>() - }}; -} - /// Downcast an argument to a specific array type, returning an internal error /// if the cast fails /// @@ -168,9 +148,9 @@ macro_rules! make_math_unary_udf { use std::any::Any; use std::sync::Arc; - use arrow::array::{ArrayRef, Float32Array, Float64Array}; - use arrow::datatypes::DataType; - use datafusion_common::{exec_err, DataFusionError, Result}; + use arrow::array::{ArrayRef, AsArray}; + use arrow::datatypes::{DataType, Float32Type, Float64Type}; + use datafusion_common::{exec_err, Result}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ @@ -231,24 +211,28 @@ macro_rules! make_math_unary_udf { fn invoke(&self, col_args: &[ColumnarValue]) -> Result { let args = ColumnarValue::values_to_arrays(col_args)?; let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => { - Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - self.name(), - Float64Array, - Float64Array, - { f64::$UNARY_FUNC } - )) - } - DataType::Float32 => { - Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - self.name(), - Float32Array, - Float32Array, - { f32::$UNARY_FUNC } - )) - } + DataType::Float64 => Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float64Type>(|x: f64| { + if x == 0_f64 { + 0_f64 + } else { + f64::$UNARY_FUNC(x) + } + }), + ) as ArrayRef, + DataType::Float32 => Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float32Type>(|x: f32| { + if x == 0_f32 { + 0_f32 + } else { + f32::$UNARY_FUNC(x) + } + }), + ) as ArrayRef, other => { return exec_err!( "Unsupported data type {other:?} for function {}", @@ -286,9 +270,9 @@ macro_rules! make_math_binary_udf { use std::any::Any; use std::sync::Arc; - use arrow::array::{ArrayRef, Float32Array, Float64Array}; - use arrow::datatypes::DataType; - use datafusion_common::{exec_err, DataFusionError, Result}; + use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; + use arrow::datatypes::{DataType, Float32Type, Float64Type}; + use datafusion_common::{exec_err, Result}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature; use datafusion_expr::{ @@ -347,23 +331,24 @@ macro_rules! make_math_binary_udf { fn invoke(&self, col_args: &[ColumnarValue]) -> Result { let args = ColumnarValue::values_to_arrays(col_args)?; let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "y", - "x", - Float64Array, - { f64::$BINARY_FUNC } - )), - - DataType::Float32 => Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "y", - "x", - Float32Array, - { f32::$BINARY_FUNC } - )), + DataType::Float64 => { + let y = args[0].as_primitive::(); + let x = args[1].as_primitive::(); + let result: PrimitiveArray = + arrow::compute::binary(y, x, |y, x| { + f64::$BINARY_FUNC(y, x) + })?; + Arc::new(result) as ArrayRef + } + DataType::Float32 => { + let y = args[0].as_primitive::(); + let x = args[1].as_primitive::(); + let result: PrimitiveArray = + arrow::compute::binary(y, x, |y, x| { + f32::$BINARY_FUNC(y, x) + })?; + Arc::new(result) as ArrayRef + } other => { return exec_err!( "Unsupported data type {other:?} for function {}", @@ -382,30 +367,3 @@ macro_rules! make_math_binary_udf { } }; } - -macro_rules! make_function_inputs2 { - ($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE:ident, $FUNC: block) => {{ - let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE); - let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE); - - arg1.iter() - .zip(arg2.iter()) - .map(|(a1, a2)| match (a1, a2) { - (Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)), - _ => None, - }) - .collect::<$ARRAY_TYPE>() - }}; - ($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE1:ident, $ARRAY_TYPE2:ident, $FUNC: block) => {{ - let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE1); - let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE2); - - arg1.iter() - .zip(arg2.iter()) - .map(|(a1, a2)| match (a1, a2) { - (Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)), - _ => None, - }) - .collect::<$ARRAY_TYPE1>() - }}; -} diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index 731afdaf8110..64fe1704b3f3 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -22,11 +22,10 @@ use std::sync::{Arc, OnceLock}; use super::power::PowerFunc; -use arrow::array::{ArrayRef, AsArray, Float32Array, Float64Array}; +use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; use arrow::datatypes::{DataType, Float32Type, Float64Type}; use datafusion_common::{ - exec_err, internal_err, plan_datafusion_err, plan_err, DataFusionError, Result, - ScalarValue, + exec_err, internal_err, plan_datafusion_err, plan_err, Result, ScalarValue, }; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; @@ -143,14 +142,13 @@ impl ScalarUDFImpl for LogFunc { x.as_primitive::() .unary::<_, Float64Type>(|value: f64| f64::log(value, base)), ), - ColumnarValue::Array(base) => Arc::new(make_function_inputs2!( - x, - base, - "x", - "base", - Float64Array, - { f64::log } - )), + ColumnarValue::Array(base) => { + let x = x.as_primitive::(); + let base = base.as_primitive::(); + let result: PrimitiveArray = + arrow::compute::binary(x, base, f64::log)?; + Arc::new(result) as ArrayRef + } _ => { return exec_err!("log function requires a scalar or array for base") } @@ -161,14 +159,13 @@ impl ScalarUDFImpl for LogFunc { x.as_primitive::() .unary::<_, Float32Type>(|value: f32| f32::log(value, base)), ), - ColumnarValue::Array(base) => Arc::new(make_function_inputs2!( - x, - base, - "x", - "base", - Float32Array, - { f32::log } - )), + ColumnarValue::Array(base) => { + let x = x.as_primitive::(); + let base = base.as_primitive::(); + let result: PrimitiveArray = + arrow::compute::binary(x, base, f32::log)?; + Arc::new(result) as ArrayRef + } _ => { return exec_err!("log function requires a scalar or array for base") } @@ -253,9 +250,9 @@ fn is_pow(func: &ScalarUDF) -> bool { #[cfg(test)] mod tests { - use std::collections::HashMap; - use super::*; + use arrow::array::{Float32Array, Float64Array}; + use std::collections::HashMap; use arrow::compute::SortOptions; use datafusion_common::cast::{as_float32_array, as_float64_array}; diff --git a/datafusion/functions/src/math/nanvl.rs b/datafusion/functions/src/math/nanvl.rs index b82ee0d45744..99eba8ca05e9 100644 --- a/datafusion/functions/src/math/nanvl.rs +++ b/datafusion/functions/src/math/nanvl.rs @@ -18,11 +18,11 @@ use std::any::Any; use std::sync::{Arc, OnceLock}; -use arrow::array::{ArrayRef, Float32Array, Float64Array}; -use arrow::datatypes::DataType; +use arrow::array::{ArrayRef, AsArray, Float32Array, Float64Array, PrimitiveArray}; use arrow::datatypes::DataType::{Float32, Float64}; +use arrow::datatypes::{DataType, Float32Type, Float64Type}; -use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_common::{exec_err, Result}; use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ @@ -113,14 +113,11 @@ fn nanvl(args: &[ArrayRef]) -> Result { } }; - Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "x", - "y", - Float64Array, - { compute_nanvl } - )) as ArrayRef) + let x = args[0].as_primitive() as &Float64Array; + let y = args[1].as_primitive() as &Float64Array; + let result: PrimitiveArray = + arrow::compute::binary(x, y, compute_nanvl)?; + Ok(Arc::new(result) as ArrayRef) } Float32 => { let compute_nanvl = |x: f32, y: f32| { @@ -131,14 +128,11 @@ fn nanvl(args: &[ArrayRef]) -> Result { } }; - Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "x", - "y", - Float32Array, - { compute_nanvl } - )) as ArrayRef) + let x = args[0].as_primitive() as &Float32Array; + let y = args[1].as_primitive() as &Float32Array; + let result: PrimitiveArray = + arrow::compute::binary(x, y, compute_nanvl)?; + Ok(Arc::new(result) as ArrayRef) } other => exec_err!("Unsupported data type {other:?} for function nanvl"), } diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index a99afaec97f7..436375898fcf 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -17,7 +17,7 @@ //! Math function: `power()`. -use arrow::datatypes::{ArrowNativeTypeOp, DataType}; +use arrow::datatypes::{ArrowNativeTypeOp, DataType, Float64Type}; use datafusion_common::{ arrow_datafusion_err, exec_datafusion_err, exec_err, plan_datafusion_err, @@ -28,7 +28,7 @@ use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ColumnarValue, Documentation, Expr, ScalarUDF, TypeSignature}; -use arrow::array::{ArrayRef, Float64Array, Int64Array}; +use arrow::array::{ArrayRef, AsArray, Int64Array, PrimitiveArray}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::{Arc, OnceLock}; @@ -90,15 +90,15 @@ impl ScalarUDFImpl for PowerFunc { let args = ColumnarValue::values_to_arrays(args)?; let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "base", - "exponent", - Float64Array, - { f64::powf } - )), - + DataType::Float64 => { + let bases = args[0].as_primitive::(); + let exponents = args[1].as_primitive::(); + let result: PrimitiveArray = + arrow::compute::binary(bases, exponents, |base, exp| { + f64::powf(base, exp) + })?; + Arc::new(result) as ArrayRef + } DataType::Int64 => { let bases = downcast_arg!(&args[0], "base", Int64Array); let exponents = downcast_arg!(&args[1], "exponent", Int64Array); @@ -195,6 +195,7 @@ fn is_log(func: &ScalarUDF) -> bool { #[cfg(test)] mod tests { + use arrow::array::Float64Array; use datafusion_common::cast::{as_float64_array, as_int64_array}; use super::*; diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index 27aa5b0078ce..076e2fbe93d7 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -20,13 +20,11 @@ use std::sync::{Arc, OnceLock}; use crate::utils::make_scalar_function; -use arrow::array::{ArrayRef, AsArray, Float32Array, Float64Array, Int32Array}; +use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; use arrow::compute::{cast_with_options, CastOptions}; use arrow::datatypes::DataType::{Float32, Float64, Int32}; -use arrow::datatypes::{DataType, Float32Type, Float64Type}; -use datafusion_common::{ - exec_datafusion_err, exec_err, DataFusionError, Result, ScalarValue, -}; +use arrow::datatypes::{DataType, Float32Type, Float64Type, Int32Type}; +use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue}; use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature::Exact; @@ -170,20 +168,18 @@ pub fn round(args: &[ArrayRef]) -> Result { .map_err(|e| { exec_datafusion_err!("Invalid values for decimal places: {e}") })?; - Ok(Arc::new(make_function_inputs2!( - &args[0], + + let values = args[0].as_primitive::(); + let decimal_places = decimal_places.as_primitive::(); + let result: PrimitiveArray = arrow::compute::binary( + values, decimal_places, - "value", - "decimal_places", - Float64Array, - Int32Array, - { - |value: f64, decimal_places: i32| { - (value * 10.0_f64.powi(decimal_places)).round() - / 10.0_f64.powi(decimal_places) - } - } - )) as ArrayRef) + |value, decimal_places| { + (value * 10.0_f64.powi(decimal_places)).round() + / 10.0_f64.powi(decimal_places) + }, + )?; + Ok(Arc::new(result) as ArrayRef) } _ => { exec_err!("round function requires a scalar or array for decimal_places") @@ -220,20 +216,17 @@ pub fn round(args: &[ArrayRef]) -> Result { panic!("Unexpected result of ColumnarValue::Array.cast") }; - Ok(Arc::new(make_function_inputs2!( - &args[0], + let values = args[0].as_primitive::(); + let decimal_places = decimal_places.as_primitive::(); + let result: PrimitiveArray = arrow::compute::binary( + values, decimal_places, - "value", - "decimal_places", - Float32Array, - Int32Array, - { - |value: f32, decimal_places: i32| { - (value * 10.0_f32.powi(decimal_places)).round() - / 10.0_f32.powi(decimal_places) - } - } - )) as ArrayRef) + |value, decimal_places| { + (value * 10.0_f32.powi(decimal_places)).round() + / 10.0_f32.powi(decimal_places) + }, + )?; + Ok(Arc::new(result) as ArrayRef) } _ => { exec_err!("round function requires a scalar or array for decimal_places") From aa419e6a2cfd731625a3e607a62c862bd9674cc8 Mon Sep 17 00:00:00 2001 From: buraksenb Date: Sat, 19 Oct 2024 18:22:43 +0300 Subject: [PATCH 5/9] fix unary CI --- datafusion/functions/src/macros.rs | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index 83bbdc852b4e..a9f2ac630117 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -214,24 +214,12 @@ macro_rules! make_math_unary_udf { DataType::Float64 => Arc::new( args[0] .as_primitive::() - .unary::<_, Float64Type>(|x: f64| { - if x == 0_f64 { - 0_f64 - } else { - f64::$UNARY_FUNC(x) - } - }), + .unary::<_, Float64Type>(|x: f64| f64::$UNARY_FUNC(x)), ) as ArrayRef, DataType::Float32 => Arc::new( args[0] .as_primitive::() - .unary::<_, Float32Type>(|x: f32| { - if x == 0_f32 { - 0_f32 - } else { - f32::$UNARY_FUNC(x) - } - }), + .unary::<_, Float32Type>(|x: f32| f32::$UNARY_FUNC(x)), ) as ArrayRef, other => { return exec_err!( From 26d43555043ef9e6bbcb14321ed587bad07eea97 Mon Sep 17 00:00:00 2001 From: buraksenb Date: Sun, 20 Oct 2024 12:33:50 +0300 Subject: [PATCH 6/9] fix base f32/f64 mismatch not caught by tests --- datafusion/functions/src/math/log.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index 64fe1704b3f3..a91c64aaeca2 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -138,10 +138,11 @@ impl ScalarUDFImpl for LogFunc { // note in f64::log params order is different than in sql. e.g in sql log(base, x) == f64::log(x, base) let arr: ArrayRef = match args[0].data_type() { DataType::Float64 => match base { - ColumnarValue::Scalar(ScalarValue::Float64(Some(base))) => Arc::new( - x.as_primitive::() - .unary::<_, Float64Type>(|value: f64| f64::log(value, base)), - ), + ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => { + Arc::new(x.as_primitive::().unary::<_, Float64Type>( + |value: f64| f64::log(value, base as f64), + )) + } ColumnarValue::Array(base) => { let x = x.as_primitive::(); let base = base.as_primitive::(); From 137cd4d6ff8f59ae77032d789fb6b50b9787ceda Mon Sep 17 00:00:00 2001 From: buraksenb Date: Sun, 20 Oct 2024 17:44:08 +0300 Subject: [PATCH 7/9] import order changes --- datafusion/functions/src/math/nanvl.rs | 9 +++++---- datafusion/functions/src/math/power.rs | 12 +++++------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/datafusion/functions/src/math/nanvl.rs b/datafusion/functions/src/math/nanvl.rs index 99eba8ca05e9..3dacb13031bd 100644 --- a/datafusion/functions/src/math/nanvl.rs +++ b/datafusion/functions/src/math/nanvl.rs @@ -18,10 +18,11 @@ use std::any::Any; use std::sync::{Arc, OnceLock}; +use crate::utils::make_scalar_function; + use arrow::array::{ArrayRef, AsArray, Float32Array, Float64Array, PrimitiveArray}; use arrow::datatypes::DataType::{Float32, Float64}; use arrow::datatypes::{DataType, Float32Type, Float64Type}; - use datafusion_common::{exec_err, Result}; use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::TypeSignature::Exact; @@ -29,8 +30,6 @@ use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; -use crate::utils::make_scalar_function; - #[derive(Debug)] pub struct NanvlFunc { signature: Signature, @@ -140,10 +139,12 @@ fn nanvl(args: &[ArrayRef]) -> Result { #[cfg(test)] mod test { + use std::sync::Arc; + use crate::math::nanvl::nanvl; + use arrow::array::{ArrayRef, Float32Array, Float64Array}; use datafusion_common::cast::{as_float32_array, as_float64_array}; - use std::sync::Arc; #[test] fn test_nanvl_f64() { diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 436375898fcf..189fd1b097b5 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -16,9 +16,13 @@ // under the License. //! Math function: `power()`. +use std::any::Any; +use std::sync::{Arc, OnceLock}; -use arrow::datatypes::{ArrowNativeTypeOp, DataType, Float64Type}; +use super::log::LogFunc; +use arrow::array::{ArrayRef, AsArray, Int64Array, PrimitiveArray}; +use arrow::datatypes::{ArrowNativeTypeOp, DataType, Float64Type}; use datafusion_common::{ arrow_datafusion_err, exec_datafusion_err, exec_err, plan_datafusion_err, DataFusionError, Result, ScalarValue, @@ -27,13 +31,7 @@ use datafusion_expr::expr::ScalarFunction; use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ColumnarValue, Documentation, Expr, ScalarUDF, TypeSignature}; - -use arrow::array::{ArrayRef, AsArray, Int64Array, PrimitiveArray}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use std::any::Any; -use std::sync::{Arc, OnceLock}; - -use super::log::LogFunc; #[derive(Debug)] pub struct PowerFunc { From cc5e237e7cc0e3e489dcc5c72593f3244dcfb753 Mon Sep 17 00:00:00 2001 From: berkaysynnada Date: Mon, 21 Oct 2024 11:09:53 +0300 Subject: [PATCH 8/9] Update log.rs --- datafusion/functions/src/math/log.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index a91c64aaeca2..b854d6198f8c 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -251,10 +251,11 @@ fn is_pow(func: &ScalarUDF) -> bool { #[cfg(test)] mod tests { - use super::*; - use arrow::array::{Float32Array, Float64Array}; use std::collections::HashMap; + use super::*; + + use arrow::array::{Float32Array, Float64Array}; use arrow::compute::SortOptions; use datafusion_common::cast::{as_float32_array, as_float64_array}; use datafusion_common::DFSchema; From c4f1ba710216e7ebdf90a61064958dcb52a8bfb5 Mon Sep 17 00:00:00 2001 From: berkaysynnada Date: Mon, 21 Oct 2024 12:59:57 +0300 Subject: [PATCH 9/9] stylistic changes --- datafusion/functions/src/macros.rs | 24 +++++++------- datafusion/functions/src/math/log.rs | 20 ++++++++---- datafusion/functions/src/math/nanvl.rs | 16 ++++----- datafusion/functions/src/math/power.rs | 15 +++++---- datafusion/functions/src/math/round.rs | 45 ++++++++++---------------- 5 files changed, 59 insertions(+), 61 deletions(-) diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index a9f2ac630117..621ab92db21a 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -258,7 +258,7 @@ macro_rules! make_math_binary_udf { use std::any::Any; use std::sync::Arc; - use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; + use arrow::array::{ArrayRef, AsArray}; use arrow::datatypes::{DataType, Float32Type, Float64Type}; use datafusion_common::{exec_err, Result}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; @@ -322,20 +322,22 @@ macro_rules! make_math_binary_udf { DataType::Float64 => { let y = args[0].as_primitive::(); let x = args[1].as_primitive::(); - let result: PrimitiveArray = - arrow::compute::binary(y, x, |y, x| { - f64::$BINARY_FUNC(y, x) - })?; - Arc::new(result) as ArrayRef + let result = arrow::compute::binary::<_, _, _, Float64Type>( + y, + x, + |y, x| f64::$BINARY_FUNC(y, x), + )?; + Arc::new(result) as _ } DataType::Float32 => { let y = args[0].as_primitive::(); let x = args[1].as_primitive::(); - let result: PrimitiveArray = - arrow::compute::binary(y, x, |y, x| { - f32::$BINARY_FUNC(y, x) - })?; - Arc::new(result) as ArrayRef + let result = arrow::compute::binary::<_, _, _, Float32Type>( + y, + x, + |y, x| f32::$BINARY_FUNC(y, x), + )?; + Arc::new(result) as _ } other => { return exec_err!( diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index b854d6198f8c..93b5683e1946 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -22,7 +22,7 @@ use std::sync::{Arc, OnceLock}; use super::power::PowerFunc; -use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; +use arrow::array::{ArrayRef, AsArray}; use arrow::datatypes::{DataType, Float32Type, Float64Type}; use datafusion_common::{ exec_err, internal_err, plan_datafusion_err, plan_err, Result, ScalarValue, @@ -146,9 +146,12 @@ impl ScalarUDFImpl for LogFunc { ColumnarValue::Array(base) => { let x = x.as_primitive::(); let base = base.as_primitive::(); - let result: PrimitiveArray = - arrow::compute::binary(x, base, f64::log)?; - Arc::new(result) as ArrayRef + let result = arrow::compute::binary::<_, _, _, Float64Type>( + x, + base, + f64::log, + )?; + Arc::new(result) as _ } _ => { return exec_err!("log function requires a scalar or array for base") @@ -163,9 +166,12 @@ impl ScalarUDFImpl for LogFunc { ColumnarValue::Array(base) => { let x = x.as_primitive::(); let base = base.as_primitive::(); - let result: PrimitiveArray = - arrow::compute::binary(x, base, f32::log)?; - Arc::new(result) as ArrayRef + let result = arrow::compute::binary::<_, _, _, Float32Type>( + x, + base, + f32::log, + )?; + Arc::new(result) as _ } _ => { return exec_err!("log function requires a scalar or array for base") diff --git a/datafusion/functions/src/math/nanvl.rs b/datafusion/functions/src/math/nanvl.rs index 3dacb13031bd..cfd21256dd96 100644 --- a/datafusion/functions/src/math/nanvl.rs +++ b/datafusion/functions/src/math/nanvl.rs @@ -20,10 +20,10 @@ use std::sync::{Arc, OnceLock}; use crate::utils::make_scalar_function; -use arrow::array::{ArrayRef, AsArray, Float32Array, Float64Array, PrimitiveArray}; +use arrow::array::{ArrayRef, AsArray, Float32Array, Float64Array}; use arrow::datatypes::DataType::{Float32, Float64}; use arrow::datatypes::{DataType, Float32Type, Float64Type}; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ @@ -114,9 +114,9 @@ fn nanvl(args: &[ArrayRef]) -> Result { let x = args[0].as_primitive() as &Float64Array; let y = args[1].as_primitive() as &Float64Array; - let result: PrimitiveArray = - arrow::compute::binary(x, y, compute_nanvl)?; - Ok(Arc::new(result) as ArrayRef) + arrow::compute::binary::<_, _, _, Float64Type>(x, y, compute_nanvl) + .map(|res| Arc::new(res) as _) + .map_err(DataFusionError::from) } Float32 => { let compute_nanvl = |x: f32, y: f32| { @@ -129,9 +129,9 @@ fn nanvl(args: &[ArrayRef]) -> Result { let x = args[0].as_primitive() as &Float32Array; let y = args[1].as_primitive() as &Float32Array; - let result: PrimitiveArray = - arrow::compute::binary(x, y, compute_nanvl)?; - Ok(Arc::new(result) as ArrayRef) + arrow::compute::binary::<_, _, _, Float32Type>(x, y, compute_nanvl) + .map(|res| Arc::new(res) as _) + .map_err(DataFusionError::from) } other => exec_err!("Unsupported data type {other:?} for function nanvl"), } diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 189fd1b097b5..16ce0b8df39b 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -21,7 +21,7 @@ use std::sync::{Arc, OnceLock}; use super::log::LogFunc; -use arrow::array::{ArrayRef, AsArray, Int64Array, PrimitiveArray}; +use arrow::array::{ArrayRef, AsArray, Int64Array}; use arrow::datatypes::{ArrowNativeTypeOp, DataType, Float64Type}; use datafusion_common::{ arrow_datafusion_err, exec_datafusion_err, exec_err, plan_datafusion_err, @@ -91,11 +91,12 @@ impl ScalarUDFImpl for PowerFunc { DataType::Float64 => { let bases = args[0].as_primitive::(); let exponents = args[1].as_primitive::(); - let result: PrimitiveArray = - arrow::compute::binary(bases, exponents, |base, exp| { - f64::powf(base, exp) - })?; - Arc::new(result) as ArrayRef + let result = arrow::compute::binary::<_, _, _, Float64Type>( + bases, + exponents, + f64::powf, + )?; + Arc::new(result) as _ } DataType::Int64 => { let bases = downcast_arg!(&args[0], "base", Int64Array); @@ -114,7 +115,7 @@ impl ScalarUDFImpl for PowerFunc { _ => Ok(None), }) .collect::>() - .map(Arc::new)? as ArrayRef + .map(Arc::new)? as _ } other => { diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index 076e2fbe93d7..87ad5ecb6938 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -146,18 +146,13 @@ pub fn round(args: &[ArrayRef]) -> Result { ) })?; - Ok(Arc::new( - args[0] - .as_primitive::() - .unary::<_, Float64Type>(|value: f64| { - if value == 0_f64 { - 0_f64 - } else { - (value * 10.0_f64.powi(decimal_places)).round() - / 10.0_f64.powi(decimal_places) - } - }), - ) as ArrayRef) + let result = args[0] + .as_primitive::() + .unary::<_, Float64Type>(|value: f64| { + (value * 10.0_f64.powi(decimal_places)).round() + / 10.0_f64.powi(decimal_places) + }); + Ok(Arc::new(result) as _) } ColumnarValue::Array(decimal_places) => { let options = CastOptions { @@ -171,7 +166,7 @@ pub fn round(args: &[ArrayRef]) -> Result { let values = args[0].as_primitive::(); let decimal_places = decimal_places.as_primitive::(); - let result: PrimitiveArray = arrow::compute::binary( + let result = arrow::compute::binary::<_, _, _, Float64Type>( values, decimal_places, |value, decimal_places| { @@ -179,7 +174,7 @@ pub fn round(args: &[ArrayRef]) -> Result { / 10.0_f64.powi(decimal_places) }, )?; - Ok(Arc::new(result) as ArrayRef) + Ok(Arc::new(result) as _) } _ => { exec_err!("round function requires a scalar or array for decimal_places") @@ -193,19 +188,13 @@ pub fn round(args: &[ArrayRef]) -> Result { "Invalid value for decimal places: {decimal_places}: {e}" ) })?; - - Ok(Arc::new( - args[0] - .as_primitive::() - .unary::<_, Float32Type>(|value: f32| { - if value == 0_f32 { - 0_f32 - } else { - (value * 10.0_f32.powi(decimal_places)).round() - / 10.0_f32.powi(decimal_places) - } - }), - ) as ArrayRef) + let result = args[0] + .as_primitive::() + .unary::<_, Float32Type>(|value: f32| { + (value * 10.0_f32.powi(decimal_places)).round() + / 10.0_f32.powi(decimal_places) + }); + Ok(Arc::new(result) as _) } ColumnarValue::Array(_) => { let ColumnarValue::Array(decimal_places) = @@ -226,7 +215,7 @@ pub fn round(args: &[ArrayRef]) -> Result { / 10.0_f32.powi(decimal_places) }, )?; - Ok(Arc::new(result) as ArrayRef) + Ok(Arc::new(result) as _) } _ => { exec_err!("round function requires a scalar or array for decimal_places")