diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index fca851ce63c9..3ad197afb64a 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -314,77 +314,45 @@ pub fn sum_return_type(arg_type: &DataType) -> Result { /// function return type of variance pub fn variance_return_type(arg_type: &DataType) -> Result { - match arg_type { - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Float32 - | DataType::Float64 => Ok(DataType::Float64), - other => Err(DataFusionError::Plan(format!( - "VAR does not support {other:?}" - ))), + if NUMERICS.contains(arg_type) { + Ok(DataType::Float64) + } else { + Err(DataFusionError::Plan(format!( + "VAR does not support {arg_type:?}" + ))) } } /// function return type of covariance pub fn covariance_return_type(arg_type: &DataType) -> Result { - match arg_type { - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Float32 - | DataType::Float64 => Ok(DataType::Float64), - other => Err(DataFusionError::Plan(format!( - "COVAR does not support {other:?}" - ))), + if NUMERICS.contains(arg_type) { + Ok(DataType::Float64) + } else { + Err(DataFusionError::Plan(format!( + "COVAR does not support {arg_type:?}" + ))) } } /// function return type of correlation pub fn correlation_return_type(arg_type: &DataType) -> Result { - match arg_type { - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Float32 - | DataType::Float64 => Ok(DataType::Float64), - other => Err(DataFusionError::Plan(format!( - "CORR does not support {other:?}" - ))), + if NUMERICS.contains(arg_type) { + Ok(DataType::Float64) + } else { + Err(DataFusionError::Plan(format!( + "CORR does not support {arg_type:?}" + ))) } } /// function return type of standard deviation pub fn stddev_return_type(arg_type: &DataType) -> Result { - match arg_type { - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Float32 - | DataType::Float64 => Ok(DataType::Float64), - other => Err(DataFusionError::Plan(format!( - "STDDEV does not support {other:?}" - ))), + if NUMERICS.contains(arg_type) { + Ok(DataType::Float64) + } else { + Err(DataFusionError::Plan(format!( + "STDDEV does not support {arg_type:?}" + ))) } } @@ -398,16 +366,7 @@ pub fn avg_return_type(arg_type: &DataType) -> Result { let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4); Ok(DataType::Decimal128(new_precision, new_scale)) } - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Float32 - | DataType::Float64 => Ok(DataType::Float64), + arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64), other => Err(DataFusionError::Plan(format!( "AVG does not support {other:?}" ))), @@ -417,98 +376,44 @@ pub fn avg_return_type(arg_type: &DataType) -> Result { pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool { matches!( arg_type, - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 - | DataType::Decimal128(_, _) + arg_type if NUMERICS.contains(arg_type) + || matches!(arg_type, DataType::Decimal128(_, _)) ) } pub fn is_avg_support_arg_type(arg_type: &DataType) -> bool { matches!( arg_type, - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 - | DataType::Decimal128(_, _) + arg_type if NUMERICS.contains(arg_type) + || matches!(arg_type, DataType::Decimal128(_, _)) ) } pub fn is_variance_support_arg_type(arg_type: &DataType) -> bool { matches!( arg_type, - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 + arg_type if NUMERICS.contains(arg_type) ) } pub fn is_covariance_support_arg_type(arg_type: &DataType) -> bool { matches!( arg_type, - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 + arg_type if NUMERICS.contains(arg_type) ) } pub fn is_stddev_support_arg_type(arg_type: &DataType) -> bool { matches!( arg_type, - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 + arg_type if NUMERICS.contains(arg_type) ) } pub fn is_correlation_support_arg_type(arg_type: &DataType) -> bool { matches!( arg_type, - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 + arg_type if NUMERICS.contains(arg_type) ) } @@ -531,16 +436,7 @@ pub fn is_integer_arg_type(arg_type: &DataType) -> bool { pub fn is_approx_percentile_cont_supported_arg_type(arg_type: &DataType) -> bool { matches!( arg_type, - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 + arg_type if NUMERICS.contains(arg_type) ) }