From 06387c01ac289c0b007cd49aa79318a6771a83f9 Mon Sep 17 00:00:00 2001 From: Arshdeep54 Date: Sun, 10 Aug 2025 20:14:37 +0530 Subject: [PATCH] feat: Add decimal support for `round` Signed-off-by: Arshdeep54 --- datafusion/functions/src/math/round.rs | 290 +++++++++++++++++++++++-- 1 file changed, 276 insertions(+), 14 deletions(-) diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index e13d6b8f9a52..4185ae4f53f7 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -22,11 +22,15 @@ use crate::utils::make_scalar_function; use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; use arrow::compute::{cast_with_options, CastOptions}; -use arrow::datatypes::DataType::{Float32, Float64, Int32}; -use arrow::datatypes::{DataType, Float32Type, Float64Type, Int32Type}; +use arrow::datatypes::DataType::{ + Decimal128, Decimal256, Float32, Float64, Int32, Int64, +}; +use arrow::datatypes::{ + DataType, Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int32Type, +}; +use arrow_buffer::i256; use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; -use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -56,17 +60,8 @@ impl Default for RoundFunc { impl RoundFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::one_of( - vec![ - Exact(vec![Float64, Int64]), - Exact(vec![Float32, Int64]), - Exact(vec![Float64]), - Exact(vec![Float32]), - ], - Volatility::Immutable, - ), + signature: Signature::user_defined(Volatility::Immutable), } } } @@ -84,9 +79,41 @@ impl ScalarUDFImpl for RoundFunc { &self.signature } + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 1 && arg_types.len() != 2 { + return exec_err!( + "round function requires one or two arguments, got {}", + arg_types.len() + ); + } + + if arg_types.len() == 1 { + match arg_types[0].clone() { + Decimal128(p, s) => Ok(vec![Decimal128(p, s)]), + Decimal256(p, s) => Ok(vec![Decimal256(p, s)]), + Float32 => Ok(vec![Float32]), + _ => Ok(vec![Float64]), + } + } else if arg_types.len() == 2 { + match arg_types[0].clone() { + Decimal128(p, s) => Ok(vec![Decimal128(p, s), Int64]), + Decimal256(p, s) => Ok(vec![Decimal256(p, s), Int64]), + Float32 => Ok(vec![Float32, Int64]), + _ => Ok(vec![Float64, Int64]), + } + } else { + exec_err!( + "round function requires one or two arguments, got {}", + arg_types.len() + ) + } + } + fn return_type(&self, arg_types: &[DataType]) -> Result { match arg_types[0] { Float32 => Ok(Float32), + Decimal128(p, s) => Ok(Decimal128(p, s)), + Decimal256(p, s) => Ok(Decimal256(p, s)), _ => Ok(Float64), } } @@ -215,17 +242,141 @@ pub fn round(args: &[ArrayRef]) -> Result { } }, + Decimal128(precision, scale) => match decimal_places { + ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => { + let decimal_places: i32 = decimal_places.try_into().map_err(|e| { + exec_datafusion_err!( + "Invalid value for decimal places: {decimal_places}: {e}" + ) + })?; + + let values = args[0].as_primitive::(); + let result = values.unary::<_, Decimal128Type>(|value| { + round_decimal128(value, *scale, decimal_places) + }); + + Ok(Arc::new(result.with_precision_and_scale(*precision, *scale)?) as _) + } + ColumnarValue::Array(decimal_places) => { + let options = CastOptions { + safe: false, // raise error if the cast is not possible + ..Default::default() + }; + let decimal_places = cast_with_options(&decimal_places, &Int32, &options) + .map_err(|e| { + exec_datafusion_err!("Invalid values for decimal places: {e}") + })?; + + let values = args[0].as_primitive::(); + let decimal_places = decimal_places.as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Decimal128Type>( + values, + decimal_places, + |value, decimal_places| { + round_decimal128(value, *scale, decimal_places) + }, + )?; + + Ok(Arc::new(result.with_precision_and_scale(*precision, *scale)?) as _) + } + _ => { + exec_err!("round function requires a scalar or array for decimal_places") + } + }, + + Decimal256(precision, scale) => match decimal_places { + ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => { + let decimal_places: i32 = decimal_places.try_into().map_err(|e| { + exec_datafusion_err!( + "Invalid value for decimal places: {decimal_places}: {e}" + ) + })?; + + let values = args[0].as_primitive::(); + let result = values.unary::<_, Decimal256Type>(|value| { + round_decimal256(value, *scale, decimal_places) + }); + + Ok(Arc::new(result.with_precision_and_scale(*precision, *scale)?) as _) + } + ColumnarValue::Array(decimal_places) => { + let options = CastOptions { + safe: false, + ..Default::default() + }; + let decimal_places = cast_with_options(&decimal_places, &Int32, &options) + .map_err(|e| { + exec_datafusion_err!("Invalid values for decimal places: {e}") + })?; + + let values = args[0].as_primitive::(); + let decimal_places = decimal_places.as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Decimal256Type>( + values, + decimal_places, + |value, decimal_places| { + round_decimal256(value, *scale, decimal_places) + }, + )?; + + Ok(Arc::new(result.with_precision_and_scale(*precision, *scale)?) as _) + } + _ => { + exec_err!("round function requires a scalar or array for decimal_places") + } + }, + other => exec_err!("Unsupported data type {other:?} for function round"), } } +#[inline] +fn round_decimal128(value: i128, current_scale: i8, decimal_places: i32) -> i128 { + let scale_adjustment = current_scale as i32 - decimal_places; + + if scale_adjustment > 0 { + let remove_factor = 10_i128.pow(scale_adjustment as u32); + let half = remove_factor / 2; + + if value >= 0 { + ((value + half) / remove_factor) * remove_factor + } else { + ((value - half) / remove_factor) * remove_factor + } + } else { + value + } +} + +#[inline] +fn round_decimal256(value: i256, current_scale: i8, decimal_places: i32) -> i256 { + let scale_adjustment = current_scale as i32 - decimal_places; + + if scale_adjustment > 0 { + let remove_factor = i256::from_i128(10_i128.pow(scale_adjustment as u32)); + let half = remove_factor / i256::from_i128(2); + + if value >= i256::from_i128(0) { + ((value + half) / remove_factor) * remove_factor + } else { + ((value - half) / remove_factor) * remove_factor + } + } else { + value + } +} + #[cfg(test)] mod test { use std::sync::Arc; use crate::math::round::round; - use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array}; + use arrow::array::{ + ArrayRef, Decimal128Array, Decimal256Array, Float32Array, Float64Array, + Int64Array, + }; + use arrow_buffer::i256; use datafusion_common::cast::{as_float32_array, as_float64_array}; use datafusion_common::DataFusionError; @@ -307,4 +458,115 @@ mod test { assert!(result.is_err()); assert!(matches!(result, Err(DataFusionError::Execution { .. }))); } + + #[test] + fn test_round_decimal128() { + let args: Vec = vec![ + Arc::new( + Decimal128Array::from(vec![1252345_i128; 10]) + .with_precision_and_scale(10, 4) + .unwrap(), + ), + Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), + ]; + + let result = round(&args).expect("failed to initialize function round"); + let decimals = result.as_any().downcast_ref::().unwrap(); + + let expected = Decimal128Array::from(vec![ + 1250000_i128, + 1252000_i128, + 1252300_i128, + 1252350_i128, + 1252345_i128, + 1252345_i128, + 1300000_i128, + 1000000_i128, + 0_i128, + 0_i128, + ]) + .with_precision_and_scale(10, 4) + .unwrap(); + + assert_eq!(decimals, &expected); + } + + #[test] + fn test_round_decimal128_one_input() { + let args: Vec = vec![Arc::new( + Decimal128Array::from(vec![1252345_i128, 123450_i128, 12340_i128, 1234_i128]) + .with_precision_and_scale(10, 4) + .unwrap(), + )]; + + let result = round(&args).expect("failed to initialize function round"); + let decimals = result.as_any().downcast_ref::().unwrap(); + + let expected = + Decimal128Array::from(vec![1250000_i128, 120000_i128, 10000_i128, 0_i128]) + .with_precision_and_scale(10, 4) + .unwrap(); + + assert_eq!(decimals, &expected); + } + + #[test] + fn test_round_decimal256() { + let args: Vec = vec![ + Arc::new( + Decimal256Array::from(vec![i256::from_i128(1252345_i128); 10]) + .with_precision_and_scale(20, 4) + .unwrap(), + ), + Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), + ]; + + let result = round(&args).expect("failed to initialize function round"); + let decimals = result.as_any().downcast_ref::().unwrap(); + + let expected = Decimal256Array::from(vec![ + i256::from_i128(1250000_i128), + i256::from_i128(1252000_i128), + i256::from_i128(1252300_i128), + i256::from_i128(1252350_i128), + i256::from_i128(1252345_i128), + i256::from_i128(1252345_i128), + i256::from_i128(1300000_i128), + i256::from_i128(1000000_i128), + i256::from_i128(0_i128), + i256::from_i128(0_i128), + ]) + .with_precision_and_scale(20, 4) + .unwrap(); + + assert_eq!(decimals, &expected); + } + + #[test] + fn test_round_decimal256_one_input() { + let args: Vec = vec![Arc::new( + Decimal256Array::from(vec![ + i256::from_i128(1252345_i128), + i256::from_i128(123450_i128), + i256::from_i128(12340_i128), + i256::from_i128(1234_i128), + ]) + .with_precision_and_scale(20, 4) + .unwrap(), + )]; + + let result = round(&args).expect("failed to initialize function round"); + let decimals = result.as_any().downcast_ref::().unwrap(); + + let expected = Decimal256Array::from(vec![ + i256::from_i128(1250000_i128), + i256::from_i128(120000_i128), + i256::from_i128(10000_i128), + i256::from_i128(0_i128), + ]) + .with_precision_and_scale(20, 4) + .unwrap(); + + assert_eq!(decimals, &expected); + } }