Skip to content

Datafusion: math function does not support array type f32 #699

@lvheyang

Description

@lvheyang

Describe the bug
Math function such as sqrt() / sin() ... does not support f32 type.

To Reproduce

Example Code

use std::sync::Arc;

use datafusion::arrow::array::{Float32Array, Float64Array};
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::arrow::util::pretty;

use datafusion::datasource::MemTable;
use datafusion::error::Result;
use datafusion::prelude::*;

#[tokio::main]
async fn main() -> Result<()> {
    // define a schema.
    let schema = Arc::new(Schema::new(vec![
        Field::new("f64", DataType::Float64, false), 
        Field::new("f32", DataType::Float32, false),
    ]));

    // define data.
    let batch = RecordBatch::try_new(
        schema.clone(),
        vec![
            Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0, 5.0])),
            Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0, 4.0, 5.0])),
        ],
    )?;

    let mut ctx = ExecutionContext::new();
    let provider = MemTable::try_new(schema, vec![vec![batch]])?;
    ctx.register_table("t", Arc::new(provider))?;

    // construct an expression corresponding to "SELECT sqrt(f32) FROM t" in SQL
    let df = ctx.sql("SELECT sqrt(f32) FROM t").unwrap(); // failed
    // let df = ctx.sql("SELECT sqrt(f64) FROM t").unwrap(); // success

    // execute
    let results = df.collect().await?;

    // print the results
    pretty::print_batches(&results)?;

    Ok(())
}

The code will panic, and the error output is

Error: ArrowError(InvalidArgumentError("column types must match schema types, expected Float64 but found Float32 at column index 0"))

Expected behavior

The code should not panic

Additional context

Possible reason:

in datafusion/src/physical_plan/functions.rs

The return type of math builtin functions is Float64

BuiltinScalarFunction::Abs
        | BuiltinScalarFunction::Acos
        | ...
        | BuiltinScalarFunction::Sqrt
        | BuiltinScalarFunction::Tan
        | BuiltinScalarFunction::Trunc => Ok(DataType::Float64),

But when compute the math function in datafusion/src/physical_plan/math_expressions.rs, when compute f32 array, it will return Float32Array

macro_rules! unary_primitive_array_op {
    ($VALUE:expr, $NAME:expr, $FUNC:ident) => {{
        match ($VALUE) {
            ColumnarValue::Array(array) => match array.data_type() {
                DataType::Float32 => {
                    let result = downcast_compute_op!(array, $NAME, $FUNC, Float32Array);
                    Ok(ColumnarValue::Array(result?))
                }
                DataType::Float64 => {
                    let result = downcast_compute_op!(array, $NAME, $FUNC, Float64Array);
                    Ok(ColumnarValue::Array(result?))
                }
                other => Err(DataFusionError::Internal(format!(
                    "Unsupported data type {:?} for function {}",
                    other, $NAME,
                ))),
            },
       ...

It seems that we should

  1. Implicit cast f32 to float64 before we use math functions,
  2. or when we infer schema we should also consider the input arguments' datatype. such as sqrt(f32) -> f32, sqrt(64) -> f64

I think the second is more reasonable

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions