diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index d2dcec5f47d7..d4d3a8a14ac6 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -916,9 +916,10 @@ mod tests { physical_plan::expressions::AvgAccumulator, }; use arrow::array::{ - Array, ArrayRef, BinaryArray, DictionaryArray, Float64Array, Int32Array, - Int64Array, LargeBinaryArray, LargeStringArray, StringArray, - TimestampNanosecondArray, + Array, ArrayRef, BinaryArray, DictionaryArray, Float32Array, Float64Array, + Int16Array, Int32Array, Int64Array, Int8Array, LargeBinaryArray, + LargeStringArray, StringArray, TimestampNanosecondArray, UInt16Array, + UInt32Array, UInt64Array, UInt8Array, }; use arrow::compute::add; use arrow::datatypes::*; @@ -2364,6 +2365,75 @@ mod tests { assert_batches_sorted_eq!(expected, &results); } + #[tokio::test] + async fn case_builtin_math_expression() { + let mut ctx = ExecutionContext::new(); + + let type_values = vec![ + ( + DataType::Int8, + Arc::new(Int8Array::from(vec![1])) as ArrayRef, + ), + ( + DataType::Int16, + Arc::new(Int16Array::from(vec![1])) as ArrayRef, + ), + ( + DataType::Int32, + Arc::new(Int32Array::from(vec![1])) as ArrayRef, + ), + ( + DataType::Int64, + Arc::new(Int64Array::from(vec![1])) as ArrayRef, + ), + ( + DataType::UInt8, + Arc::new(UInt8Array::from(vec![1])) as ArrayRef, + ), + ( + DataType::UInt16, + Arc::new(UInt16Array::from(vec![1])) as ArrayRef, + ), + ( + DataType::UInt32, + Arc::new(UInt32Array::from(vec![1])) as ArrayRef, + ), + ( + DataType::UInt64, + Arc::new(UInt64Array::from(vec![1])) as ArrayRef, + ), + ( + DataType::Float32, + Arc::new(Float32Array::from(vec![1.0_f32])) as ArrayRef, + ), + ( + DataType::Float64, + Arc::new(Float64Array::from(vec![1.0_f64])) as ArrayRef, + ), + ]; + + for (data_type, array) in type_values.iter() { + let schema = + Arc::new(Schema::new(vec![Field::new("v", data_type.clone(), false)])); + let batch = + RecordBatch::try_new(schema.clone(), vec![array.clone()]).unwrap(); + let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap(); + ctx.register_table("t", Arc::new(provider)).unwrap(); + let expected = vec![ + "+---------+", + "| sqrt(v) |", + "+---------+", + "| 1 |", + "+---------+", + ]; + let results = plan_and_collect(&mut ctx, "SELECT sqrt(v) FROM t") + .await + .unwrap(); + + assert_batches_sorted_eq!(expected, &results); + } + } + #[tokio::test] async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> { let mut ctx = ExecutionContext::new(); diff --git a/datafusion/src/physical_plan/functions.rs b/datafusion/src/physical_plan/functions.rs index 01f7e95a0ee9..d856ca4bd606 100644 --- a/datafusion/src/physical_plan/functions.rs +++ b/datafusion/src/physical_plan/functions.rs @@ -468,7 +468,18 @@ pub fn return_type( | BuiltinScalarFunction::Sin | BuiltinScalarFunction::Sqrt | BuiltinScalarFunction::Tan - | BuiltinScalarFunction::Trunc => Ok(DataType::Float64), + | BuiltinScalarFunction::Trunc => { + if arg_types.is_empty() { + return Err(DataFusionError::Internal(format!( + "builtin scalar function {} does not support empty arguments", + fun + ))); + } + match arg_types[0] { + DataType::Float32 => Ok(DataType::Float32), + _ => Ok(DataType::Float64), + } + } } } @@ -1427,8 +1438,8 @@ mod tests { }; use arrow::{ array::{ - Array, ArrayRef, BinaryArray, BooleanArray, FixedSizeListArray, Float64Array, - Int32Array, StringArray, UInt32Array, UInt64Array, + Array, ArrayRef, BinaryArray, BooleanArray, FixedSizeListArray, Float32Array, + Float64Array, Int32Array, StringArray, UInt32Array, UInt64Array, }, datatypes::Field, record_batch::RecordBatch, @@ -1857,10 +1868,10 @@ mod tests { test_function!( Exp, &[lit(ScalarValue::Float32(Some(1.0)))], - Ok(Some((1.0_f32).exp() as f64)), - f64, - Float64, - Float64Array + Ok(Some((1.0_f32).exp())), + f32, + Float32, + Float32Array ); test_function!( InitCap, diff --git a/datafusion/src/physical_plan/math_expressions.rs b/datafusion/src/physical_plan/math_expressions.rs index cfc239cde661..eabacfc6eb18 100644 --- a/datafusion/src/physical_plan/math_expressions.rs +++ b/datafusion/src/physical_plan/math_expressions.rs @@ -60,7 +60,7 @@ macro_rules! unary_primitive_array_op { }, ColumnarValue::Scalar(a) => match a { ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar( - ScalarValue::Float64(a.map(|x| x.$FUNC() as f64)), + ScalarValue::Float32(a.map(|x| x.$FUNC())), )), ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar( ScalarValue::Float64(a.map(|x| x.$FUNC())),