diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index 443229a3cb77..3fa602f12554 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -114,10 +114,10 @@ impl Default for SessionConfig { } /// A type map for storing extensions. -/// +/// /// Extensions are indexed by their type `T`. If multiple values of the same type are provided, only the last one /// will be kept. -/// +/// /// Extensions are opaque objects that are unknown to DataFusion itself but can be downcast by optimizer rules, /// execution plans, or other components that have access to the session config. /// They provide a flexible way to attach extra data or behavior to the session config. diff --git a/datafusion/functions/src/math/ceil.rs b/datafusion/functions/src/math/ceil.rs new file mode 100644 index 000000000000..68a76c7bc3b3 --- /dev/null +++ b/datafusion/functions/src/math/ceil.rs @@ -0,0 +1,174 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, AsArray}; +use arrow::datatypes::{ + DataType, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type, Float32Type, + Float64Type, +}; +use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_expr::{ + Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignature, TypeSignatureClass, Volatility, +}; +use datafusion_macros::user_doc; + +use super::decimal::{apply_decimal_op, ceil_decimal_value}; + +#[user_doc( + doc_section(label = "Math Functions"), + description = "Returns the nearest integer greater than or equal to a number.", + syntax_example = "ceil(numeric_expression)", + standard_argument(name = "numeric_expression", prefix = "Numeric"), + sql_example = r#"```sql +> SELECT ceil(3.14); ++------------+ +| ceil(3.14) | ++------------+ +| 4.0 | ++------------+ +```"# +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct CeilFunc { + signature: Signature, +} + +impl Default for CeilFunc { + fn default() -> Self { + Self::new() + } +} + +impl CeilFunc { + pub fn new() -> Self { + let decimal_sig = Coercion::new_exact(TypeSignatureClass::Decimal); + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Coercible(vec![decimal_sig]), + TypeSignature::Uniform(1, vec![DataType::Float64, DataType::Float32]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for CeilFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "ceil" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match &arg_types[0] { + DataType::Null => Ok(DataType::Float64), + other => Ok(other.clone()), + } + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = ColumnarValue::values_to_arrays(&args.args)?; + let value = &args[0]; + + let result: ArrayRef = match value.data_type() { + DataType::Float64 => Arc::new( + value + .as_primitive::() + .unary::<_, Float64Type>(f64::ceil), + ), + DataType::Float32 => Arc::new( + value + .as_primitive::() + .unary::<_, Float32Type>(f32::ceil), + ), + DataType::Null => { + return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))) + } + DataType::Decimal32(precision, scale) => { + apply_decimal_op::( + value, + *precision, + *scale, + self.name(), + ceil_decimal_value, + )? + } + DataType::Decimal64(precision, scale) => { + apply_decimal_op::( + value, + *precision, + *scale, + self.name(), + ceil_decimal_value, + )? + } + DataType::Decimal128(precision, scale) => { + apply_decimal_op::( + value, + *precision, + *scale, + self.name(), + ceil_decimal_value, + )? + } + DataType::Decimal256(precision, scale) => { + apply_decimal_op::( + value, + *precision, + *scale, + self.name(), + ceil_decimal_value, + )? + } + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ) + } + }; + + Ok(ColumnarValue::Array(result)) + } + + fn output_ordering(&self, input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) + } + + fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result { + let data_type = inputs[0].data_type(); + Interval::make_unbounded(&data_type) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} diff --git a/datafusion/functions/src/math/decimal.rs b/datafusion/functions/src/math/decimal.rs new file mode 100644 index 000000000000..b27130dbd3a0 --- /dev/null +++ b/datafusion/functions/src/math/decimal.rs @@ -0,0 +1,107 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; +use arrow::datatypes::{ArrowNativeTypeOp, DecimalType}; +use arrow::error::ArrowError; +use arrow_buffer::ArrowNativeType; +use datafusion_common::{DataFusionError, Result}; + +pub(super) fn apply_decimal_op( + array: &ArrayRef, + precision: u8, + scale: i8, + fn_name: &str, + op: F, +) -> Result +where + T: DecimalType, + T::Native: ArrowNativeType + ArrowNativeTypeOp, + F: Fn(T::Native, T::Native) -> T::Native, +{ + if scale <= 0 { + return Ok(Arc::clone(array)); + } + + let factor = decimal_scale_factor::(scale, fn_name)?; + let decimal = array.as_primitive::(); + let data_type = array.data_type().clone(); + + let result: PrimitiveArray = decimal.try_unary(|value| { + let new_value = op(value, factor); + T::validate_decimal_precision(new_value, precision, scale).map_err(|_| { + ArrowError::ComputeError(format!("Decimal overflow while applying {fn_name}")) + })?; + Ok::<_, ArrowError>(new_value) + })?; + + let result = result.with_data_type(data_type); + + Ok(Arc::new(result)) +} + +fn decimal_scale_factor(scale: i8, fn_name: &str) -> Result +where + T: DecimalType, + T::Native: ArrowNativeType + ArrowNativeTypeOp, +{ + let base = ::from_usize(10).ok_or_else(|| { + DataFusionError::Execution(format!("Decimal overflow while applying {fn_name}")) + })?; + + base.pow_checked(scale as u32).map_err(|_| { + DataFusionError::Execution(format!("Decimal overflow while applying {fn_name}")) + }) +} + +pub(super) fn ceil_decimal_value(value: T, factor: T) -> T +where + T: ArrowNativeTypeOp + std::ops::Rem, +{ + let remainder = value % factor; + + if remainder == T::ZERO { + return value; + } + + if value >= T::ZERO { + let increment = factor.sub_wrapping(remainder); + value.add_wrapping(increment) + } else { + value.sub_wrapping(remainder) + } +} + +pub(super) fn floor_decimal_value(value: T, factor: T) -> T +where + T: ArrowNativeTypeOp + std::ops::Rem, +{ + let remainder = value % factor; + + if remainder == T::ZERO { + return value; + } + + if value >= T::ZERO { + value.sub_wrapping(remainder) + } else { + let adjustment = factor.add_wrapping(remainder); + value.sub_wrapping(adjustment) + } +} diff --git a/datafusion/functions/src/math/floor.rs b/datafusion/functions/src/math/floor.rs new file mode 100644 index 000000000000..8a4a888555cd --- /dev/null +++ b/datafusion/functions/src/math/floor.rs @@ -0,0 +1,174 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, AsArray}; +use arrow::datatypes::{ + DataType, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type, Float32Type, + Float64Type, +}; +use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_expr::{ + Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignature, TypeSignatureClass, Volatility, +}; +use datafusion_macros::user_doc; + +use super::decimal::{apply_decimal_op, floor_decimal_value}; + +#[user_doc( + doc_section(label = "Math Functions"), + description = "Returns the nearest integer less than or equal to a number.", + syntax_example = "floor(numeric_expression)", + standard_argument(name = "numeric_expression", prefix = "Numeric"), + sql_example = r#"```sql +> SELECT floor(3.14); ++-------------+ +| floor(3.14) | ++-------------+ +| 3.0 | ++-------------+ +```"# +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct FloorFunc { + signature: Signature, +} + +impl Default for FloorFunc { + fn default() -> Self { + Self::new() + } +} + +impl FloorFunc { + pub fn new() -> Self { + let decimal_sig = Coercion::new_exact(TypeSignatureClass::Decimal); + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Coercible(vec![decimal_sig]), + TypeSignature::Uniform(1, vec![DataType::Float64, DataType::Float32]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for FloorFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "floor" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match &arg_types[0] { + DataType::Null => Ok(DataType::Float64), + other => Ok(other.clone()), + } + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = ColumnarValue::values_to_arrays(&args.args)?; + let value = &args[0]; + + let result: ArrayRef = match value.data_type() { + DataType::Float64 => Arc::new( + value + .as_primitive::() + .unary::<_, Float64Type>(f64::floor), + ), + DataType::Float32 => Arc::new( + value + .as_primitive::() + .unary::<_, Float32Type>(f32::floor), + ), + DataType::Null => { + return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))) + } + DataType::Decimal32(precision, scale) => { + apply_decimal_op::( + value, + *precision, + *scale, + self.name(), + floor_decimal_value, + )? + } + DataType::Decimal64(precision, scale) => { + apply_decimal_op::( + value, + *precision, + *scale, + self.name(), + floor_decimal_value, + )? + } + DataType::Decimal128(precision, scale) => { + apply_decimal_op::( + value, + *precision, + *scale, + self.name(), + floor_decimal_value, + )? + } + DataType::Decimal256(precision, scale) => { + apply_decimal_op::( + value, + *precision, + *scale, + self.name(), + floor_decimal_value, + )? + } + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ) + } + }; + + Ok(ColumnarValue::Array(result)) + } + + fn output_ordering(&self, input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) + } + + fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result { + let data_type = inputs[0].data_type(); + Interval::make_unbounded(&data_type) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index 4eb337a30110..610e773d68fd 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -23,8 +23,11 @@ use std::sync::Arc; pub mod abs; pub mod bounds; +pub mod ceil; pub mod cot; +mod decimal; pub mod factorial; +pub mod floor; pub mod gcd; pub mod iszero; pub mod lcm; @@ -104,14 +107,7 @@ make_math_unary_udf!( super::bounds::unbounded_bounds, super::get_cbrt_doc ); -make_math_unary_udf!( - CeilFunc, - ceil, - ceil, - super::ceil_order, - super::bounds::unbounded_bounds, - super::get_ceil_doc -); +make_udf_function!(ceil::CeilFunc, ceil); make_math_unary_udf!( CosFunc, cos, @@ -146,14 +142,7 @@ make_math_unary_udf!( super::get_exp_doc ); make_udf_function!(factorial::FactorialFunc, factorial); -make_math_unary_udf!( - FloorFunc, - floor, - floor, - super::floor_order, - super::bounds::unbounded_bounds, - super::get_floor_doc -); +make_udf_function!(floor::FloorFunc, floor); make_udf_function!(log::LogFunc, log); make_udf_function!(gcd::GcdFunc, gcd); make_udf_function!(nans::IsNanFunc, isnan); diff --git a/datafusion/functions/src/math/monotonicity.rs b/datafusion/functions/src/math/monotonicity.rs index ffb3df4196d2..22b3bcb86399 100644 --- a/datafusion/functions/src/math/monotonicity.rs +++ b/datafusion/functions/src/math/monotonicity.rs @@ -309,30 +309,6 @@ pub fn ceil_order(input: &[ExprProperties]) -> Result { Ok(input[0].sort_properties) } -static DOCUMENTATION_CEIL: LazyLock = LazyLock::new(|| { - Documentation::builder( - DOC_SECTION_MATH, - "Returns the nearest integer greater than or equal to a number.", - "ceil(numeric_expression)", - ) - .with_standard_argument("numeric_expression", Some("Numeric")) - .with_sql_example( - r#"```sql - > SELECT ceil(3.14); -+------------+ -| ceil(3.14) | -+------------+ -| 4.0 | -+------------+ -```"#, - ) - .build() -}); - -pub fn get_ceil_doc() -> &'static Documentation { - &DOCUMENTATION_CEIL -} - /// Non-increasing on \[0, π\] and then non-decreasing on \[π, 2π\]. /// This pattern repeats periodically with a period of 2π. // TODO: Implement ordering rule of the ATAN2 function. @@ -467,30 +443,6 @@ pub fn floor_order(input: &[ExprProperties]) -> Result { Ok(input[0].sort_properties) } -static DOCUMENTATION_FLOOR: LazyLock = LazyLock::new(|| { - Documentation::builder( - DOC_SECTION_MATH, - "Returns the nearest integer less than or equal to a number.", - "floor(numeric_expression)", - ) - .with_standard_argument("numeric_expression", Some("Numeric")) - .with_sql_example( - r#"```sql -> SELECT floor(3.14); -+-------------+ -| floor(3.14) | -+-------------+ -| 3.0 | -+-------------+ -```"#, - ) - .build() -}); - -pub fn get_floor_doc() -> &'static Documentation { - &DOCUMENTATION_FLOOR -} - /// Non-decreasing for x ≥ 0, undefined otherwise. pub fn ln_order(input: &[ExprProperties]) -> Result { let arg = &input[0]; diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 8eac9bd0c955..3a9e238e6b38 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -317,6 +317,54 @@ select ceil(100.1234, 1) query error DataFusion error: This feature is not implemented: CEIL with datetime is not supported select ceil(100.1234 to year) +# ceil with decimal argument +query RRRR +select + ceil(arrow_cast(1.23,'Decimal128(10,2)')), + ceil(arrow_cast(-1.23,'Decimal128(10,2)')), + ceil(arrow_cast(123.00,'Decimal128(10,2)')), + ceil(arrow_cast(-123.00,'Decimal128(10,2)')); +---- +2 -1 123 -123 + +# ceil overflow with limited precision +query error Decimal overflow while applying ceil +select ceil(arrow_cast(9.23,'Decimal128(3,2)')); + +# ceil with decimal32 argument (ensure decimal output) +query TTTTTTTT +select + arrow_typeof(ceil(arrow_cast(9.01,'Decimal32(7,2)'))), + arrow_cast(ceil(arrow_cast(9.01,'Decimal32(7,2)')), 'Utf8'), + arrow_typeof(ceil(arrow_cast(-9.01,'Decimal32(7,2)'))), + arrow_cast(ceil(arrow_cast(-9.01,'Decimal32(7,2)')), 'Utf8'), + arrow_typeof(ceil(arrow_cast(10.00,'Decimal32(7,2)'))), + arrow_cast(ceil(arrow_cast(10.00,'Decimal32(7,2)')), 'Utf8'), + arrow_typeof(ceil(arrow_cast(-0.99,'Decimal32(7,2)'))), + arrow_cast(ceil(arrow_cast(-0.99,'Decimal32(7,2)')), 'Utf8'); +---- +Decimal32(7, 2) 10.00 Decimal32(7, 2) -9.00 Decimal32(7, 2) 10.00 Decimal32(7, 2) 0.00 + +# ceil with decimal64 zero scale +query TTTT +select + arrow_typeof(ceil(arrow_cast(123456789,'Decimal64(18,0)'))), + arrow_cast(ceil(arrow_cast(123456789,'Decimal64(18,0)')), 'Utf8'), + arrow_typeof(ceil(arrow_cast(-987654321,'Decimal64(18,0)'))), + arrow_cast(ceil(arrow_cast(-987654321,'Decimal64(18,0)')), 'Utf8'); +---- +Decimal64(18, 0) 123456789 Decimal64(18, 0) -987654321 + +# ceil with decimal256 argument +query TTTT +select + arrow_typeof(ceil(arrow_cast('9999999999999999999999999999999999.01','Decimal256(38,2)'))), + arrow_cast(ceil(arrow_cast('9999999999999999999999999999999999.01','Decimal256(38,2)')), 'Utf8'), + arrow_typeof(ceil(arrow_cast('-9999999999999999999999999999999999.01','Decimal256(38,2)'))), + arrow_cast(ceil(arrow_cast('-9999999999999999999999999999999999.01','Decimal256(38,2)')), 'Utf8'); +---- +Decimal256(38, 2) 10000000000000000000000000000000000.00 Decimal256(38, 2) -9999999999999999999999999999999999.00 + ## degrees # degrees scalar function @@ -464,6 +512,54 @@ select floor(a, 1) query error DataFusion error: This feature is not implemented: FLOOR with datetime is not supported select floor(a to year) +# floor with decimal argument +query RRRR +select + floor(arrow_cast(1.23,'Decimal128(10,2)')), + floor(arrow_cast(-1.23,'Decimal128(10,2)')), + floor(arrow_cast(123.00,'Decimal128(10,2)')), + floor(arrow_cast(-123.00,'Decimal128(10,2)')); +---- +1 -2 123 -123 + +# floor overflow with limited precision +query error Decimal overflow while applying floor +select floor(arrow_cast(-9.23,'Decimal128(3,2)')); + +# floor with decimal32 argument (ensure decimal output) +query TTTTTTTT +select + arrow_typeof(floor(arrow_cast(9.99,'Decimal32(7,2)'))), + arrow_cast(floor(arrow_cast(9.99,'Decimal32(7,2)')), 'Utf8'), + arrow_typeof(floor(arrow_cast(-9.01,'Decimal32(7,2)'))), + arrow_cast(floor(arrow_cast(-9.01,'Decimal32(7,2)')), 'Utf8'), + arrow_typeof(floor(arrow_cast(10.00,'Decimal32(7,2)'))), + arrow_cast(floor(arrow_cast(10.00,'Decimal32(7,2)')), 'Utf8'), + arrow_typeof(floor(arrow_cast(-0.01,'Decimal32(7,2)'))), + arrow_cast(floor(arrow_cast(-0.01,'Decimal32(7,2)')), 'Utf8'); +---- +Decimal32(7, 2) 9.00 Decimal32(7, 2) -10.00 Decimal32(7, 2) 10.00 Decimal32(7, 2) -1.00 + +# floor with decimal64 zero scale +query TTTT +select + arrow_typeof(floor(arrow_cast(123456789,'Decimal64(18,0)'))), + arrow_cast(floor(arrow_cast(123456789,'Decimal64(18,0)')), 'Utf8'), + arrow_typeof(floor(arrow_cast(-987654321,'Decimal64(18,0)'))), + arrow_cast(floor(arrow_cast(-987654321,'Decimal64(18,0)')), 'Utf8'); +---- +Decimal64(18, 0) 123456789 Decimal64(18, 0) -987654321 + +# floor with decimal256 argument +query TTTT +select + arrow_typeof(floor(arrow_cast('9999999999999999999999999999999999.99','Decimal256(38,2)'))), + arrow_cast(floor(arrow_cast('9999999999999999999999999999999999.99','Decimal256(38,2)')), 'Utf8'), + arrow_typeof(floor(arrow_cast('-9999999999999999999999999999999999.99','Decimal256(38,2)'))), + arrow_cast(floor(arrow_cast('-9999999999999999999999999999999999.99','Decimal256(38,2)')), 'Utf8'); +---- +Decimal256(38, 2) 9999999999999999999999999999999999.00 Decimal256(38, 2) -10000000000000000000000000000000000.00 + ## ln # ln scalar function diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index c5380b2e84b7..2d6bb2817cfc 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -294,7 +294,7 @@ ceil(numeric_expression) #### Example ```sql - > SELECT ceil(3.14); +> SELECT ceil(3.14); +------------+ | ceil(3.14) | +------------+