From 08dbd6ebd7c50e40f3114fb25a1094f569359704 Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Fri, 28 Nov 2025 15:09:57 +0530 Subject: [PATCH 1/6] Add Decimal128 support to Ceil and Floor --- datafusion/functions/src/math/ceil.rs | 273 ++++++++++++++++++ datafusion/functions/src/math/floor.rs | 273 ++++++++++++++++++ datafusion/functions/src/math/mod.rs | 20 +- datafusion/sqllogictest/test_files/scalar.slt | 20 ++ 4 files changed, 570 insertions(+), 16 deletions(-) create mode 100644 datafusion/functions/src/math/ceil.rs create mode 100644 datafusion/functions/src/math/floor.rs diff --git a/datafusion/functions/src/math/ceil.rs b/datafusion/functions/src/math/ceil.rs new file mode 100644 index 000000000000..82a76bdbdd6d --- /dev/null +++ b/datafusion/functions/src/math/ceil.rs @@ -0,0 +1,273 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, AsArray, PrimitiveArray}; +use arrow::datatypes::{DataType, Decimal128Type, Float32Type, Float64Type}; +use arrow::error::ArrowError; +use datafusion_common::utils::take_function_args; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct CeilFunc { + signature: Signature, +} + +impl Default for CeilFunc { + fn default() -> Self { + Self::new() + } +} + +impl CeilFunc { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for CeilFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "ceil" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match arg_types[0] { + DataType::Float32 => Ok(DataType::Float32), + DataType::Decimal128(precision, scale) => { + Ok(DataType::Decimal128(precision, scale)) + } + _ => Ok(DataType::Float64), + } + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [arg] = take_function_args(self.name(), arg_types)?; + + let coerced = match arg { + DataType::Null => DataType::Float64, + DataType::Float32 | DataType::Float64 => arg.clone(), + DataType::Decimal128(_, _) => arg.clone(), + DataType::Float16 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => DataType::Float64, + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ) + } + }; + + Ok(vec![coerced]) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = ColumnarValue::values_to_arrays(&args.args)?; + let value = &args[0]; + + let result: ArrayRef = match value.data_type() { + DataType::Float64 => Arc::new( + value + .as_primitive::() + .unary::<_, Float64Type>(f64::ceil), + ), + DataType::Float32 => Arc::new( + value + .as_primitive::() + .unary::<_, Float32Type>(f32::ceil), + ), + DataType::Decimal128(_, scale) => { + apply_decimal_op(value, *scale, ceil_decimal_value)? + } + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ) + } + }; + + Ok(ColumnarValue::Array(result)) + } + + fn output_ordering(&self, input: &[ExprProperties]) -> Result { + super::ceil_order(input) + } + + fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result { + super::bounds::unbounded_bounds(inputs) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(super::get_ceil_doc()) + } +} + +fn apply_decimal_op( + array: &ArrayRef, + scale: i8, + op: fn(i128, i128) -> std::result::Result, +) -> Result { + if scale <= 0 { + return Ok(Arc::clone(array)); + } + + let factor = decimal_scale_factor(scale)?; + let decimal = array.as_primitive::(); + let data_type = array.data_type().clone(); + + let result: PrimitiveArray = decimal + .try_unary(|value| op(value, factor))? + .with_data_type(data_type); + + Ok(Arc::new(result)) +} + +fn decimal_scale_factor(scale: i8) -> Result { + if scale < 0 { + return exec_err!("Decimal scale {scale} must be non-negative"); + } + let exponent = scale as u32; + + if let Some(value) = 10_i128.checked_pow(exponent) { + Ok(value) + } else { + exec_err!("Decimal scale {scale} is too large for ceil") + } +} + +fn ceil_decimal_value( + value: i128, + factor: i128, +) -> std::result::Result { + let remainder = value % factor; + + if remainder == 0 { + return Ok(value); + } + + if value >= 0 { + let increment = factor - remainder; + value.checked_add(increment).ok_or_else(|| { + ArrowError::ComputeError("Decimal128 overflow while applying ceil".into()) + }) + } else { + value.checked_sub(remainder).ok_or_else(|| { + ArrowError::ComputeError("Decimal128 overflow while applying ceil".into()) + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Decimal128Array; + use arrow::datatypes::Field; + use datafusion_common::cast::as_decimal128_array; + use datafusion_common::config::ConfigOptions; + + #[test] + fn test_decimal128_ceil() { + let data_type = DataType::Decimal128(10, 2); + let input = Decimal128Array::from(vec![ + Some(1234), + Some(-1234), + Some(1200), + Some(-1200), + None, + ]) + .with_precision_and_scale(10, 2) + .unwrap(); + + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(Arc::new(input) as ArrayRef)], + arg_fields: vec![Field::new("a", data_type.clone(), true).into()], + number_rows: 5, + return_field: Field::new("f", data_type.clone(), true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = CeilFunc::new() + .invoke_with_args(args) + .expect("ceil evaluation succeeded"); + + let ColumnarValue::Array(result) = result else { + panic!("expected array result"); + }; + + let values = as_decimal128_array(&result).unwrap(); + assert_eq!(result.data_type(), &DataType::Decimal128(10, 2)); + assert_eq!(values.value(0), 1300); + assert_eq!(values.value(1), -1200); + assert_eq!(values.value(2), 1200); + assert_eq!(values.value(3), -1200); + assert!(values.is_null(4)); + } + + #[test] + fn test_decimal128_ceil_zero_scale() { + let data_type = DataType::Decimal128(6, 0); + let input = Decimal128Array::from(vec![Some(12), Some(-13), None]) + .with_precision_and_scale(6, 0) + .unwrap(); + + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(Arc::new(input) as ArrayRef)], + arg_fields: vec![Field::new("a", data_type.clone(), true).into()], + number_rows: 3, + return_field: Field::new("f", data_type.clone(), true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = CeilFunc::new() + .invoke_with_args(args) + .expect("ceil evaluation succeeded"); + + let ColumnarValue::Array(result) = result else { + panic!("expected array result"); + }; + + let values = as_decimal128_array(&result).unwrap(); + assert_eq!(values.value(0), 12); + assert_eq!(values.value(1), -13); + assert!(values.is_null(2)); + } +} diff --git a/datafusion/functions/src/math/floor.rs b/datafusion/functions/src/math/floor.rs new file mode 100644 index 000000000000..73f30351d524 --- /dev/null +++ b/datafusion/functions/src/math/floor.rs @@ -0,0 +1,273 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, AsArray, PrimitiveArray}; +use arrow::datatypes::{DataType, Decimal128Type, Float32Type, Float64Type}; +use arrow::error::ArrowError; +use datafusion_common::utils::take_function_args; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct FloorFunc { + signature: Signature, +} + +impl Default for FloorFunc { + fn default() -> Self { + Self::new() + } +} + +impl FloorFunc { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for FloorFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "floor" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match arg_types[0] { + DataType::Float32 => Ok(DataType::Float32), + DataType::Decimal128(precision, scale) => { + Ok(DataType::Decimal128(precision, scale)) + } + _ => Ok(DataType::Float64), + } + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [arg] = take_function_args(self.name(), arg_types)?; + + let coerced = match arg { + DataType::Null => DataType::Float64, + DataType::Float32 | DataType::Float64 => arg.clone(), + DataType::Decimal128(_, _) => arg.clone(), + DataType::Float16 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => DataType::Float64, + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ) + } + }; + + Ok(vec![coerced]) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = ColumnarValue::values_to_arrays(&args.args)?; + let value = &args[0]; + + let result: ArrayRef = match value.data_type() { + DataType::Float64 => Arc::new( + value + .as_primitive::() + .unary::<_, Float64Type>(f64::floor), + ), + DataType::Float32 => Arc::new( + value + .as_primitive::() + .unary::<_, Float32Type>(f32::floor), + ), + DataType::Decimal128(_, scale) => { + apply_decimal_op(value, *scale, floor_decimal_value)? + } + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ) + } + }; + + Ok(ColumnarValue::Array(result)) + } + + fn output_ordering(&self, input: &[ExprProperties]) -> Result { + super::floor_order(input) + } + + fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result { + super::bounds::unbounded_bounds(inputs) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(super::get_floor_doc()) + } +} + +fn apply_decimal_op( + array: &ArrayRef, + scale: i8, + op: fn(i128, i128) -> std::result::Result, +) -> Result { + if scale <= 0 { + return Ok(Arc::clone(array)); + } + + let factor = decimal_scale_factor(scale)?; + let decimal = array.as_primitive::(); + let data_type = array.data_type().clone(); + + let result: PrimitiveArray = decimal + .try_unary(|value| op(value, factor))? + .with_data_type(data_type); + + Ok(Arc::new(result)) +} + +fn decimal_scale_factor(scale: i8) -> Result { + if scale < 0 { + return exec_err!("Decimal scale {scale} must be non-negative"); + } + let exponent = scale as u32; + + if let Some(value) = 10_i128.checked_pow(exponent) { + Ok(value) + } else { + exec_err!("Decimal scale {scale} is too large for floor") + } +} + +fn floor_decimal_value( + value: i128, + factor: i128, +) -> std::result::Result { + let remainder = value % factor; + + if remainder == 0 { + return Ok(value); + } + + if value >= 0 { + value.checked_sub(remainder).ok_or_else(|| { + ArrowError::ComputeError("Decimal128 overflow while applying floor".into()) + }) + } else { + let adjustment = factor + remainder; + value.checked_sub(adjustment).ok_or_else(|| { + ArrowError::ComputeError("Decimal128 overflow while applying floor".into()) + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Decimal128Array; + use arrow::datatypes::Field; + use datafusion_common::cast::as_decimal128_array; + use datafusion_common::config::ConfigOptions; + + #[test] + fn test_decimal128_floor() { + let data_type = DataType::Decimal128(10, 2); + let input = Decimal128Array::from(vec![ + Some(1234), + Some(-1234), + Some(1200), + Some(-1200), + None, + ]) + .with_precision_and_scale(10, 2) + .unwrap(); + + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(Arc::new(input) as ArrayRef)], + arg_fields: vec![Field::new("a", data_type.clone(), true).into()], + number_rows: 5, + return_field: Field::new("f", data_type.clone(), true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = FloorFunc::new() + .invoke_with_args(args) + .expect("floor evaluation succeeded"); + + let ColumnarValue::Array(result) = result else { + panic!("expected array result"); + }; + + let values = as_decimal128_array(&result).unwrap(); + assert_eq!(result.data_type(), &DataType::Decimal128(10, 2)); + assert_eq!(values.value(0), 1200); + assert_eq!(values.value(1), -1300); + assert_eq!(values.value(2), 1200); + assert_eq!(values.value(3), -1200); + assert!(values.is_null(4)); + } + + #[test] + fn test_decimal128_floor_zero_scale() { + let data_type = DataType::Decimal128(6, 0); + let input = Decimal128Array::from(vec![Some(12), Some(-13), None]) + .with_precision_and_scale(6, 0) + .unwrap(); + + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(Arc::new(input) as ArrayRef)], + arg_fields: vec![Field::new("a", data_type.clone(), true).into()], + number_rows: 3, + return_field: Field::new("f", data_type.clone(), true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = FloorFunc::new() + .invoke_with_args(args) + .expect("floor evaluation succeeded"); + + let ColumnarValue::Array(result) = result else { + panic!("expected array result"); + }; + + let values = as_decimal128_array(&result).unwrap(); + assert_eq!(values.value(0), 12); + assert_eq!(values.value(1), -13); + assert!(values.is_null(2)); + } +} diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index 4eb337a30110..66370a21e62f 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -23,8 +23,10 @@ use std::sync::Arc; pub mod abs; pub mod bounds; +pub mod ceil; pub mod cot; pub mod factorial; +pub mod floor; pub mod gcd; pub mod iszero; pub mod lcm; @@ -104,14 +106,7 @@ make_math_unary_udf!( super::bounds::unbounded_bounds, super::get_cbrt_doc ); -make_math_unary_udf!( - CeilFunc, - ceil, - ceil, - super::ceil_order, - super::bounds::unbounded_bounds, - super::get_ceil_doc -); +make_udf_function!(ceil::CeilFunc, ceil); make_math_unary_udf!( CosFunc, cos, @@ -146,14 +141,7 @@ make_math_unary_udf!( super::get_exp_doc ); make_udf_function!(factorial::FactorialFunc, factorial); -make_math_unary_udf!( - FloorFunc, - floor, - floor, - super::floor_order, - super::bounds::unbounded_bounds, - super::get_floor_doc -); +make_udf_function!(floor::FloorFunc, floor); make_udf_function!(log::LogFunc, log); make_udf_function!(gcd::GcdFunc, gcd); make_udf_function!(nans::IsNanFunc, isnan); diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 8eac9bd0c955..ad2673e1b9a0 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -317,6 +317,16 @@ select ceil(100.1234, 1) query error DataFusion error: This feature is not implemented: CEIL with datetime is not supported select ceil(100.1234 to year) +# ceil with decimal argument +query RRRR +select + ceil(arrow_cast(1.23,'Decimal128(10,2)')), + ceil(arrow_cast(-1.23,'Decimal128(10,2)')), + ceil(arrow_cast(123.00,'Decimal128(10,2)')), + ceil(arrow_cast(-123.00,'Decimal128(10,2)')); +---- +2 -1 123 -123 + ## degrees # degrees scalar function @@ -464,6 +474,16 @@ select floor(a, 1) query error DataFusion error: This feature is not implemented: FLOOR with datetime is not supported select floor(a to year) +# floor with decimal argument +query RRRR +select + floor(arrow_cast(1.23,'Decimal128(10,2)')), + floor(arrow_cast(-1.23,'Decimal128(10,2)')), + floor(arrow_cast(123.00,'Decimal128(10,2)')), + floor(arrow_cast(-123.00,'Decimal128(10,2)')); +---- +1 -2 123 -123 + ## ln # ln scalar function From 47b8f26d5fab1991a49fa806ef8f7405a2621285 Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Sat, 29 Nov 2025 00:39:18 +0530 Subject: [PATCH 2/6] created new module for decimal shared generic code --- datafusion/functions/src/math/ceil.rs | 244 ++++++------------ datafusion/functions/src/math/decimal.rs | 207 +++++++++++++++ datafusion/functions/src/math/floor.rs | 244 ++++++------------ datafusion/functions/src/math/mod.rs | 1 + datafusion/functions/src/math/monotonicity.rs | 48 ---- datafusion/sqllogictest/test_files/scalar.slt | 68 +++++ 6 files changed, 420 insertions(+), 392 deletions(-) create mode 100644 datafusion/functions/src/math/decimal.rs diff --git a/datafusion/functions/src/math/ceil.rs b/datafusion/functions/src/math/ceil.rs index 82a76bdbdd6d..0845f5c5e4cd 100644 --- a/datafusion/functions/src/math/ceil.rs +++ b/datafusion/functions/src/math/ceil.rs @@ -18,18 +18,37 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{Array, ArrayRef, AsArray, PrimitiveArray}; -use arrow::datatypes::{DataType, Decimal128Type, Float32Type, Float64Type}; -use arrow::error::ArrowError; -use datafusion_common::utils::take_function_args; +use arrow::array::{ArrayRef, AsArray}; +use arrow::compute::cast; +use arrow::datatypes::{ + DataType, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type, Float32Type, + Float64Type, +}; use datafusion_common::{exec_err, Result}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, - Volatility, + Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignature, TypeSignatureClass, Volatility, }; - +use datafusion_macros::user_doc; + +use super::decimal::{apply_decimal_op, ceil_decimal_value}; + +#[user_doc( + doc_section(label = "Math Functions"), + description = "Returns the nearest integer greater than or equal to a number.", + syntax_example = "ceil(numeric_expression)", + standard_argument(name = "numeric_expression", prefix = "Numeric"), + sql_example = r#"```sql +> SELECT ceil(3.14); ++------------+ +| ceil(3.14) | ++------------+ +| 4.0 | ++------------+ +```"# +)] #[derive(Debug, PartialEq, Eq, Hash)] pub struct CeilFunc { signature: Signature, @@ -43,8 +62,15 @@ impl Default for CeilFunc { impl CeilFunc { pub fn new() -> Self { + let decimal_sig = Coercion::new_exact(TypeSignatureClass::Decimal); Self { - signature: Signature::user_defined(Volatility::Immutable), + signature: Signature::one_of( + vec![ + TypeSignature::Coercible(vec![decimal_sig]), + TypeSignature::Uniform(1, vec![DataType::Float64, DataType::Float32]), + ], + Volatility::Immutable, + ), } } } @@ -65,40 +91,22 @@ impl ScalarUDFImpl for CeilFunc { fn return_type(&self, arg_types: &[DataType]) -> Result { match arg_types[0] { DataType::Float32 => Ok(DataType::Float32), + DataType::Decimal32(precision, scale) => { + Ok(DataType::Decimal32(precision, scale)) + } + DataType::Decimal64(precision, scale) => { + Ok(DataType::Decimal64(precision, scale)) + } DataType::Decimal128(precision, scale) => { Ok(DataType::Decimal128(precision, scale)) } + DataType::Decimal256(precision, scale) => { + Ok(DataType::Decimal256(precision, scale)) + } _ => Ok(DataType::Float64), } } - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - let [arg] = take_function_args(self.name(), arg_types)?; - - let coerced = match arg { - DataType::Null => DataType::Float64, - DataType::Float32 | DataType::Float64 => arg.clone(), - DataType::Decimal128(_, _) => arg.clone(), - DataType::Float16 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 => DataType::Float64, - other => { - return exec_err!( - "Unsupported data type {other:?} for function {}", - self.name() - ) - } - }; - - Ok(vec![coerced]) - } - fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let args = ColumnarValue::values_to_arrays(&args.args)?; let value = &args[0]; @@ -114,9 +122,31 @@ impl ScalarUDFImpl for CeilFunc { .as_primitive::() .unary::<_, Float32Type>(f32::ceil), ), - DataType::Decimal128(_, scale) => { - apply_decimal_op(value, *scale, ceil_decimal_value)? - } + DataType::Null => cast(value.as_ref(), &DataType::Float64)?, + DataType::Decimal32(_, scale) => apply_decimal_op::( + value, + *scale, + self.name(), + ceil_decimal_value, + )?, + DataType::Decimal64(_, scale) => apply_decimal_op::( + value, + *scale, + self.name(), + ceil_decimal_value, + )?, + DataType::Decimal128(_, scale) => apply_decimal_op::( + value, + *scale, + self.name(), + ceil_decimal_value, + )?, + DataType::Decimal256(_, scale) => apply_decimal_op::( + value, + *scale, + self.name(), + ceil_decimal_value, + )?, other => { return exec_err!( "Unsupported data type {other:?} for function {}", @@ -129,145 +159,15 @@ impl ScalarUDFImpl for CeilFunc { } fn output_ordering(&self, input: &[ExprProperties]) -> Result { - super::ceil_order(input) + Ok(input[0].sort_properties) } fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result { - super::bounds::unbounded_bounds(inputs) + let data_type = inputs[0].data_type(); + Interval::make_unbounded(&data_type) } fn documentation(&self) -> Option<&Documentation> { - Some(super::get_ceil_doc()) - } -} - -fn apply_decimal_op( - array: &ArrayRef, - scale: i8, - op: fn(i128, i128) -> std::result::Result, -) -> Result { - if scale <= 0 { - return Ok(Arc::clone(array)); - } - - let factor = decimal_scale_factor(scale)?; - let decimal = array.as_primitive::(); - let data_type = array.data_type().clone(); - - let result: PrimitiveArray = decimal - .try_unary(|value| op(value, factor))? - .with_data_type(data_type); - - Ok(Arc::new(result)) -} - -fn decimal_scale_factor(scale: i8) -> Result { - if scale < 0 { - return exec_err!("Decimal scale {scale} must be non-negative"); - } - let exponent = scale as u32; - - if let Some(value) = 10_i128.checked_pow(exponent) { - Ok(value) - } else { - exec_err!("Decimal scale {scale} is too large for ceil") - } -} - -fn ceil_decimal_value( - value: i128, - factor: i128, -) -> std::result::Result { - let remainder = value % factor; - - if remainder == 0 { - return Ok(value); - } - - if value >= 0 { - let increment = factor - remainder; - value.checked_add(increment).ok_or_else(|| { - ArrowError::ComputeError("Decimal128 overflow while applying ceil".into()) - }) - } else { - value.checked_sub(remainder).ok_or_else(|| { - ArrowError::ComputeError("Decimal128 overflow while applying ceil".into()) - }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use arrow::array::Decimal128Array; - use arrow::datatypes::Field; - use datafusion_common::cast::as_decimal128_array; - use datafusion_common::config::ConfigOptions; - - #[test] - fn test_decimal128_ceil() { - let data_type = DataType::Decimal128(10, 2); - let input = Decimal128Array::from(vec![ - Some(1234), - Some(-1234), - Some(1200), - Some(-1200), - None, - ]) - .with_precision_and_scale(10, 2) - .unwrap(); - - let args = ScalarFunctionArgs { - args: vec![ColumnarValue::Array(Arc::new(input) as ArrayRef)], - arg_fields: vec![Field::new("a", data_type.clone(), true).into()], - number_rows: 5, - return_field: Field::new("f", data_type.clone(), true).into(), - config_options: Arc::new(ConfigOptions::default()), - }; - - let result = CeilFunc::new() - .invoke_with_args(args) - .expect("ceil evaluation succeeded"); - - let ColumnarValue::Array(result) = result else { - panic!("expected array result"); - }; - - let values = as_decimal128_array(&result).unwrap(); - assert_eq!(result.data_type(), &DataType::Decimal128(10, 2)); - assert_eq!(values.value(0), 1300); - assert_eq!(values.value(1), -1200); - assert_eq!(values.value(2), 1200); - assert_eq!(values.value(3), -1200); - assert!(values.is_null(4)); - } - - #[test] - fn test_decimal128_ceil_zero_scale() { - let data_type = DataType::Decimal128(6, 0); - let input = Decimal128Array::from(vec![Some(12), Some(-13), None]) - .with_precision_and_scale(6, 0) - .unwrap(); - - let args = ScalarFunctionArgs { - args: vec![ColumnarValue::Array(Arc::new(input) as ArrayRef)], - arg_fields: vec![Field::new("a", data_type.clone(), true).into()], - number_rows: 3, - return_field: Field::new("f", data_type.clone(), true).into(), - config_options: Arc::new(ConfigOptions::default()), - }; - - let result = CeilFunc::new() - .invoke_with_args(args) - .expect("ceil evaluation succeeded"); - - let ColumnarValue::Array(result) = result else { - panic!("expected array result"); - }; - - let values = as_decimal128_array(&result).unwrap(); - assert_eq!(values.value(0), 12); - assert_eq!(values.value(1), -13); - assert!(values.is_null(2)); + self.doc() } } diff --git a/datafusion/functions/src/math/decimal.rs b/datafusion/functions/src/math/decimal.rs new file mode 100644 index 000000000000..78cfd31800af --- /dev/null +++ b/datafusion/functions/src/math/decimal.rs @@ -0,0 +1,207 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::ops::Rem; +use std::sync::Arc; + +use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; +use arrow::datatypes::DecimalType; +use arrow::error::ArrowError; +use arrow_buffer::i256; +use datafusion_common::{exec_err, Result}; + +/// Operations required to manipulate the native representation of Arrow decimal arrays. +pub(super) trait DecimalNative: + Copy + Rem + PartialEq + PartialOrd +{ + fn zero() -> Self; + fn checked_add(self, other: Self) -> Option; + fn checked_sub(self, other: Self) -> Option; + fn checked_pow10(exp: u32) -> Option; +} + +impl DecimalNative for i32 { + fn zero() -> Self { + 0 + } + + fn checked_add(self, other: Self) -> Option { + self.checked_add(other) + } + + fn checked_sub(self, other: Self) -> Option { + self.checked_sub(other) + } + + fn checked_pow10(exp: u32) -> Option { + 10_i32.checked_pow(exp) + } +} + +impl DecimalNative for i64 { + fn zero() -> Self { + 0 + } + + fn checked_add(self, other: Self) -> Option { + self.checked_add(other) + } + + fn checked_sub(self, other: Self) -> Option { + self.checked_sub(other) + } + + fn checked_pow10(exp: u32) -> Option { + 10_i64.checked_pow(exp) + } +} + +impl DecimalNative for i128 { + fn zero() -> Self { + 0 + } + + fn checked_add(self, other: Self) -> Option { + self.checked_add(other) + } + + fn checked_sub(self, other: Self) -> Option { + self.checked_sub(other) + } + + fn checked_pow10(exp: u32) -> Option { + 10_i128.checked_pow(exp) + } +} + +impl DecimalNative for i256 { + fn zero() -> Self { + i256::ZERO + } + + fn checked_add(self, other: Self) -> Option { + self.checked_add(other) + } + + fn checked_sub(self, other: Self) -> Option { + self.checked_sub(other) + } + + fn checked_pow10(exp: u32) -> Option { + i256::from_i128(10).checked_pow(exp) + } +} + +pub(super) fn apply_decimal_op( + array: &ArrayRef, + scale: i8, + fn_name: &str, + op: F, +) -> Result +where + T: DecimalType, + T::Native: DecimalNative, + F: Fn(T::Native, T::Native) -> std::result::Result, +{ + if scale <= 0 { + return Ok(Arc::clone(array)); + } + + let factor = decimal_scale_factor::(scale, fn_name)?; + let decimal = array.as_primitive::(); + let data_type = array.data_type().clone(); + + let result: PrimitiveArray = decimal + .try_unary(|value| op(value, factor))? + .with_data_type(data_type); + + Ok(Arc::new(result)) +} + +fn decimal_scale_factor(scale: i8, fn_name: &str) -> Result +where + T: DecimalType, + T::Native: DecimalNative, +{ + if scale < 0 { + return exec_err!("Decimal scale {scale} must be non-negative"); + } + + if let Some(value) = T::Native::checked_pow10(scale as u32) { + Ok(value) + } else { + exec_err!("Decimal scale {scale} is too large for {fn_name}") + } +} + +pub(super) fn ceil_decimal_value( + value: T, + factor: T, +) -> std::result::Result +where + T: DecimalNative, +{ + let remainder = value % factor; + + if remainder == T::zero() { + return Ok(value); + } + + if value >= T::zero() { + let increment = factor + .checked_sub(remainder) + .ok_or_else(|| overflow_err("ceil"))?; + value + .checked_add(increment) + .ok_or_else(|| overflow_err("ceil")) + } else { + value + .checked_sub(remainder) + .ok_or_else(|| overflow_err("ceil")) + } +} + +pub(super) fn floor_decimal_value( + value: T, + factor: T, +) -> std::result::Result +where + T: DecimalNative, +{ + let remainder = value % factor; + + if remainder == T::zero() { + return Ok(value); + } + + if value >= T::zero() { + value + .checked_sub(remainder) + .ok_or_else(|| overflow_err("floor")) + } else { + let adjustment = factor + .checked_add(remainder) + .ok_or_else(|| overflow_err("floor"))?; + value + .checked_sub(adjustment) + .ok_or_else(|| overflow_err("floor")) + } +} + +fn overflow_err(name: &str) -> ArrowError { + ArrowError::ComputeError(format!("Decimal overflow while applying {name}")) +} diff --git a/datafusion/functions/src/math/floor.rs b/datafusion/functions/src/math/floor.rs index 73f30351d524..df283b200e7a 100644 --- a/datafusion/functions/src/math/floor.rs +++ b/datafusion/functions/src/math/floor.rs @@ -18,18 +18,37 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{Array, ArrayRef, AsArray, PrimitiveArray}; -use arrow::datatypes::{DataType, Decimal128Type, Float32Type, Float64Type}; -use arrow::error::ArrowError; -use datafusion_common::utils::take_function_args; +use arrow::array::{ArrayRef, AsArray}; +use arrow::compute::cast; +use arrow::datatypes::{ + DataType, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type, Float32Type, + Float64Type, +}; use datafusion_common::{exec_err, Result}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, - Volatility, + Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignature, TypeSignatureClass, Volatility, }; - +use datafusion_macros::user_doc; + +use super::decimal::{apply_decimal_op, floor_decimal_value}; + +#[user_doc( + doc_section(label = "Math Functions"), + description = "Returns the nearest integer less than or equal to a number.", + syntax_example = "floor(numeric_expression)", + standard_argument(name = "numeric_expression", prefix = "Numeric"), + sql_example = r#"```sql +> SELECT floor(3.14); ++-------------+ +| floor(3.14) | ++-------------+ +| 3.0 | ++-------------+ +```"# +)] #[derive(Debug, PartialEq, Eq, Hash)] pub struct FloorFunc { signature: Signature, @@ -43,8 +62,15 @@ impl Default for FloorFunc { impl FloorFunc { pub fn new() -> Self { + let decimal_sig = Coercion::new_exact(TypeSignatureClass::Decimal); Self { - signature: Signature::user_defined(Volatility::Immutable), + signature: Signature::one_of( + vec![ + TypeSignature::Coercible(vec![decimal_sig]), + TypeSignature::Uniform(1, vec![DataType::Float64, DataType::Float32]), + ], + Volatility::Immutable, + ), } } } @@ -65,40 +91,22 @@ impl ScalarUDFImpl for FloorFunc { fn return_type(&self, arg_types: &[DataType]) -> Result { match arg_types[0] { DataType::Float32 => Ok(DataType::Float32), + DataType::Decimal32(precision, scale) => { + Ok(DataType::Decimal32(precision, scale)) + } + DataType::Decimal64(precision, scale) => { + Ok(DataType::Decimal64(precision, scale)) + } DataType::Decimal128(precision, scale) => { Ok(DataType::Decimal128(precision, scale)) } + DataType::Decimal256(precision, scale) => { + Ok(DataType::Decimal256(precision, scale)) + } _ => Ok(DataType::Float64), } } - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - let [arg] = take_function_args(self.name(), arg_types)?; - - let coerced = match arg { - DataType::Null => DataType::Float64, - DataType::Float32 | DataType::Float64 => arg.clone(), - DataType::Decimal128(_, _) => arg.clone(), - DataType::Float16 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 => DataType::Float64, - other => { - return exec_err!( - "Unsupported data type {other:?} for function {}", - self.name() - ) - } - }; - - Ok(vec![coerced]) - } - fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let args = ColumnarValue::values_to_arrays(&args.args)?; let value = &args[0]; @@ -114,9 +122,31 @@ impl ScalarUDFImpl for FloorFunc { .as_primitive::() .unary::<_, Float32Type>(f32::floor), ), - DataType::Decimal128(_, scale) => { - apply_decimal_op(value, *scale, floor_decimal_value)? - } + DataType::Null => cast(value.as_ref(), &DataType::Float64)?, + DataType::Decimal32(_, scale) => apply_decimal_op::( + value, + *scale, + self.name(), + floor_decimal_value, + )?, + DataType::Decimal64(_, scale) => apply_decimal_op::( + value, + *scale, + self.name(), + floor_decimal_value, + )?, + DataType::Decimal128(_, scale) => apply_decimal_op::( + value, + *scale, + self.name(), + floor_decimal_value, + )?, + DataType::Decimal256(_, scale) => apply_decimal_op::( + value, + *scale, + self.name(), + floor_decimal_value, + )?, other => { return exec_err!( "Unsupported data type {other:?} for function {}", @@ -129,145 +159,15 @@ impl ScalarUDFImpl for FloorFunc { } fn output_ordering(&self, input: &[ExprProperties]) -> Result { - super::floor_order(input) + Ok(input[0].sort_properties) } fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result { - super::bounds::unbounded_bounds(inputs) + let data_type = inputs[0].data_type(); + Interval::make_unbounded(&data_type) } fn documentation(&self) -> Option<&Documentation> { - Some(super::get_floor_doc()) - } -} - -fn apply_decimal_op( - array: &ArrayRef, - scale: i8, - op: fn(i128, i128) -> std::result::Result, -) -> Result { - if scale <= 0 { - return Ok(Arc::clone(array)); - } - - let factor = decimal_scale_factor(scale)?; - let decimal = array.as_primitive::(); - let data_type = array.data_type().clone(); - - let result: PrimitiveArray = decimal - .try_unary(|value| op(value, factor))? - .with_data_type(data_type); - - Ok(Arc::new(result)) -} - -fn decimal_scale_factor(scale: i8) -> Result { - if scale < 0 { - return exec_err!("Decimal scale {scale} must be non-negative"); - } - let exponent = scale as u32; - - if let Some(value) = 10_i128.checked_pow(exponent) { - Ok(value) - } else { - exec_err!("Decimal scale {scale} is too large for floor") - } -} - -fn floor_decimal_value( - value: i128, - factor: i128, -) -> std::result::Result { - let remainder = value % factor; - - if remainder == 0 { - return Ok(value); - } - - if value >= 0 { - value.checked_sub(remainder).ok_or_else(|| { - ArrowError::ComputeError("Decimal128 overflow while applying floor".into()) - }) - } else { - let adjustment = factor + remainder; - value.checked_sub(adjustment).ok_or_else(|| { - ArrowError::ComputeError("Decimal128 overflow while applying floor".into()) - }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use arrow::array::Decimal128Array; - use arrow::datatypes::Field; - use datafusion_common::cast::as_decimal128_array; - use datafusion_common::config::ConfigOptions; - - #[test] - fn test_decimal128_floor() { - let data_type = DataType::Decimal128(10, 2); - let input = Decimal128Array::from(vec![ - Some(1234), - Some(-1234), - Some(1200), - Some(-1200), - None, - ]) - .with_precision_and_scale(10, 2) - .unwrap(); - - let args = ScalarFunctionArgs { - args: vec![ColumnarValue::Array(Arc::new(input) as ArrayRef)], - arg_fields: vec![Field::new("a", data_type.clone(), true).into()], - number_rows: 5, - return_field: Field::new("f", data_type.clone(), true).into(), - config_options: Arc::new(ConfigOptions::default()), - }; - - let result = FloorFunc::new() - .invoke_with_args(args) - .expect("floor evaluation succeeded"); - - let ColumnarValue::Array(result) = result else { - panic!("expected array result"); - }; - - let values = as_decimal128_array(&result).unwrap(); - assert_eq!(result.data_type(), &DataType::Decimal128(10, 2)); - assert_eq!(values.value(0), 1200); - assert_eq!(values.value(1), -1300); - assert_eq!(values.value(2), 1200); - assert_eq!(values.value(3), -1200); - assert!(values.is_null(4)); - } - - #[test] - fn test_decimal128_floor_zero_scale() { - let data_type = DataType::Decimal128(6, 0); - let input = Decimal128Array::from(vec![Some(12), Some(-13), None]) - .with_precision_and_scale(6, 0) - .unwrap(); - - let args = ScalarFunctionArgs { - args: vec![ColumnarValue::Array(Arc::new(input) as ArrayRef)], - arg_fields: vec![Field::new("a", data_type.clone(), true).into()], - number_rows: 3, - return_field: Field::new("f", data_type.clone(), true).into(), - config_options: Arc::new(ConfigOptions::default()), - }; - - let result = FloorFunc::new() - .invoke_with_args(args) - .expect("floor evaluation succeeded"); - - let ColumnarValue::Array(result) = result else { - panic!("expected array result"); - }; - - let values = as_decimal128_array(&result).unwrap(); - assert_eq!(values.value(0), 12); - assert_eq!(values.value(1), -13); - assert!(values.is_null(2)); + self.doc() } } diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index 66370a21e62f..610e773d68fd 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -25,6 +25,7 @@ pub mod abs; pub mod bounds; pub mod ceil; pub mod cot; +mod decimal; pub mod factorial; pub mod floor; pub mod gcd; diff --git a/datafusion/functions/src/math/monotonicity.rs b/datafusion/functions/src/math/monotonicity.rs index ffb3df4196d2..22b3bcb86399 100644 --- a/datafusion/functions/src/math/monotonicity.rs +++ b/datafusion/functions/src/math/monotonicity.rs @@ -309,30 +309,6 @@ pub fn ceil_order(input: &[ExprProperties]) -> Result { Ok(input[0].sort_properties) } -static DOCUMENTATION_CEIL: LazyLock = LazyLock::new(|| { - Documentation::builder( - DOC_SECTION_MATH, - "Returns the nearest integer greater than or equal to a number.", - "ceil(numeric_expression)", - ) - .with_standard_argument("numeric_expression", Some("Numeric")) - .with_sql_example( - r#"```sql - > SELECT ceil(3.14); -+------------+ -| ceil(3.14) | -+------------+ -| 4.0 | -+------------+ -```"#, - ) - .build() -}); - -pub fn get_ceil_doc() -> &'static Documentation { - &DOCUMENTATION_CEIL -} - /// Non-increasing on \[0, π\] and then non-decreasing on \[π, 2π\]. /// This pattern repeats periodically with a period of 2π. // TODO: Implement ordering rule of the ATAN2 function. @@ -467,30 +443,6 @@ pub fn floor_order(input: &[ExprProperties]) -> Result { Ok(input[0].sort_properties) } -static DOCUMENTATION_FLOOR: LazyLock = LazyLock::new(|| { - Documentation::builder( - DOC_SECTION_MATH, - "Returns the nearest integer less than or equal to a number.", - "floor(numeric_expression)", - ) - .with_standard_argument("numeric_expression", Some("Numeric")) - .with_sql_example( - r#"```sql -> SELECT floor(3.14); -+-------------+ -| floor(3.14) | -+-------------+ -| 3.0 | -+-------------+ -```"#, - ) - .build() -}); - -pub fn get_floor_doc() -> &'static Documentation { - &DOCUMENTATION_FLOOR -} - /// Non-decreasing for x ≥ 0, undefined otherwise. pub fn ln_order(input: &[ExprProperties]) -> Result { let arg = &input[0]; diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index ad2673e1b9a0..1194e57b0861 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -327,6 +327,40 @@ select ---- 2 -1 123 -123 +# ceil with decimal32 argument (ensure decimal output) +query TTTTTTTT +select + arrow_typeof(ceil(arrow_cast(9.01,'Decimal32(7,2)'))), + arrow_cast(ceil(arrow_cast(9.01,'Decimal32(7,2)')), 'Utf8'), + arrow_typeof(ceil(arrow_cast(-9.01,'Decimal32(7,2)'))), + arrow_cast(ceil(arrow_cast(-9.01,'Decimal32(7,2)')), 'Utf8'), + arrow_typeof(ceil(arrow_cast(10.00,'Decimal32(7,2)'))), + arrow_cast(ceil(arrow_cast(10.00,'Decimal32(7,2)')), 'Utf8'), + arrow_typeof(ceil(arrow_cast(-0.99,'Decimal32(7,2)'))), + arrow_cast(ceil(arrow_cast(-0.99,'Decimal32(7,2)')), 'Utf8'); +---- +Decimal32(7, 2) 10.00 Decimal32(7, 2) -9.00 Decimal32(7, 2) 10.00 Decimal32(7, 2) 0.00 + +# ceil with decimal64 zero scale +query TTTT +select + arrow_typeof(ceil(arrow_cast(123456789,'Decimal64(18,0)'))), + arrow_cast(ceil(arrow_cast(123456789,'Decimal64(18,0)')), 'Utf8'), + arrow_typeof(ceil(arrow_cast(-987654321,'Decimal64(18,0)'))), + arrow_cast(ceil(arrow_cast(-987654321,'Decimal64(18,0)')), 'Utf8'); +---- +Decimal64(18, 0) 123456789 Decimal64(18, 0) -987654321 + +# ceil with decimal256 argument +query TTTT +select + arrow_typeof(ceil(arrow_cast('9999999999999999999999999999999999.01','Decimal256(38,2)'))), + arrow_cast(ceil(arrow_cast('9999999999999999999999999999999999.01','Decimal256(38,2)')), 'Utf8'), + arrow_typeof(ceil(arrow_cast('-9999999999999999999999999999999999.01','Decimal256(38,2)'))), + arrow_cast(ceil(arrow_cast('-9999999999999999999999999999999999.01','Decimal256(38,2)')), 'Utf8'); +---- +Decimal256(38, 2) 10000000000000000000000000000000000.00 Decimal256(38, 2) -9999999999999999999999999999999999.00 + ## degrees # degrees scalar function @@ -484,6 +518,40 @@ select ---- 1 -2 123 -123 +# floor with decimal32 argument (ensure decimal output) +query TTTTTTTT +select + arrow_typeof(floor(arrow_cast(9.99,'Decimal32(7,2)'))), + arrow_cast(floor(arrow_cast(9.99,'Decimal32(7,2)')), 'Utf8'), + arrow_typeof(floor(arrow_cast(-9.01,'Decimal32(7,2)'))), + arrow_cast(floor(arrow_cast(-9.01,'Decimal32(7,2)')), 'Utf8'), + arrow_typeof(floor(arrow_cast(10.00,'Decimal32(7,2)'))), + arrow_cast(floor(arrow_cast(10.00,'Decimal32(7,2)')), 'Utf8'), + arrow_typeof(floor(arrow_cast(-0.01,'Decimal32(7,2)'))), + arrow_cast(floor(arrow_cast(-0.01,'Decimal32(7,2)')), 'Utf8'); +---- +Decimal32(7, 2) 9.00 Decimal32(7, 2) -10.00 Decimal32(7, 2) 10.00 Decimal32(7, 2) -1.00 + +# floor with decimal64 zero scale +query TTTT +select + arrow_typeof(floor(arrow_cast(123456789,'Decimal64(18,0)'))), + arrow_cast(floor(arrow_cast(123456789,'Decimal64(18,0)')), 'Utf8'), + arrow_typeof(floor(arrow_cast(-987654321,'Decimal64(18,0)'))), + arrow_cast(floor(arrow_cast(-987654321,'Decimal64(18,0)')), 'Utf8'); +---- +Decimal64(18, 0) 123456789 Decimal64(18, 0) -987654321 + +# floor with decimal256 argument +query TTTT +select + arrow_typeof(floor(arrow_cast('9999999999999999999999999999999999.99','Decimal256(38,2)'))), + arrow_cast(floor(arrow_cast('9999999999999999999999999999999999.99','Decimal256(38,2)')), 'Utf8'), + arrow_typeof(floor(arrow_cast('-9999999999999999999999999999999999.99','Decimal256(38,2)'))), + arrow_cast(floor(arrow_cast('-9999999999999999999999999999999999.99','Decimal256(38,2)')), 'Utf8'); +---- +Decimal256(38, 2) 9999999999999999999999999999999999.00 Decimal256(38, 2) -10000000000000000000000000000000000.00 + ## ln # ln scalar function From 7ac92ff4a88c5e365113fac4b186fb3e7b462654 Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Sat, 29 Nov 2025 12:15:39 +0530 Subject: [PATCH 3/6] Update datafusion/functions/src/math/ceil.rs Co-authored-by: Jeffrey Vo --- datafusion/functions/src/math/ceil.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/functions/src/math/ceil.rs b/datafusion/functions/src/math/ceil.rs index 0845f5c5e4cd..1454d46e17c9 100644 --- a/datafusion/functions/src/math/ceil.rs +++ b/datafusion/functions/src/math/ceil.rs @@ -122,7 +122,7 @@ impl ScalarUDFImpl for CeilFunc { .as_primitive::() .unary::<_, Float32Type>(f32::ceil), ), - DataType::Null => cast(value.as_ref(), &DataType::Float64)?, + DataType::Null => return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))), DataType::Decimal32(_, scale) => apply_decimal_op::( value, *scale, From ae2fbd01d27b52f9116459c3b519e3f7e9263d23 Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Sat, 29 Nov 2025 12:44:26 +0530 Subject: [PATCH 4/6] used alreay present traits for decimal --- datafusion/execution/src/config.rs | 4 +- datafusion/functions/src/math/ceil.rs | 25 +-- datafusion/functions/src/math/decimal.rs | 148 ++++-------------- datafusion/functions/src/math/floor.rs | 25 +-- datafusion/sqllogictest/test_files/scalar.slt | 32 ++-- 5 files changed, 66 insertions(+), 168 deletions(-) diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index 443229a3cb77..3fa602f12554 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -114,10 +114,10 @@ impl Default for SessionConfig { } /// A type map for storing extensions. -/// +/// /// Extensions are indexed by their type `T`. If multiple values of the same type are provided, only the last one /// will be kept. -/// +/// /// Extensions are opaque objects that are unknown to DataFusion itself but can be downcast by optimizer rules, /// execution plans, or other components that have access to the session config. /// They provide a flexible way to attach extra data or behavior to the session config. diff --git a/datafusion/functions/src/math/ceil.rs b/datafusion/functions/src/math/ceil.rs index 1454d46e17c9..1b61d44a7c0d 100644 --- a/datafusion/functions/src/math/ceil.rs +++ b/datafusion/functions/src/math/ceil.rs @@ -19,12 +19,11 @@ use std::any::Any; use std::sync::Arc; use arrow::array::{ArrayRef, AsArray}; -use arrow::compute::cast; use arrow::datatypes::{ DataType, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type, Float32Type, Float64Type, }; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ @@ -89,21 +88,9 @@ impl ScalarUDFImpl for CeilFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - match arg_types[0] { - DataType::Float32 => Ok(DataType::Float32), - DataType::Decimal32(precision, scale) => { - Ok(DataType::Decimal32(precision, scale)) - } - DataType::Decimal64(precision, scale) => { - Ok(DataType::Decimal64(precision, scale)) - } - DataType::Decimal128(precision, scale) => { - Ok(DataType::Decimal128(precision, scale)) - } - DataType::Decimal256(precision, scale) => { - Ok(DataType::Decimal256(precision, scale)) - } - _ => Ok(DataType::Float64), + match &arg_types[0] { + DataType::Null => Ok(DataType::Float64), + other => Ok(other.clone()), } } @@ -122,7 +109,9 @@ impl ScalarUDFImpl for CeilFunc { .as_primitive::() .unary::<_, Float32Type>(f32::ceil), ), - DataType::Null => return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))), + DataType::Null => { + return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))) + } DataType::Decimal32(_, scale) => apply_decimal_op::( value, *scale, diff --git a/datafusion/functions/src/math/decimal.rs b/datafusion/functions/src/math/decimal.rs index 78cfd31800af..060fbdb4e0d5 100644 --- a/datafusion/functions/src/math/decimal.rs +++ b/datafusion/functions/src/math/decimal.rs @@ -19,92 +19,10 @@ use std::ops::Rem; use std::sync::Arc; use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; -use arrow::datatypes::DecimalType; +use arrow::datatypes::{ArrowNativeTypeOp, DecimalType}; use arrow::error::ArrowError; -use arrow_buffer::i256; -use datafusion_common::{exec_err, Result}; - -/// Operations required to manipulate the native representation of Arrow decimal arrays. -pub(super) trait DecimalNative: - Copy + Rem + PartialEq + PartialOrd -{ - fn zero() -> Self; - fn checked_add(self, other: Self) -> Option; - fn checked_sub(self, other: Self) -> Option; - fn checked_pow10(exp: u32) -> Option; -} - -impl DecimalNative for i32 { - fn zero() -> Self { - 0 - } - - fn checked_add(self, other: Self) -> Option { - self.checked_add(other) - } - - fn checked_sub(self, other: Self) -> Option { - self.checked_sub(other) - } - - fn checked_pow10(exp: u32) -> Option { - 10_i32.checked_pow(exp) - } -} - -impl DecimalNative for i64 { - fn zero() -> Self { - 0 - } - - fn checked_add(self, other: Self) -> Option { - self.checked_add(other) - } - - fn checked_sub(self, other: Self) -> Option { - self.checked_sub(other) - } - - fn checked_pow10(exp: u32) -> Option { - 10_i64.checked_pow(exp) - } -} - -impl DecimalNative for i128 { - fn zero() -> Self { - 0 - } - - fn checked_add(self, other: Self) -> Option { - self.checked_add(other) - } - - fn checked_sub(self, other: Self) -> Option { - self.checked_sub(other) - } - - fn checked_pow10(exp: u32) -> Option { - 10_i128.checked_pow(exp) - } -} - -impl DecimalNative for i256 { - fn zero() -> Self { - i256::ZERO - } - - fn checked_add(self, other: Self) -> Option { - self.checked_add(other) - } - - fn checked_sub(self, other: Self) -> Option { - self.checked_sub(other) - } - - fn checked_pow10(exp: u32) -> Option { - i256::from_i128(10).checked_pow(exp) - } -} +use arrow_buffer::ArrowNativeType; +use datafusion_common::{DataFusionError, Result}; pub(super) fn apply_decimal_op( array: &ArrayRef, @@ -114,7 +32,7 @@ pub(super) fn apply_decimal_op( ) -> Result where T: DecimalType, - T::Native: DecimalNative, + T::Native: ArrowNativeType + ArrowNativeTypeOp, F: Fn(T::Native, T::Native) -> std::result::Result, { if scale <= 0 { @@ -135,17 +53,19 @@ where fn decimal_scale_factor(scale: i8, fn_name: &str) -> Result where T: DecimalType, - T::Native: DecimalNative, + T::Native: ArrowNativeType + ArrowNativeTypeOp, { - if scale < 0 { - return exec_err!("Decimal scale {scale} must be non-negative"); - } - - if let Some(value) = T::Native::checked_pow10(scale as u32) { - Ok(value) - } else { - exec_err!("Decimal scale {scale} is too large for {fn_name}") - } + let base = ::from_usize(10).ok_or_else(|| { + DataFusionError::Execution(format!( + "Decimal scale {scale} is too large for {fn_name}" + )) + })?; + + base.pow_checked(scale as u32).map_err(|_| { + DataFusionError::Execution(format!( + "Decimal scale {scale} is too large for {fn_name}" + )) + }) } pub(super) fn ceil_decimal_value( @@ -153,25 +73,25 @@ pub(super) fn ceil_decimal_value( factor: T, ) -> std::result::Result where - T: DecimalNative, + T: ArrowNativeTypeOp + Rem, { let remainder = value % factor; - if remainder == T::zero() { + if remainder == T::ZERO { return Ok(value); } - if value >= T::zero() { + if value >= T::ZERO { let increment = factor - .checked_sub(remainder) - .ok_or_else(|| overflow_err("ceil"))?; + .sub_checked(remainder) + .map_err(|_| overflow_err("ceil"))?; value - .checked_add(increment) - .ok_or_else(|| overflow_err("ceil")) + .add_checked(increment) + .map_err(|_| overflow_err("ceil")) } else { value - .checked_sub(remainder) - .ok_or_else(|| overflow_err("ceil")) + .sub_checked(remainder) + .map_err(|_| overflow_err("ceil")) } } @@ -180,25 +100,25 @@ pub(super) fn floor_decimal_value( factor: T, ) -> std::result::Result where - T: DecimalNative, + T: ArrowNativeTypeOp + Rem, { let remainder = value % factor; - if remainder == T::zero() { + if remainder == T::ZERO { return Ok(value); } - if value >= T::zero() { + if value >= T::ZERO { value - .checked_sub(remainder) - .ok_or_else(|| overflow_err("floor")) + .sub_checked(remainder) + .map_err(|_| overflow_err("floor")) } else { let adjustment = factor - .checked_add(remainder) - .ok_or_else(|| overflow_err("floor"))?; + .add_checked(remainder) + .map_err(|_| overflow_err("floor"))?; value - .checked_sub(adjustment) - .ok_or_else(|| overflow_err("floor")) + .sub_checked(adjustment) + .map_err(|_| overflow_err("floor")) } } diff --git a/datafusion/functions/src/math/floor.rs b/datafusion/functions/src/math/floor.rs index df283b200e7a..4a2d8d8b3bf5 100644 --- a/datafusion/functions/src/math/floor.rs +++ b/datafusion/functions/src/math/floor.rs @@ -19,12 +19,11 @@ use std::any::Any; use std::sync::Arc; use arrow::array::{ArrayRef, AsArray}; -use arrow::compute::cast; use arrow::datatypes::{ DataType, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type, Float32Type, Float64Type, }; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ @@ -89,21 +88,9 @@ impl ScalarUDFImpl for FloorFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - match arg_types[0] { - DataType::Float32 => Ok(DataType::Float32), - DataType::Decimal32(precision, scale) => { - Ok(DataType::Decimal32(precision, scale)) - } - DataType::Decimal64(precision, scale) => { - Ok(DataType::Decimal64(precision, scale)) - } - DataType::Decimal128(precision, scale) => { - Ok(DataType::Decimal128(precision, scale)) - } - DataType::Decimal256(precision, scale) => { - Ok(DataType::Decimal256(precision, scale)) - } - _ => Ok(DataType::Float64), + match &arg_types[0] { + DataType::Null => Ok(DataType::Float64), + other => Ok(other.clone()), } } @@ -122,7 +109,9 @@ impl ScalarUDFImpl for FloorFunc { .as_primitive::() .unary::<_, Float32Type>(f32::floor), ), - DataType::Null => cast(value.as_ref(), &DataType::Float64)?, + DataType::Null => { + return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))) + } DataType::Decimal32(_, scale) => apply_decimal_op::( value, *scale, diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 1194e57b0861..363177ce00ad 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -331,13 +331,13 @@ select query TTTTTTTT select arrow_typeof(ceil(arrow_cast(9.01,'Decimal32(7,2)'))), - arrow_cast(ceil(arrow_cast(9.01,'Decimal32(7,2)')), 'Utf8'), + ceil(arrow_cast(9.01,'Decimal32(7,2)')), arrow_typeof(ceil(arrow_cast(-9.01,'Decimal32(7,2)'))), - arrow_cast(ceil(arrow_cast(-9.01,'Decimal32(7,2)')), 'Utf8'), + ceil(arrow_cast(-9.01,'Decimal32(7,2)')), arrow_typeof(ceil(arrow_cast(10.00,'Decimal32(7,2)'))), - arrow_cast(ceil(arrow_cast(10.00,'Decimal32(7,2)')), 'Utf8'), + ceil(arrow_cast(10.00,'Decimal32(7,2)')), arrow_typeof(ceil(arrow_cast(-0.99,'Decimal32(7,2)'))), - arrow_cast(ceil(arrow_cast(-0.99,'Decimal32(7,2)')), 'Utf8'); + ceil(arrow_cast(-0.99,'Decimal32(7,2)')); ---- Decimal32(7, 2) 10.00 Decimal32(7, 2) -9.00 Decimal32(7, 2) 10.00 Decimal32(7, 2) 0.00 @@ -345,9 +345,9 @@ Decimal32(7, 2) 10.00 Decimal32(7, 2) -9.00 Decimal32(7, 2) 10.00 Decimal32(7, 2 query TTTT select arrow_typeof(ceil(arrow_cast(123456789,'Decimal64(18,0)'))), - arrow_cast(ceil(arrow_cast(123456789,'Decimal64(18,0)')), 'Utf8'), + ceil(arrow_cast(123456789,'Decimal64(18,0)')), arrow_typeof(ceil(arrow_cast(-987654321,'Decimal64(18,0)'))), - arrow_cast(ceil(arrow_cast(-987654321,'Decimal64(18,0)')), 'Utf8'); + ceil(arrow_cast(-987654321,'Decimal64(18,0)')); ---- Decimal64(18, 0) 123456789 Decimal64(18, 0) -987654321 @@ -355,9 +355,9 @@ Decimal64(18, 0) 123456789 Decimal64(18, 0) -987654321 query TTTT select arrow_typeof(ceil(arrow_cast('9999999999999999999999999999999999.01','Decimal256(38,2)'))), - arrow_cast(ceil(arrow_cast('9999999999999999999999999999999999.01','Decimal256(38,2)')), 'Utf8'), + ceil(arrow_cast('9999999999999999999999999999999999.01','Decimal256(38,2)')), arrow_typeof(ceil(arrow_cast('-9999999999999999999999999999999999.01','Decimal256(38,2)'))), - arrow_cast(ceil(arrow_cast('-9999999999999999999999999999999999.01','Decimal256(38,2)')), 'Utf8'); + ceil(arrow_cast('-9999999999999999999999999999999999.01','Decimal256(38,2)')); ---- Decimal256(38, 2) 10000000000000000000000000000000000.00 Decimal256(38, 2) -9999999999999999999999999999999999.00 @@ -522,13 +522,13 @@ select query TTTTTTTT select arrow_typeof(floor(arrow_cast(9.99,'Decimal32(7,2)'))), - arrow_cast(floor(arrow_cast(9.99,'Decimal32(7,2)')), 'Utf8'), + floor(arrow_cast(9.99,'Decimal32(7,2)')), arrow_typeof(floor(arrow_cast(-9.01,'Decimal32(7,2)'))), - arrow_cast(floor(arrow_cast(-9.01,'Decimal32(7,2)')), 'Utf8'), + floor(arrow_cast(-9.01,'Decimal32(7,2)')), arrow_typeof(floor(arrow_cast(10.00,'Decimal32(7,2)'))), - arrow_cast(floor(arrow_cast(10.00,'Decimal32(7,2)')), 'Utf8'), + floor(arrow_cast(10.00,'Decimal32(7,2)')), arrow_typeof(floor(arrow_cast(-0.01,'Decimal32(7,2)'))), - arrow_cast(floor(arrow_cast(-0.01,'Decimal32(7,2)')), 'Utf8'); + floor(arrow_cast(-0.01,'Decimal32(7,2)')); ---- Decimal32(7, 2) 9.00 Decimal32(7, 2) -10.00 Decimal32(7, 2) 10.00 Decimal32(7, 2) -1.00 @@ -536,9 +536,9 @@ Decimal32(7, 2) 9.00 Decimal32(7, 2) -10.00 Decimal32(7, 2) 10.00 Decimal32(7, 2 query TTTT select arrow_typeof(floor(arrow_cast(123456789,'Decimal64(18,0)'))), - arrow_cast(floor(arrow_cast(123456789,'Decimal64(18,0)')), 'Utf8'), + floor(arrow_cast(123456789,'Decimal64(18,0)')), arrow_typeof(floor(arrow_cast(-987654321,'Decimal64(18,0)'))), - arrow_cast(floor(arrow_cast(-987654321,'Decimal64(18,0)')), 'Utf8'); + floor(arrow_cast(-987654321,'Decimal64(18,0)')); ---- Decimal64(18, 0) 123456789 Decimal64(18, 0) -987654321 @@ -546,9 +546,9 @@ Decimal64(18, 0) 123456789 Decimal64(18, 0) -987654321 query TTTT select arrow_typeof(floor(arrow_cast('9999999999999999999999999999999999.99','Decimal256(38,2)'))), - arrow_cast(floor(arrow_cast('9999999999999999999999999999999999.99','Decimal256(38,2)')), 'Utf8'), + floor(arrow_cast('9999999999999999999999999999999999.99','Decimal256(38,2)')), arrow_typeof(floor(arrow_cast('-9999999999999999999999999999999999.99','Decimal256(38,2)'))), - arrow_cast(floor(arrow_cast('-9999999999999999999999999999999999.99','Decimal256(38,2)')), 'Utf8'); + floor(arrow_cast('-9999999999999999999999999999999999.99','Decimal256(38,2)')); ---- Decimal256(38, 2) 9999999999999999999999999999999999.00 Decimal256(38, 2) -10000000000000000000000000000000000.00 From df2be81d4ede0e981fab823adbe54961b14408a2 Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Sat, 29 Nov 2025 17:04:20 +0530 Subject: [PATCH 5/6] fixed tests and updated docs --- datafusion/sqllogictest/test_files/scalar.slt | 32 +++++++++---------- .../source/user-guide/sql/scalar_functions.md | 2 +- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 363177ce00ad..1194e57b0861 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -331,13 +331,13 @@ select query TTTTTTTT select arrow_typeof(ceil(arrow_cast(9.01,'Decimal32(7,2)'))), - ceil(arrow_cast(9.01,'Decimal32(7,2)')), + arrow_cast(ceil(arrow_cast(9.01,'Decimal32(7,2)')), 'Utf8'), arrow_typeof(ceil(arrow_cast(-9.01,'Decimal32(7,2)'))), - ceil(arrow_cast(-9.01,'Decimal32(7,2)')), + arrow_cast(ceil(arrow_cast(-9.01,'Decimal32(7,2)')), 'Utf8'), arrow_typeof(ceil(arrow_cast(10.00,'Decimal32(7,2)'))), - ceil(arrow_cast(10.00,'Decimal32(7,2)')), + arrow_cast(ceil(arrow_cast(10.00,'Decimal32(7,2)')), 'Utf8'), arrow_typeof(ceil(arrow_cast(-0.99,'Decimal32(7,2)'))), - ceil(arrow_cast(-0.99,'Decimal32(7,2)')); + arrow_cast(ceil(arrow_cast(-0.99,'Decimal32(7,2)')), 'Utf8'); ---- Decimal32(7, 2) 10.00 Decimal32(7, 2) -9.00 Decimal32(7, 2) 10.00 Decimal32(7, 2) 0.00 @@ -345,9 +345,9 @@ Decimal32(7, 2) 10.00 Decimal32(7, 2) -9.00 Decimal32(7, 2) 10.00 Decimal32(7, 2 query TTTT select arrow_typeof(ceil(arrow_cast(123456789,'Decimal64(18,0)'))), - ceil(arrow_cast(123456789,'Decimal64(18,0)')), + arrow_cast(ceil(arrow_cast(123456789,'Decimal64(18,0)')), 'Utf8'), arrow_typeof(ceil(arrow_cast(-987654321,'Decimal64(18,0)'))), - ceil(arrow_cast(-987654321,'Decimal64(18,0)')); + arrow_cast(ceil(arrow_cast(-987654321,'Decimal64(18,0)')), 'Utf8'); ---- Decimal64(18, 0) 123456789 Decimal64(18, 0) -987654321 @@ -355,9 +355,9 @@ Decimal64(18, 0) 123456789 Decimal64(18, 0) -987654321 query TTTT select arrow_typeof(ceil(arrow_cast('9999999999999999999999999999999999.01','Decimal256(38,2)'))), - ceil(arrow_cast('9999999999999999999999999999999999.01','Decimal256(38,2)')), + arrow_cast(ceil(arrow_cast('9999999999999999999999999999999999.01','Decimal256(38,2)')), 'Utf8'), arrow_typeof(ceil(arrow_cast('-9999999999999999999999999999999999.01','Decimal256(38,2)'))), - ceil(arrow_cast('-9999999999999999999999999999999999.01','Decimal256(38,2)')); + arrow_cast(ceil(arrow_cast('-9999999999999999999999999999999999.01','Decimal256(38,2)')), 'Utf8'); ---- Decimal256(38, 2) 10000000000000000000000000000000000.00 Decimal256(38, 2) -9999999999999999999999999999999999.00 @@ -522,13 +522,13 @@ select query TTTTTTTT select arrow_typeof(floor(arrow_cast(9.99,'Decimal32(7,2)'))), - floor(arrow_cast(9.99,'Decimal32(7,2)')), + arrow_cast(floor(arrow_cast(9.99,'Decimal32(7,2)')), 'Utf8'), arrow_typeof(floor(arrow_cast(-9.01,'Decimal32(7,2)'))), - floor(arrow_cast(-9.01,'Decimal32(7,2)')), + arrow_cast(floor(arrow_cast(-9.01,'Decimal32(7,2)')), 'Utf8'), arrow_typeof(floor(arrow_cast(10.00,'Decimal32(7,2)'))), - floor(arrow_cast(10.00,'Decimal32(7,2)')), + arrow_cast(floor(arrow_cast(10.00,'Decimal32(7,2)')), 'Utf8'), arrow_typeof(floor(arrow_cast(-0.01,'Decimal32(7,2)'))), - floor(arrow_cast(-0.01,'Decimal32(7,2)')); + arrow_cast(floor(arrow_cast(-0.01,'Decimal32(7,2)')), 'Utf8'); ---- Decimal32(7, 2) 9.00 Decimal32(7, 2) -10.00 Decimal32(7, 2) 10.00 Decimal32(7, 2) -1.00 @@ -536,9 +536,9 @@ Decimal32(7, 2) 9.00 Decimal32(7, 2) -10.00 Decimal32(7, 2) 10.00 Decimal32(7, 2 query TTTT select arrow_typeof(floor(arrow_cast(123456789,'Decimal64(18,0)'))), - floor(arrow_cast(123456789,'Decimal64(18,0)')), + arrow_cast(floor(arrow_cast(123456789,'Decimal64(18,0)')), 'Utf8'), arrow_typeof(floor(arrow_cast(-987654321,'Decimal64(18,0)'))), - floor(arrow_cast(-987654321,'Decimal64(18,0)')); + arrow_cast(floor(arrow_cast(-987654321,'Decimal64(18,0)')), 'Utf8'); ---- Decimal64(18, 0) 123456789 Decimal64(18, 0) -987654321 @@ -546,9 +546,9 @@ Decimal64(18, 0) 123456789 Decimal64(18, 0) -987654321 query TTTT select arrow_typeof(floor(arrow_cast('9999999999999999999999999999999999.99','Decimal256(38,2)'))), - floor(arrow_cast('9999999999999999999999999999999999.99','Decimal256(38,2)')), + arrow_cast(floor(arrow_cast('9999999999999999999999999999999999.99','Decimal256(38,2)')), 'Utf8'), arrow_typeof(floor(arrow_cast('-9999999999999999999999999999999999.99','Decimal256(38,2)'))), - floor(arrow_cast('-9999999999999999999999999999999999.99','Decimal256(38,2)')); + arrow_cast(floor(arrow_cast('-9999999999999999999999999999999999.99','Decimal256(38,2)')), 'Utf8'); ---- Decimal256(38, 2) 9999999999999999999999999999999999.00 Decimal256(38, 2) -10000000000000000000000000000000000.00 diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index c5380b2e84b7..2d6bb2817cfc 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -294,7 +294,7 @@ ceil(numeric_expression) #### Example ```sql - > SELECT ceil(3.14); +> SELECT ceil(3.14); +------------+ | ceil(3.14) | +------------+ From fc22f632023b6c37c3fdbf7122516547c36d36c0 Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Sat, 29 Nov 2025 21:08:39 +0530 Subject: [PATCH 6/6] tests for floor/ceil overflow --- datafusion/functions/src/math/ceil.rs | 60 +++++++++------- datafusion/functions/src/math/decimal.rs | 70 +++++++------------ datafusion/functions/src/math/floor.rs | 60 +++++++++------- datafusion/sqllogictest/test_files/scalar.slt | 8 +++ 4 files changed, 105 insertions(+), 93 deletions(-) diff --git a/datafusion/functions/src/math/ceil.rs b/datafusion/functions/src/math/ceil.rs index 1b61d44a7c0d..68a76c7bc3b3 100644 --- a/datafusion/functions/src/math/ceil.rs +++ b/datafusion/functions/src/math/ceil.rs @@ -112,30 +112,42 @@ impl ScalarUDFImpl for CeilFunc { DataType::Null => { return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))) } - DataType::Decimal32(_, scale) => apply_decimal_op::( - value, - *scale, - self.name(), - ceil_decimal_value, - )?, - DataType::Decimal64(_, scale) => apply_decimal_op::( - value, - *scale, - self.name(), - ceil_decimal_value, - )?, - DataType::Decimal128(_, scale) => apply_decimal_op::( - value, - *scale, - self.name(), - ceil_decimal_value, - )?, - DataType::Decimal256(_, scale) => apply_decimal_op::( - value, - *scale, - self.name(), - ceil_decimal_value, - )?, + DataType::Decimal32(precision, scale) => { + apply_decimal_op::( + value, + *precision, + *scale, + self.name(), + ceil_decimal_value, + )? + } + DataType::Decimal64(precision, scale) => { + apply_decimal_op::( + value, + *precision, + *scale, + self.name(), + ceil_decimal_value, + )? + } + DataType::Decimal128(precision, scale) => { + apply_decimal_op::( + value, + *precision, + *scale, + self.name(), + ceil_decimal_value, + )? + } + DataType::Decimal256(precision, scale) => { + apply_decimal_op::( + value, + *precision, + *scale, + self.name(), + ceil_decimal_value, + )? + } other => { return exec_err!( "Unsupported data type {other:?} for function {}", diff --git a/datafusion/functions/src/math/decimal.rs b/datafusion/functions/src/math/decimal.rs index 060fbdb4e0d5..b27130dbd3a0 100644 --- a/datafusion/functions/src/math/decimal.rs +++ b/datafusion/functions/src/math/decimal.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::ops::Rem; use std::sync::Arc; use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; @@ -26,6 +25,7 @@ use datafusion_common::{DataFusionError, Result}; pub(super) fn apply_decimal_op( array: &ArrayRef, + precision: u8, scale: i8, fn_name: &str, op: F, @@ -33,7 +33,7 @@ pub(super) fn apply_decimal_op( where T: DecimalType, T::Native: ArrowNativeType + ArrowNativeTypeOp, - F: Fn(T::Native, T::Native) -> std::result::Result, + F: Fn(T::Native, T::Native) -> T::Native, { if scale <= 0 { return Ok(Arc::clone(array)); @@ -43,9 +43,15 @@ where let decimal = array.as_primitive::(); let data_type = array.data_type().clone(); - let result: PrimitiveArray = decimal - .try_unary(|value| op(value, factor))? - .with_data_type(data_type); + let result: PrimitiveArray = decimal.try_unary(|value| { + let new_value = op(value, factor); + T::validate_decimal_precision(new_value, precision, scale).map_err(|_| { + ArrowError::ComputeError(format!("Decimal overflow while applying {fn_name}")) + })?; + Ok::<_, ArrowError>(new_value) + })?; + + let result = result.with_data_type(data_type); Ok(Arc::new(result)) } @@ -56,72 +62,46 @@ where T::Native: ArrowNativeType + ArrowNativeTypeOp, { let base = ::from_usize(10).ok_or_else(|| { - DataFusionError::Execution(format!( - "Decimal scale {scale} is too large for {fn_name}" - )) + DataFusionError::Execution(format!("Decimal overflow while applying {fn_name}")) })?; base.pow_checked(scale as u32).map_err(|_| { - DataFusionError::Execution(format!( - "Decimal scale {scale} is too large for {fn_name}" - )) + DataFusionError::Execution(format!("Decimal overflow while applying {fn_name}")) }) } -pub(super) fn ceil_decimal_value( - value: T, - factor: T, -) -> std::result::Result +pub(super) fn ceil_decimal_value(value: T, factor: T) -> T where - T: ArrowNativeTypeOp + Rem, + T: ArrowNativeTypeOp + std::ops::Rem, { let remainder = value % factor; if remainder == T::ZERO { - return Ok(value); + return value; } if value >= T::ZERO { - let increment = factor - .sub_checked(remainder) - .map_err(|_| overflow_err("ceil"))?; - value - .add_checked(increment) - .map_err(|_| overflow_err("ceil")) + let increment = factor.sub_wrapping(remainder); + value.add_wrapping(increment) } else { - value - .sub_checked(remainder) - .map_err(|_| overflow_err("ceil")) + value.sub_wrapping(remainder) } } -pub(super) fn floor_decimal_value( - value: T, - factor: T, -) -> std::result::Result +pub(super) fn floor_decimal_value(value: T, factor: T) -> T where - T: ArrowNativeTypeOp + Rem, + T: ArrowNativeTypeOp + std::ops::Rem, { let remainder = value % factor; if remainder == T::ZERO { - return Ok(value); + return value; } if value >= T::ZERO { - value - .sub_checked(remainder) - .map_err(|_| overflow_err("floor")) + value.sub_wrapping(remainder) } else { - let adjustment = factor - .add_checked(remainder) - .map_err(|_| overflow_err("floor"))?; - value - .sub_checked(adjustment) - .map_err(|_| overflow_err("floor")) + let adjustment = factor.add_wrapping(remainder); + value.sub_wrapping(adjustment) } } - -fn overflow_err(name: &str) -> ArrowError { - ArrowError::ComputeError(format!("Decimal overflow while applying {name}")) -} diff --git a/datafusion/functions/src/math/floor.rs b/datafusion/functions/src/math/floor.rs index 4a2d8d8b3bf5..8a4a888555cd 100644 --- a/datafusion/functions/src/math/floor.rs +++ b/datafusion/functions/src/math/floor.rs @@ -112,30 +112,42 @@ impl ScalarUDFImpl for FloorFunc { DataType::Null => { return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))) } - DataType::Decimal32(_, scale) => apply_decimal_op::( - value, - *scale, - self.name(), - floor_decimal_value, - )?, - DataType::Decimal64(_, scale) => apply_decimal_op::( - value, - *scale, - self.name(), - floor_decimal_value, - )?, - DataType::Decimal128(_, scale) => apply_decimal_op::( - value, - *scale, - self.name(), - floor_decimal_value, - )?, - DataType::Decimal256(_, scale) => apply_decimal_op::( - value, - *scale, - self.name(), - floor_decimal_value, - )?, + DataType::Decimal32(precision, scale) => { + apply_decimal_op::( + value, + *precision, + *scale, + self.name(), + floor_decimal_value, + )? + } + DataType::Decimal64(precision, scale) => { + apply_decimal_op::( + value, + *precision, + *scale, + self.name(), + floor_decimal_value, + )? + } + DataType::Decimal128(precision, scale) => { + apply_decimal_op::( + value, + *precision, + *scale, + self.name(), + floor_decimal_value, + )? + } + DataType::Decimal256(precision, scale) => { + apply_decimal_op::( + value, + *precision, + *scale, + self.name(), + floor_decimal_value, + )? + } other => { return exec_err!( "Unsupported data type {other:?} for function {}", diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 1194e57b0861..3a9e238e6b38 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -327,6 +327,10 @@ select ---- 2 -1 123 -123 +# ceil overflow with limited precision +query error Decimal overflow while applying ceil +select ceil(arrow_cast(9.23,'Decimal128(3,2)')); + # ceil with decimal32 argument (ensure decimal output) query TTTTTTTT select @@ -518,6 +522,10 @@ select ---- 1 -2 123 -123 +# floor overflow with limited precision +query error Decimal overflow while applying floor +select floor(arrow_cast(-9.23,'Decimal128(3,2)')); + # floor with decimal32 argument (ensure decimal output) query TTTTTTTT select