Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 34 additions & 138 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,77 +314,45 @@ pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> {

/// function return type of variance
pub fn variance_return_type(arg_type: &DataType) -> Result<DataType> {
match arg_type {
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Float32
| DataType::Float64 => Ok(DataType::Float64),
other => Err(DataFusionError::Plan(format!(
"VAR does not support {other:?}"
))),
if NUMERICS.contains(arg_type) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably we can reuse Datatype::is_numeric()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the DataType::is_numeric also includes Decimals 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, perhaps we can do a helper method checking is_numeric and exclude decimals.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sort of

fn is_numeric(arg_type: &DataType) -> bool {
    return arg_type.is_numeric() && match arg_type {
        Decimal128(_, _) | Decimal() => false
        _ => true
    } 
}

Ok(DataType::Float64)
} else {
Err(DataFusionError::Plan(format!(
"VAR does not support {arg_type:?}"
)))
}
}

/// function return type of covariance
pub fn covariance_return_type(arg_type: &DataType) -> Result<DataType> {
match arg_type {
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Float32
| DataType::Float64 => Ok(DataType::Float64),
other => Err(DataFusionError::Plan(format!(
"COVAR does not support {other:?}"
))),
if NUMERICS.contains(arg_type) {
Ok(DataType::Float64)
} else {
Err(DataFusionError::Plan(format!(
"COVAR does not support {arg_type:?}"
)))
}
}

/// function return type of correlation
pub fn correlation_return_type(arg_type: &DataType) -> Result<DataType> {
match arg_type {
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Float32
| DataType::Float64 => Ok(DataType::Float64),
other => Err(DataFusionError::Plan(format!(
"CORR does not support {other:?}"
))),
if NUMERICS.contains(arg_type) {
Ok(DataType::Float64)
} else {
Err(DataFusionError::Plan(format!(
"CORR does not support {arg_type:?}"
)))
}
}

/// function return type of standard deviation
pub fn stddev_return_type(arg_type: &DataType) -> Result<DataType> {
match arg_type {
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Float32
| DataType::Float64 => Ok(DataType::Float64),
other => Err(DataFusionError::Plan(format!(
"STDDEV does not support {other:?}"
))),
if NUMERICS.contains(arg_type) {
Ok(DataType::Float64)
} else {
Err(DataFusionError::Plan(format!(
"STDDEV does not support {arg_type:?}"
)))
}
}

Expand All @@ -398,16 +366,7 @@ pub fn avg_return_type(arg_type: &DataType) -> Result<DataType> {
let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4);
Ok(DataType::Decimal128(new_precision, new_scale))
}
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Float32
| DataType::Float64 => Ok(DataType::Float64),
arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64),
other => Err(DataFusionError::Plan(format!(
"AVG does not support {other:?}"
))),
Expand All @@ -417,98 +376,44 @@ pub fn avg_return_type(arg_type: &DataType) -> Result<DataType> {
pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool {
matches!(
arg_type,
DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Float32
| DataType::Float64
| DataType::Decimal128(_, _)
arg_type if NUMERICS.contains(arg_type)
|| matches!(arg_type, DataType::Decimal128(_, _))
)
}

pub fn is_avg_support_arg_type(arg_type: &DataType) -> bool {
matches!(
arg_type,
DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Float32
| DataType::Float64
| DataType::Decimal128(_, _)
arg_type if NUMERICS.contains(arg_type)
|| matches!(arg_type, DataType::Decimal128(_, _))
)
}

pub fn is_variance_support_arg_type(arg_type: &DataType) -> bool {
matches!(
arg_type,
DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Float32
| DataType::Float64
arg_type if NUMERICS.contains(arg_type)
)
}

pub fn is_covariance_support_arg_type(arg_type: &DataType) -> bool {
matches!(
arg_type,
DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Float32
| DataType::Float64
arg_type if NUMERICS.contains(arg_type)
)
}

pub fn is_stddev_support_arg_type(arg_type: &DataType) -> bool {
matches!(
arg_type,
DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Float32
| DataType::Float64
arg_type if NUMERICS.contains(arg_type)
)
}

pub fn is_correlation_support_arg_type(arg_type: &DataType) -> bool {
matches!(
arg_type,
DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Float32
| DataType::Float64
arg_type if NUMERICS.contains(arg_type)
)
}

Expand All @@ -531,16 +436,7 @@ pub fn is_integer_arg_type(arg_type: &DataType) -> bool {
pub fn is_approx_percentile_cont_supported_arg_type(arg_type: &DataType) -> bool {
matches!(
arg_type,
DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Float32
| DataType::Float64
arg_type if NUMERICS.contains(arg_type)
)
}

Expand Down