Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 196 additions & 31 deletions datafusion/core/src/physical_plan/join_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use crate::logical_expr::JoinType;
use crate::physical_plan::expressions::Column;
use arrow::datatypes::{Field, Schema};
use arrow::error::ArrowError;
use datafusion_common::ScalarValue;
use datafusion_physical_expr::PhysicalExpr;
use futures::future::{BoxFuture, Shared};
use futures::{ready, FutureExt};
Expand Down Expand Up @@ -423,7 +424,9 @@ fn estimate_inner_join_cardinality(
return None;
}

let max_distinct = max(left_stat.distinct_count, right_stat.distinct_count);
let left_max_distinct = max_distinct_count(left_num_rows, left_stat.clone());
let right_max_distinct = max_distinct_count(right_num_rows, right_stat.clone());
let max_distinct = max(left_max_distinct, right_max_distinct);
if max_distinct > join_selectivity {
// Seems like there are a few implementations of this algorithm that implement
// exponential decay for the selectivity (like Hive's Optiq Optimizer). Needs
Expand All @@ -447,6 +450,50 @@ fn estimate_inner_join_cardinality(
}
}

/// Estimate the number of maximum distinct values that can be present in the
/// given column from its statistics.
///
/// If distinct_count is available, uses it directly. If the column numeric, and
/// has min/max values, then they might be used as a fallback option. Otherwise,
/// returns None.
fn max_distinct_count(num_rows: usize, stats: ColumnStatistics) -> Option<usize> {
match (stats.distinct_count, stats.max_value, stats.min_value) {
(Some(_), _, _) => stats.distinct_count,
(_, Some(max), Some(min)) => {
// Note that float support is intentionally omitted here, since the computation
// of a range between two float values is not trivial and the result would be
// highly inaccurate.
let numeric_range = get_int_range(min, max)?;

// The number can never be greater than the number of rows we have (minus
// the nulls, since they don't count as distinct values).
let ceiling = num_rows - stats.null_count.unwrap_or(0);
Some(numeric_range.min(ceiling))
}
_ => None,
}
}

/// Return the numeric range between the given min and max values.
fn get_int_range(min: ScalarValue, max: ScalarValue) -> Option<usize> {
let delta = &max.sub(&min).ok()?;
match delta {
ScalarValue::Int8(Some(delta)) if *delta >= 0 => Some(*delta as usize),
ScalarValue::Int16(Some(delta)) if *delta >= 0 => Some(*delta as usize),
ScalarValue::Int32(Some(delta)) if *delta >= 0 => Some(*delta as usize),
ScalarValue::Int64(Some(delta)) if *delta >= 0 => Some(*delta as usize),
ScalarValue::UInt8(Some(delta)) => Some(*delta as usize),
ScalarValue::UInt16(Some(delta)) => Some(*delta as usize),
ScalarValue::UInt32(Some(delta)) => Some(*delta as usize),
ScalarValue::UInt64(Some(delta)) => Some(*delta as usize),
_ => None,
}
// The delta (directly) is not the real range, since it does not include the
// first term.
// E.g. (min=2, max=4) -> (4 - 2) -> 2, but the actual result should be 3 (1, 2, 3).
.map(|open_ended_range| open_ended_range + 1)
}

enum OnceFutState<T> {
Pending(OnceFutPending<T>),
Ready(Arc<Result<T>>),
Expand Down Expand Up @@ -626,19 +673,19 @@ mod tests {
}

fn create_column_stats(
min: Option<u64>,
max: Option<u64>,
min: Option<i64>,
max: Option<i64>,
distinct_count: Option<usize>,
) -> ColumnStatistics {
ColumnStatistics {
distinct_count,
min_value: min.map(|size| ScalarValue::UInt64(Some(size))),
max_value: max.map(|size| ScalarValue::UInt64(Some(size))),
min_value: min.map(|size| ScalarValue::Int64(Some(size))),
max_value: max.map(|size| ScalarValue::Int64(Some(size))),
..Default::default()
}
}

type PartialStats = (usize, u64, u64, Option<usize>);
type PartialStats = (usize, Option<i64>, Option<i64>, Option<usize>);

// This is mainly for validating the all edge cases of the estimation, but
// more advanced (and real world test cases) are below where we need some control
Expand All @@ -650,40 +697,135 @@ mod tests {
// | left(rows, min, max, distinct), right(rows, min, max, distinct), expected |
// -----------------------------------------------------------------------------

// distinct(left) is None OR distinct(right) is None
// Cardinality computation
// =======================
//
// distinct(left) == NaN, distinct(right) == NaN
(
(10, Some(1), Some(10), None),
(10, Some(1), Some(10), None),
Some(10),
),
// range(left) > range(right)
(
(10, Some(6), Some(10), None),
(10, Some(8), Some(10), None),
Some(20),
),
// range(right) > range(left)
(
(10, Some(8), Some(10), None),
(10, Some(6), Some(10), None),
Some(20),
),
// range(left) > len(left), range(right) > len(right)
(
(10, Some(1), Some(15), None),
(20, Some(1), Some(40), None),
Some(10),
),
// When we have distinct count.
(
(10, Some(1), Some(10), Some(10)),
(10, Some(1), Some(10), Some(10)),
Some(10),
),
// distinct(left) > distinct(right)
(
(10, Some(1), Some(10), Some(5)),
(10, Some(1), Some(10), Some(2)),
Some(20),
),
// distinct(right) > distinct(left)
(
(10, Some(1), Some(10), Some(2)),
(10, Some(1), Some(10), Some(5)),
Some(20),
),
// min(left) < 0 (range(left) > range(right))
(
(10, Some(-5), Some(5), None),
(10, Some(1), Some(5), None),
Some(10),
),
// min(right) < 0, max(right) < 0 (range(right) > range(left))
(
(10, Some(-25), Some(-20), None),
(10, Some(-25), Some(-15), None),
Some(10),
),
// range(left) < 0, range(right) >= 0
// (there isn't a case where both left and right ranges are negative
// so one of them is always going to work, this just proves negative
// ranges with bigger absolute values are not are not accidentally used).
(
(10, Some(10), Some(0), None),
(10, Some(0), Some(10), Some(5)),
Some(20), // It would have been ten if we have used abs(range(left))
),
// range(left) = 1, range(right) = 1
(
(10, Some(1), Some(1), None),
(10, Some(1), Some(1), None),
Some(100),
),
//
// len(left) = len(right), len(left) * len(right)
((10, 0, 10, None), (10, 0, 10, None), None),
// len(left) > len(right) OR len(left) < len(right), len(left) * len(right)
((10, 0, 10, None), (5, 0, 10, None), None),
((5, 0, 10, None), (10, 0, 10, None), None),
((10, 0, 10, None), (5, 0, 10, None), None),
((5, 0, 10, None), (10, 0, 10, None), None),
// min(left) > max(right) OR min(right) > max(left), None
((10, 0, 10, None), (10, 11, 20, None), None),
((10, 11, 20, None), (10, 0, 10, None), None),
((10, 5, 10, None), (10, 11, 3, None), None),
((10, 10, 5, None), (10, 3, 7, None), None),
// distinct(left) is not None AND distinct(right) is not None
// Edge cases
// ==========
//
// len(left) = len(right), len(left) * len(right) / max(distinct(left), distinct(right))
((10, 0, 10, Some(5)), (10, 0, 10, Some(5)), Some(20)),
((10, 0, 10, Some(10)), (10, 0, 10, Some(5)), Some(10)),
((10, 0, 10, Some(5)), (10, 0, 10, Some(10)), Some(10)),
// No column level stats.
((10, None, None, None), (10, None, None, None), None),
// No min or max (or both).
((10, None, None, Some(3)), (10, None, None, Some(3)), None),
(
(10, Some(2), None, Some(3)),
(10, None, Some(5), Some(3)),
None,
),
(
(10, None, Some(3), Some(3)),
(10, Some(1), None, Some(3)),
None,
),
((10, None, Some(3), None), (10, Some(1), None, None), None),
// Non overlapping min/max.
(
(10, Some(0), Some(10), None),
(10, Some(11), Some(20), None),
None,
),
(
(10, Some(11), Some(20), None),
(10, Some(0), Some(10), None),
None,
),
(
(10, Some(5), Some(10), Some(10)),
(10, Some(11), Some(3), Some(10)),
None,
),
(
(10, Some(10), Some(5), Some(10)),
(10, Some(3), Some(7), Some(10)),
None,
),
// distinct(left) = 0, distinct(right) = 0
(
(10, Some(1), Some(10), Some(0)),
(10, Some(1), Some(10), Some(0)),
None,
),
];

for (left_info, right_info, expected_cardinality) in cases {
let left_num_rows = left_info.0;
let left_col_stats = vec![create_column_stats(
Some(left_info.1),
Some(left_info.2),
left_info.3,
)];
let left_col_stats =
vec![create_column_stats(left_info.1, left_info.2, left_info.3)];

let right_num_rows = right_info.0;
let right_col_stats = vec![create_column_stats(
Some(right_info.1),
Some(right_info.2),
right_info.1,
right_info.2,
right_info.3,
)];

Expand Down Expand Up @@ -740,6 +882,29 @@ mod tests {
Ok(())
}

#[test]
fn test_inner_join_cardinality_decimal_range() -> Result<()> {
let left_col_stats = vec![ColumnStatistics {
distinct_count: None,
min_value: Some(ScalarValue::Decimal128(Some(32500), 14, 4)),
max_value: Some(ScalarValue::Decimal128(Some(35000), 14, 4)),
..Default::default()
}];

let right_col_stats = vec![ColumnStatistics {
distinct_count: None,
min_value: Some(ScalarValue::Decimal128(Some(33500), 14, 4)),
max_value: Some(ScalarValue::Decimal128(Some(34000), 14, 4)),
..Default::default()
}];

assert_eq!(
estimate_inner_join_cardinality(100, 100, left_col_stats, right_col_stats),
None
);
Ok(())
}

#[test]
fn test_join_cardinality() -> Result<()> {
// Left table (rows=1000)
Expand Down