From ded46fc1946f792ce24758d5c393e94814606533 Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Wed, 27 May 2026 10:51:43 -0400 Subject: [PATCH 1/2] . --- datafusion/functions/src/math/nanvl.rs | 63 ++++++++++++------- datafusion/sqllogictest/test_files/scalar.slt | 25 ++++++++ 2 files changed, 66 insertions(+), 22 deletions(-) diff --git a/datafusion/functions/src/math/nanvl.rs b/datafusion/functions/src/math/nanvl.rs index 251e98bb72c03..67920acf97cc9 100644 --- a/datafusion/functions/src/math/nanvl.rs +++ b/datafusion/functions/src/math/nanvl.rs @@ -20,8 +20,9 @@ use std::sync::Arc; use arrow::array::{ArrayRef, AsArray, Float16Array, Float32Array, Float64Array}; use arrow::datatypes::DataType::{Float16, Float32, Float64}; use arrow::datatypes::{DataType, Float16Type, Float32Type, Float64Type}; -use datafusion_common::{Result, ScalarValue, exec_err, utils::take_function_args}; -use datafusion_expr::TypeSignature::Exact; +use datafusion_common::{ + Result, ScalarValue, exec_err, plan_err, utils::take_function_args, +}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -64,14 +65,8 @@ impl Default for NanvlFunc { impl NanvlFunc { pub fn new() -> Self { Self { - signature: Signature::one_of( - vec![ - Exact(vec![Float16, Float16]), - Exact(vec![Float32, Float32]), - Exact(vec![Float64, Float64]), - ], - Volatility::Immutable, - ), + // Argument coercion is handled by `coerce_types`. + signature: Signature::user_defined(Volatility::Immutable), } } } @@ -86,27 +81,42 @@ impl ScalarUDFImpl for NanvlFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - match &arg_types[0] { - Float16 => Ok(Float16), - Float32 => Ok(Float32), + match (&arg_types[0], &arg_types[1]) { + (Float16, Float16) => Ok(Float16), + (Float32, Float32) => Ok(Float32), _ => Ok(Float64), } } + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [x, y] = take_function_args(self.name(), arg_types)?; + + // Integers, decimals, and NULL become Float64; choosing Float64 ensures + // we can represent as many inputs as possible before rounding. The two + // inputs are then unified to the widest float type. For example, + // (Float16, Float32) -> Float32, not Float64. + let to_float = |t: &DataType| match t { + Float16 => Ok(Float16), + Float32 => Ok(Float32), + t if t.is_numeric() || t.is_null() => Ok(Float64), + t => plan_err!("Function 'nanvl' expects numeric arguments, got {t}"), + }; + let common = match (to_float(x)?, to_float(y)?) { + (Float64, _) | (_, Float64) => Float64, + (Float32, _) | (_, Float32) => Float32, + _ => Float16, + }; + Ok(vec![common.clone(), common]) + } + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let [x, y] = take_function_args(self.name(), args.args)?; match (x, y) { - (ColumnarValue::Scalar(ScalarValue::Float16(Some(v))), y) if v.is_nan() => { - Ok(y) - } - (ColumnarValue::Scalar(ScalarValue::Float32(Some(v))), y) if v.is_nan() => { - Ok(y) - } - (ColumnarValue::Scalar(ScalarValue::Float64(Some(v))), y) if v.is_nan() => { - Ok(y) - } + // Scalar x: return y if x is NaN, otherwise x (which may be NULL). + (ColumnarValue::Scalar(ref x), y) if scalar_is_nan(x) => Ok(y), (x @ ColumnarValue::Scalar(_), _) => Ok(x), + // At least one argument is an array: evaluate element-wise. (x, y) => { let args = ColumnarValue::values_to_arrays(&[x, y])?; Ok(ColumnarValue::Array(nanvl(&args)?)) @@ -119,6 +129,15 @@ impl ScalarUDFImpl for NanvlFunc { } } +fn scalar_is_nan(scalar: &ScalarValue) -> bool { + match scalar { + ScalarValue::Float16(Some(v)) => v.is_nan(), + ScalarValue::Float32(Some(v)) => v.is_nan(), + ScalarValue::Float64(Some(v)) => v.is_nan(), + _ => false, + } +} + /// Nanvl SQL function /// /// - x is NaN -> output is y (which may itself be NULL) diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 38f76f13151bc..9ba8904c3f3d3 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -777,6 +777,31 @@ select nanvl(null, null); ---- NULL +# nanvl evaluates in the common (widest) float type of its arguments. Mixing +# narrower floats widens losslessly (Float16 + Float32 -> Float32), while +# integers, decimals, and NULL contribute Float64 so they are never narrowed. +query TTTTTTTT +select + arrow_typeof(nanvl(arrow_cast(1.0, 'Float16'), arrow_cast(2.0, 'Float16'))), + arrow_typeof(nanvl(arrow_cast(1.0, 'Float32'), arrow_cast(2.0, 'Float32'))), + arrow_typeof(nanvl(arrow_cast(1.0, 'Float16'), arrow_cast(2.0, 'Float32'))), + arrow_typeof(nanvl(arrow_cast(1.0, 'Float16'), arrow_cast(2.0, 'Float64'))), + arrow_typeof(nanvl(arrow_cast(1.0, 'Float32'), arrow_cast(2.0, 'Float64'))), + arrow_typeof(nanvl(1, 2)), + arrow_typeof(nanvl(1, arrow_cast(2.0, 'Float32'))), + arrow_typeof(nanvl(null, null)); +---- +Float16 Float32 Float32 Float64 Float64 Float64 Float64 Float64 + +# nanvl with integer inputs must be computed in double precision, even when the +# other argument is Float32 (the integer must not be narrowed to Float32). +query BB +select + nanvl(16777217, 1) = nanvl(arrow_cast(16777217, 'Float64'), 1.0), + nanvl(16777217, arrow_cast(1.0, 'Float32')) = nanvl(arrow_cast(16777217, 'Float64'), 1.0); +---- +true true + # nanvl with columns (round is needed to normalize the outputs of different operating systems) query RRR rowsort select round(nanvl(asin(f + a), 2), 5), round(nanvl(asin(b + c), 3), 5), round(nanvl(asin(d + e), 4), 5) from small_floats; From 2b33ef48f38213be847161b7895089b25d7576a4 Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Wed, 27 May 2026 14:25:22 -0400 Subject: [PATCH 2/2] Tweak comments --- datafusion/sqllogictest/test_files/scalar.slt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 9ba8904c3f3d3..42815bc9fb0f1 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -779,7 +779,7 @@ NULL # nanvl evaluates in the common (widest) float type of its arguments. Mixing # narrower floats widens losslessly (Float16 + Float32 -> Float32), while -# integers, decimals, and NULL contribute Float64 so they are never narrowed. +# integers, decimals, and NULL are coerced to Float64. query TTTTTTTT select arrow_typeof(nanvl(arrow_cast(1.0, 'Float16'), arrow_cast(2.0, 'Float16'))), @@ -793,8 +793,8 @@ select ---- Float16 Float32 Float32 Float64 Float64 Float64 Float64 Float64 -# nanvl with integer inputs must be computed in double precision, even when the -# other argument is Float32 (the integer must not be narrowed to Float32). +# nanvl with an integer argument is computed in double precision, even when the +# other argument is Float32. query BB select nanvl(16777217, 1) = nanvl(arrow_cast(16777217, 'Float64'), 1.0),