diff --git a/datafusion/core/src/physical_plan/join_utils.rs b/datafusion/core/src/physical_plan/join_utils.rs index 780a5e96f03e..d010f4219995 100644 --- a/datafusion/core/src/physical_plan/join_utils.rs +++ b/datafusion/core/src/physical_plan/join_utils.rs @@ -22,6 +22,7 @@ use crate::logical_expr::JoinType; use crate::physical_plan::expressions::Column; use arrow::datatypes::{Field, Schema}; use arrow::error::ArrowError; +use datafusion_common::ScalarValue; use datafusion_physical_expr::PhysicalExpr; use futures::future::{BoxFuture, Shared}; use futures::{ready, FutureExt}; @@ -423,7 +424,9 @@ fn estimate_inner_join_cardinality( return None; } - let max_distinct = max(left_stat.distinct_count, right_stat.distinct_count); + let left_max_distinct = max_distinct_count(left_num_rows, left_stat.clone()); + let right_max_distinct = max_distinct_count(right_num_rows, right_stat.clone()); + let max_distinct = max(left_max_distinct, right_max_distinct); if max_distinct > join_selectivity { // Seems like there are a few implementations of this algorithm that implement // exponential decay for the selectivity (like Hive's Optiq Optimizer). Needs @@ -447,6 +450,50 @@ fn estimate_inner_join_cardinality( } } +/// Estimate the number of maximum distinct values that can be present in the +/// given column from its statistics. +/// +/// If distinct_count is available, uses it directly. If the column numeric, and +/// has min/max values, then they might be used as a fallback option. Otherwise, +/// returns None. +fn max_distinct_count(num_rows: usize, stats: ColumnStatistics) -> Option { + match (stats.distinct_count, stats.max_value, stats.min_value) { + (Some(_), _, _) => stats.distinct_count, + (_, Some(max), Some(min)) => { + // Note that float support is intentionally omitted here, since the computation + // of a range between two float values is not trivial and the result would be + // highly inaccurate. + let numeric_range = get_int_range(min, max)?; + + // The number can never be greater than the number of rows we have (minus + // the nulls, since they don't count as distinct values). + let ceiling = num_rows - stats.null_count.unwrap_or(0); + Some(numeric_range.min(ceiling)) + } + _ => None, + } +} + +/// Return the numeric range between the given min and max values. +fn get_int_range(min: ScalarValue, max: ScalarValue) -> Option { + let delta = &max.sub(&min).ok()?; + match delta { + ScalarValue::Int8(Some(delta)) if *delta >= 0 => Some(*delta as usize), + ScalarValue::Int16(Some(delta)) if *delta >= 0 => Some(*delta as usize), + ScalarValue::Int32(Some(delta)) if *delta >= 0 => Some(*delta as usize), + ScalarValue::Int64(Some(delta)) if *delta >= 0 => Some(*delta as usize), + ScalarValue::UInt8(Some(delta)) => Some(*delta as usize), + ScalarValue::UInt16(Some(delta)) => Some(*delta as usize), + ScalarValue::UInt32(Some(delta)) => Some(*delta as usize), + ScalarValue::UInt64(Some(delta)) => Some(*delta as usize), + _ => None, + } + // The delta (directly) is not the real range, since it does not include the + // first term. + // E.g. (min=2, max=4) -> (4 - 2) -> 2, but the actual result should be 3 (1, 2, 3). + .map(|open_ended_range| open_ended_range + 1) +} + enum OnceFutState { Pending(OnceFutPending), Ready(Arc>), @@ -626,19 +673,19 @@ mod tests { } fn create_column_stats( - min: Option, - max: Option, + min: Option, + max: Option, distinct_count: Option, ) -> ColumnStatistics { ColumnStatistics { distinct_count, - min_value: min.map(|size| ScalarValue::UInt64(Some(size))), - max_value: max.map(|size| ScalarValue::UInt64(Some(size))), + min_value: min.map(|size| ScalarValue::Int64(Some(size))), + max_value: max.map(|size| ScalarValue::Int64(Some(size))), ..Default::default() } } - type PartialStats = (usize, u64, u64, Option); + type PartialStats = (usize, Option, Option, Option); // This is mainly for validating the all edge cases of the estimation, but // more advanced (and real world test cases) are below where we need some control @@ -650,40 +697,135 @@ mod tests { // | left(rows, min, max, distinct), right(rows, min, max, distinct), expected | // ----------------------------------------------------------------------------- - // distinct(left) is None OR distinct(right) is None + // Cardinality computation + // ======================= + // + // distinct(left) == NaN, distinct(right) == NaN + ( + (10, Some(1), Some(10), None), + (10, Some(1), Some(10), None), + Some(10), + ), + // range(left) > range(right) + ( + (10, Some(6), Some(10), None), + (10, Some(8), Some(10), None), + Some(20), + ), + // range(right) > range(left) + ( + (10, Some(8), Some(10), None), + (10, Some(6), Some(10), None), + Some(20), + ), + // range(left) > len(left), range(right) > len(right) + ( + (10, Some(1), Some(15), None), + (20, Some(1), Some(40), None), + Some(10), + ), + // When we have distinct count. + ( + (10, Some(1), Some(10), Some(10)), + (10, Some(1), Some(10), Some(10)), + Some(10), + ), + // distinct(left) > distinct(right) + ( + (10, Some(1), Some(10), Some(5)), + (10, Some(1), Some(10), Some(2)), + Some(20), + ), + // distinct(right) > distinct(left) + ( + (10, Some(1), Some(10), Some(2)), + (10, Some(1), Some(10), Some(5)), + Some(20), + ), + // min(left) < 0 (range(left) > range(right)) + ( + (10, Some(-5), Some(5), None), + (10, Some(1), Some(5), None), + Some(10), + ), + // min(right) < 0, max(right) < 0 (range(right) > range(left)) + ( + (10, Some(-25), Some(-20), None), + (10, Some(-25), Some(-15), None), + Some(10), + ), + // range(left) < 0, range(right) >= 0 + // (there isn't a case where both left and right ranges are negative + // so one of them is always going to work, this just proves negative + // ranges with bigger absolute values are not are not accidentally used). + ( + (10, Some(10), Some(0), None), + (10, Some(0), Some(10), Some(5)), + Some(20), // It would have been ten if we have used abs(range(left)) + ), + // range(left) = 1, range(right) = 1 + ( + (10, Some(1), Some(1), None), + (10, Some(1), Some(1), None), + Some(100), + ), // - // len(left) = len(right), len(left) * len(right) - ((10, 0, 10, None), (10, 0, 10, None), None), - // len(left) > len(right) OR len(left) < len(right), len(left) * len(right) - ((10, 0, 10, None), (5, 0, 10, None), None), - ((5, 0, 10, None), (10, 0, 10, None), None), - ((10, 0, 10, None), (5, 0, 10, None), None), - ((5, 0, 10, None), (10, 0, 10, None), None), - // min(left) > max(right) OR min(right) > max(left), None - ((10, 0, 10, None), (10, 11, 20, None), None), - ((10, 11, 20, None), (10, 0, 10, None), None), - ((10, 5, 10, None), (10, 11, 3, None), None), - ((10, 10, 5, None), (10, 3, 7, None), None), - // distinct(left) is not None AND distinct(right) is not None + // Edge cases + // ========== // - // len(left) = len(right), len(left) * len(right) / max(distinct(left), distinct(right)) - ((10, 0, 10, Some(5)), (10, 0, 10, Some(5)), Some(20)), - ((10, 0, 10, Some(10)), (10, 0, 10, Some(5)), Some(10)), - ((10, 0, 10, Some(5)), (10, 0, 10, Some(10)), Some(10)), + // No column level stats. + ((10, None, None, None), (10, None, None, None), None), + // No min or max (or both). + ((10, None, None, Some(3)), (10, None, None, Some(3)), None), + ( + (10, Some(2), None, Some(3)), + (10, None, Some(5), Some(3)), + None, + ), + ( + (10, None, Some(3), Some(3)), + (10, Some(1), None, Some(3)), + None, + ), + ((10, None, Some(3), None), (10, Some(1), None, None), None), + // Non overlapping min/max. + ( + (10, Some(0), Some(10), None), + (10, Some(11), Some(20), None), + None, + ), + ( + (10, Some(11), Some(20), None), + (10, Some(0), Some(10), None), + None, + ), + ( + (10, Some(5), Some(10), Some(10)), + (10, Some(11), Some(3), Some(10)), + None, + ), + ( + (10, Some(10), Some(5), Some(10)), + (10, Some(3), Some(7), Some(10)), + None, + ), + // distinct(left) = 0, distinct(right) = 0 + ( + (10, Some(1), Some(10), Some(0)), + (10, Some(1), Some(10), Some(0)), + None, + ), ]; for (left_info, right_info, expected_cardinality) in cases { let left_num_rows = left_info.0; - let left_col_stats = vec![create_column_stats( - Some(left_info.1), - Some(left_info.2), - left_info.3, - )]; + let left_col_stats = + vec![create_column_stats(left_info.1, left_info.2, left_info.3)]; let right_num_rows = right_info.0; let right_col_stats = vec![create_column_stats( - Some(right_info.1), - Some(right_info.2), + right_info.1, + right_info.2, right_info.3, )]; @@ -740,6 +882,29 @@ mod tests { Ok(()) } + #[test] + fn test_inner_join_cardinality_decimal_range() -> Result<()> { + let left_col_stats = vec![ColumnStatistics { + distinct_count: None, + min_value: Some(ScalarValue::Decimal128(Some(32500), 14, 4)), + max_value: Some(ScalarValue::Decimal128(Some(35000), 14, 4)), + ..Default::default() + }]; + + let right_col_stats = vec![ColumnStatistics { + distinct_count: None, + min_value: Some(ScalarValue::Decimal128(Some(33500), 14, 4)), + max_value: Some(ScalarValue::Decimal128(Some(34000), 14, 4)), + ..Default::default() + }]; + + assert_eq!( + estimate_inner_join_cardinality(100, 100, left_col_stats, right_col_stats), + None + ); + Ok(()) + } + #[test] fn test_join_cardinality() -> Result<()> { // Left table (rows=1000)