diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 5d6a8bcfdef2..0581f6423a32 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -27,15 +27,15 @@ use arrow::datatypes::{ Decimal64Type, Float64Type, Int64Type, }; use arrow::error::ArrowError; +use datafusion_common::types::{logical_float64, logical_int64, NativeType}; use datafusion_common::utils::take_function_args; -use datafusion_common::{exec_err, plan_datafusion_err, Result, ScalarValue}; +use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::type_coercion::is_decimal; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDF, + lit, Coercion, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDF, + ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility, }; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; #[user_doc( @@ -67,8 +67,27 @@ impl Default for PowerFunc { impl PowerFunc { pub fn new() -> Self { + let integer = Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![TypeSignatureClass::Integer], + NativeType::Int64, + ); + let decimal = Coercion::new_exact(TypeSignatureClass::Decimal); + let float = Coercion::new_implicit( + TypeSignatureClass::Native(logical_float64()), + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ); Self { - signature: Signature::user_defined(Volatility::Immutable), + signature: Signature::one_of( + vec![ + TypeSignature::Coercible(vec![integer.clone(), integer.clone()]), + TypeSignature::Coercible(vec![decimal.clone(), integer.clone()]), + TypeSignature::Coercible(vec![decimal.clone(), float.clone()]), + TypeSignature::Coercible(vec![float.clone(), float.clone()]), + ], + Volatility::Immutable, + ), aliases: vec![String::from("pow")], } } @@ -153,6 +172,7 @@ impl ScalarUDFImpl for PowerFunc { fn as_any(&self) -> &dyn Any { self } + fn name(&self) -> &str { "power" } @@ -162,49 +182,20 @@ impl ScalarUDFImpl for PowerFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) + if arg_types[0].is_null() { + Ok(DataType::Int64) + } else { + Ok(arg_types[0].clone()) + } } fn aliases(&self) -> &[String] { &self.aliases } - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - let [arg1, arg2] = take_function_args(self.name(), arg_types)?; - - fn coerced_type_base(name: &str, data_type: &DataType) -> Result { - match data_type { - DataType::Null => Ok(DataType::Int64), - d if d.is_floating() => Ok(DataType::Float64), - d if d.is_integer() => Ok(DataType::Int64), - d if is_decimal(d) => Ok(d.clone()), - other => { - exec_err!("Unsupported data type {other:?} for {} function", name) - } - } - } - - fn coerced_type_exp(name: &str, data_type: &DataType) -> Result { - match data_type { - DataType::Null => Ok(DataType::Int64), - d if d.is_floating() => Ok(DataType::Float64), - d if d.is_integer() => Ok(DataType::Int64), - d if is_decimal(d) => Ok(DataType::Float64), - other => { - exec_err!("Unsupported data type {other:?} for {} function", name) - } - } - } - - Ok(vec![ - coerced_type_base(self.name(), arg1)?, - coerced_type_exp(self.name(), arg2)?, - ]) - } - fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let base = &args.args[0].to_array(args.number_rows)?; - let exponent = &args.args[1]; + let [base, exponent] = take_function_args(self.name(), &args.args)?; + let base = base.to_array(args.number_rows)?; let arr: ArrayRef = match (base.data_type(), exponent.data_type()) { (DataType::Float64, _) => { @@ -227,110 +218,104 @@ impl ScalarUDFImpl for PowerFunc { )? } (DataType::Decimal32(precision, scale), DataType::Int64) => - calculate_binary_decimal_math::( - &base, - exponent, - |b, e| pow_decimal_int(b, *scale, e), + calculate_binary_decimal_math::( + &base, + exponent, + |b, e| pow_decimal_int(b, *scale, e), *precision, *scale, )?, (DataType::Decimal32(precision, scale), DataType::Float64) => - calculate_binary_decimal_math::( - &base, - exponent, - |b, e| pow_decimal_float(b, *scale, e), + calculate_binary_decimal_math::( + &base, + exponent, + |b, e| pow_decimal_float(b, *scale, e), *precision, *scale, )?, (DataType::Decimal64(precision, scale), DataType::Int64) => - calculate_binary_decimal_math::( - &base, - exponent, - |b, e| pow_decimal_int(b, *scale, e), + calculate_binary_decimal_math::( + &base, + exponent, + |b, e| pow_decimal_int(b, *scale, e), *precision, *scale, )?, (DataType::Decimal64(precision, scale), DataType::Float64) => - calculate_binary_decimal_math::( - &base, - exponent, - |b, e| pow_decimal_float(b, *scale, e), + calculate_binary_decimal_math::( + &base, + exponent, + |b, e| pow_decimal_float(b, *scale, e), *precision, *scale, )?, (DataType::Decimal128(precision, scale), DataType::Int64) => - calculate_binary_decimal_math::( - &base, - exponent, - |b, e| pow_decimal_int(b, *scale, e), + calculate_binary_decimal_math::( + &base, + exponent, + |b, e| pow_decimal_int(b, *scale, e), *precision, *scale, )?, (DataType::Decimal128(precision, scale), DataType::Float64) => - calculate_binary_decimal_math::< - Decimal128Type, - Float64Type, - Decimal128Type, - _, - >(&base, exponent, |b, e| - pow_decimal_float(b, *scale, e), + calculate_binary_decimal_math::( + &base, + exponent, + |b, e| pow_decimal_float(b, *scale, e), *precision, *scale, )?, (DataType::Decimal256(precision, scale),DataType::Int64) => - calculate_binary_decimal_math::( - &base, - exponent, - |b, e| pow_decimal_int(b, *scale, e), + calculate_binary_decimal_math::( + &base, + exponent, + |b, e| pow_decimal_int(b, *scale, e), *precision, *scale, )?, - (DataType::Decimal256(precision, scale), DataType::Float64) => - calculate_binary_decimal_math::< - Decimal256Type, - Float64Type, - Decimal256Type, - _, - >(&base, exponent, |b, e| - pow_decimal_float(b, *scale, e) , - *precision, + (DataType::Decimal256(precision, scale), DataType::Float64) => + calculate_binary_decimal_math::( + &base, + exponent, + |b, e| pow_decimal_float(b, *scale, e), + *precision, *scale, )?, (base_type, exp_type) => { - return exec_err!( - "Unsupported data types for base {base_type:?} and exponent {exp_type:?} for function {}", - self.name() - ) + return internal_err!("Unsupported data types for base {base_type:?} and exponent {exp_type:?} for power") } }; Ok(ColumnarValue::Array(arr)) } /// Simplify the `power` function by the relevant rules: - /// 1. Power(a, 0) ===> 0 + /// 1. Power(a, 0) ===> 1 /// 2. Power(a, 1) ===> a /// 3. Power(a, Log(a, b)) ===> b fn simplify( &self, - mut args: Vec, + args: Vec, info: &dyn SimplifyInfo, ) -> Result { - let exponent = args.pop().ok_or_else(|| { - plan_datafusion_err!("Expected power to have 2 arguments, got 0") - })?; - let base = args.pop().ok_or_else(|| { - plan_datafusion_err!("Expected power to have 2 arguments, got 1") - })?; - + let [base, exponent] = take_function_args("power", args)?; + let base_type = info.get_data_type(&base)?; let exponent_type = info.get_data_type(&exponent)?; + + // Null propagation + if base_type.is_null() || exponent_type.is_null() { + let return_type = self.return_type(&[base_type, exponent_type])?; + return Ok(ExprSimplifyResult::Simplified(lit( + ScalarValue::Null.cast_to(&return_type)? + ))); + } + match exponent { Expr::Literal(value, _) if value == ScalarValue::new_zero(&exponent_type)? => { - Ok(ExprSimplifyResult::Simplified(Expr::Literal( - ScalarValue::new_one(&info.get_data_type(&base)?)?, - None, - ))) + Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_one( + &base_type, + )?))) } Expr::Literal(value, _) if value == ScalarValue::new_one(&exponent_type)? => { Ok(ExprSimplifyResult::Simplified(base)) @@ -358,241 +343,6 @@ fn is_log(func: &ScalarUDF) -> bool { #[cfg(test)] mod tests { use super::*; - use arrow::array::{Array, Decimal128Array, Float64Array, Int64Array}; - use arrow::datatypes::{Field, DECIMAL128_MAX_SCALE}; - use arrow_buffer::NullBuffer; - use datafusion_common::cast::{ - as_decimal128_array, as_float64_array, as_int64_array, - }; - use datafusion_common::config::ConfigOptions; - use std::sync::Arc; - - #[cfg(test)] - #[ctor::ctor] - fn init() { - // Enable RUST_LOG logging configuration for test - let _ = env_logger::try_init(); - } - - #[test] - fn test_power_f64() { - let arg_fields = vec![ - Field::new("a", DataType::Float64, true).into(), - Field::new("a", DataType::Float64, true).into(), - ]; - let args = ScalarFunctionArgs { - args: vec![ - ColumnarValue::Array(Arc::new(Float64Array::from(vec![ - 2.0, 2.0, 3.0, 5.0, - ]))), // base - ColumnarValue::Array(Arc::new(Float64Array::from(vec![ - 3.0, 2.0, 4.0, 4.0, - ]))), // exponent - ], - arg_fields, - number_rows: 4, - return_field: Field::new("f", DataType::Float64, true).into(), - config_options: Arc::new(ConfigOptions::default()), - }; - let result = PowerFunc::new() - .invoke_with_args(args) - .expect("failed to initialize function power"); - - match result { - ColumnarValue::Array(arr) => { - let floats = as_float64_array(&arr) - .expect("failed to convert result to a Float64Array"); - assert_eq!(floats.len(), 4); - assert_eq!(floats.value(0), 8.0); - assert_eq!(floats.value(1), 4.0); - assert_eq!(floats.value(2), 81.0); - assert_eq!(floats.value(3), 625.0); - } - ColumnarValue::Scalar(_) => { - panic!("Expected an array value") - } - } - } - - #[test] - fn test_power_i64() { - let arg_fields = vec![ - Field::new("a", DataType::Int64, true).into(), - Field::new("a", DataType::Int64, true).into(), - ]; - let args = ScalarFunctionArgs { - args: vec![ - ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 5]))), // base - ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4]))), // exponent - ], - arg_fields, - number_rows: 4, - return_field: Field::new("f", DataType::Int64, true).into(), - config_options: Arc::new(ConfigOptions::default()), - }; - let result = PowerFunc::new() - .invoke_with_args(args) - .expect("failed to initialize function power"); - - match result { - ColumnarValue::Array(arr) => { - let ints = as_int64_array(&arr) - .expect("failed to convert result to a Int64Array"); - - assert_eq!(ints.len(), 4); - assert_eq!(ints.value(0), 8); - assert_eq!(ints.value(1), 4); - assert_eq!(ints.value(2), 81); - assert_eq!(ints.value(3), 625); - } - ColumnarValue::Scalar(_) => { - panic!("Expected an array value") - } - } - } - - #[test] - fn test_power_i128() { - let arg_fields = vec![ - Field::new( - "a", - DataType::Decimal128(DECIMAL128_MAX_SCALE as u8, 0), - true, - ) - .into(), - Field::new("a", DataType::Int64, true).into(), - ]; - let args = ScalarFunctionArgs { - args: vec![ - ColumnarValue::Array(Arc::new( - Decimal128Array::from(vec![2, 2, 3, 5, 0, 5]) - .with_precision_and_scale(DECIMAL128_MAX_SCALE as u8, 0) - .unwrap(), - )), // base - ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4, 4, 0]))), // exponent - ], - arg_fields, - number_rows: 6, - return_field: Field::new( - "f", - DataType::Decimal128(DECIMAL128_MAX_SCALE as u8, 0), - true, - ) - .into(), - config_options: Arc::new(ConfigOptions::default()), - }; - let result = PowerFunc::new() - .invoke_with_args(args) - .expect("failed to initialize function power"); - - match result { - ColumnarValue::Array(arr) => { - let ints = as_decimal128_array(&arr) - .expect("failed to convert result to an array"); - - assert_eq!(ints.len(), 6); - assert_eq!(ints.value(0), i128::from(8)); - assert_eq!(ints.value(1), i128::from(4)); - assert_eq!(ints.value(2), i128::from(81)); - assert_eq!(ints.value(3), i128::from(625)); - assert_eq!(ints.value(4), i128::from(0)); - assert_eq!(ints.value(5), i128::from(1)); - } - ColumnarValue::Scalar(_) => { - panic!("Expected an array value") - } - } - } - - #[test] - fn test_power_array_null() { - let arg_fields = vec![ - Field::new("a", DataType::Int64, true).into(), - Field::new("a", DataType::Int64, true).into(), - ]; - let args = ScalarFunctionArgs { - args: vec![ - ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 2]))), // base - ColumnarValue::Array(Arc::new(Int64Array::from_iter_values_with_nulls( - vec![1, 2, 3], - Some(NullBuffer::from(vec![true, false, true])), - ))), // exponent - ], - arg_fields, - number_rows: 1, - return_field: Field::new("f", DataType::Int64, true).into(), - config_options: Arc::new(ConfigOptions::default()), - }; - let result = PowerFunc::new() - .invoke_with_args(args) - .expect("failed to initialize function power"); - - match result { - ColumnarValue::Array(arr) => { - let ints = - as_int64_array(&arr).expect("failed to convert result to an array"); - - assert_eq!(ints.len(), 3); - assert!(!ints.is_null(0)); - assert_eq!(ints.value(0), i64::from(2)); - assert!(ints.is_null(1)); - assert!(!ints.is_null(2)); - assert_eq!(ints.value(2), i64::from(8)); - } - ColumnarValue::Scalar(_) => { - panic!("Expected an array value") - } - } - } - - #[test] - fn test_power_decimal_with_scale() { - // 2.5 ^ 4 = 39 - // 2.5 is 25 in Decimal128(2, 1) by parsing rules - // Signature is Decimal128(2, 1) -> Int64 -> Decimal128(2, 1), therefore - // result is 390 in Decimal128(2, 1) aka 39 in unscaled Decimal128(2, 0) - let arg_fields = vec![ - Field::new( - "a", - DataType::Decimal128(DECIMAL128_MAX_SCALE as u8, 0), - true, - ) - .into(), - Field::new("a", DataType::Int64, true).into(), - ]; - let args = ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::Decimal128( - Some(i128::from(25)), - 2, - 1, - )), // base - ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), // exponent - ], - arg_fields, - number_rows: 1, - return_field: Field::new("f", DataType::Decimal128(2, 1), true).into(), - config_options: Arc::new(ConfigOptions::default()), - }; - let result = PowerFunc::new() - .invoke_with_args(args) - .expect("failed to initialize function power"); - - match result { - ColumnarValue::Array(arr) => { - let ints = as_decimal128_array(&arr) - .expect("failed to convert result to an array"); - - assert_eq!(ints.len(), 1); - assert_eq!(ints.value(0), i128::from(390)); - // Signature stays the same as input - assert_eq!(*arr.data_type(), DataType::Decimal128(2, 1)); - } - ColumnarValue::Scalar(_) => { - panic!("Expected an array value") - } - } - } #[test] fn test_pow_decimal128_helper() { diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index f34e1156a785..bfacc8a39d2c 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -699,6 +699,60 @@ select lcm(2, 9223372036854775803); query error DataFusion error: Arrow error: Arithmetic overflow: Overflow happened on: 2107754225 \^ 1221660777 select power(2107754225, 1221660777); +query R rowsort +select power(base::double, exponent::double) +from values + (2.0, 2.0), + (5.0, 4.0), + (2.0, 3.0), + (3.0, 4.0) as t(base, exponent); +---- +4 +625 +8 +81 + +query I rowsort +select power(base::bigint, exponent::bigint) +from values + (2, 2), + (5, 4), + (2, 3), + (3, 4), + (2, NULL) as t(base, exponent); +---- +4 +625 +8 +81 +NULL + +query RT rowsort +select + power(base::decimal(38, 0), exponent::decimal(38, 0)), + arrow_typeof(power(base::decimal(38, 0), exponent::decimal(38, 0))) +from values + (0, 4), + (5, 0), + (2, 2), + (5, 4), + (2, 3), + (3, 4) as t(base, exponent); +---- +0 Decimal128(38, 0) +1 Decimal128(38, 0) +4 Decimal128(38, 0) +625 Decimal128(38, 0) +8 Decimal128(38, 0) +81 Decimal128(38, 0) + +query RT +select + pow(2.5::decimal(2, 1), 4::bigint), + arrow_typeof(pow(2.5::decimal(2, 1), 4::bigint)); +---- +39 Decimal128(2, 1) + # factorial overflow query error DataFusion error: Arrow error: Compute error: Overflow happened on FACTORIAL\(350943270\) select FACTORIAL(350943270); diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 8eac9bd0c955..3d4d6d11e7b2 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -1775,7 +1775,7 @@ CREATE TABLE test( (-14, -14, -14.5, -14.5), (NULL, NULL, NULL, NULL); -query IIRRIR rowsort +query IRRRIR rowsort SELECT power(i32, exp_i) as power_i32, power(i64, exp_f) as power_i64, pow(f32, exp_i) as power_f32, @@ -1883,7 +1883,7 @@ D false # test string_temporal_coercion query BBBBBBBBBB -select +select arrow_cast(to_timestamp('2020-01-01 01:01:11.1234567890Z'), 'Timestamp(Second, None)') == '2020-01-01T01:01:11', arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Second, None)') == arrow_cast('2020-01-02T01:01:11', 'LargeUtf8'), arrow_cast(to_timestamp('2020-01-03 01:01:11.1234567890Z'), 'Time32(Second)') == '01:01:11',