From b30de470a11ca3f360260a8a36bc1e5eb4f355e8 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Thu, 19 Oct 2017 10:45:53 +0800 Subject: [PATCH 1/4] refactor --- .../BasicStatsPlanVisitor.scala | 4 +- .../statsEstimation/JoinEstimation.scala | 167 +++++++++--------- 2 files changed, 86 insertions(+), 85 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala index 4cff72d45a400..ca0775a2e8408 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation -import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.types.LongType /** * An [[LogicalPlanVisitor]] that computes a the statistics used in a cost-based optimizer. @@ -54,7 +52,7 @@ object BasicStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { override def visitIntersect(p: Intersect): Statistics = fallback(p) override def visitJoin(p: Join): Statistics = { - JoinEstimation.estimate(p).getOrElse(fallback(p)) + JoinEstimation(p).estimate.getOrElse(fallback(p)) } override def visitLocalLimit(p: LocalLimit): Statistics = fallback(p) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index dcbe36da91dfc..3c011d983bdc6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -28,45 +28,43 @@ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ -object JoinEstimation extends Logging { +case class JoinEstimation(join: Join) extends Logging { + + private val leftStats = join.left.stats + private val rightStats = join.right.stats + private val keyStatsAfterJoin = new mutable.HashMap[Attribute, ColumnStat]() + /** * Estimate statistics after join. Return `None` if the join type is not supported, or we don't * have enough statistics for estimation. */ - def estimate(join: Join): Option[Statistics] = { + def estimate: Option[Statistics] = { join.joinType match { case Inner | Cross | LeftOuter | RightOuter | FullOuter => - InnerOuterEstimation(join).doEstimate() + estimateInnerOuterJoin() case LeftSemi | LeftAnti => - LeftSemiAntiEstimation(join).doEstimate() + estimateLeftSemiAntiJoin() case _ => logDebug(s"[CBO] Unsupported join type: ${join.joinType}") None } } -} - -case class InnerOuterEstimation(join: Join) extends Logging { - - private val leftStats = join.left.stats - private val rightStats = join.right.stats /** * Estimate output size and number of rows after a join operator, and update output column stats. */ - def doEstimate(): Option[Statistics] = join match { + private def estimateInnerOuterJoin(): Option[Statistics] = join match { case _ if !rowCountsExist(join.left, join.right) => None case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _) => // 1. Compute join selectivity val joinKeyPairs = extractJoinKeysWithColStats(leftKeys, rightKeys) - val selectivity = joinSelectivity(joinKeyPairs) + val innerJoinedRows = joinCardinality(joinKeyPairs) // 2. Estimate the number of output rows val leftRows = leftStats.rowCount.get val rightRows = rightStats.rowCount.get - val innerJoinedRows = ceil(BigDecimal(leftRows * rightRows) * selectivity) // Make sure outputRows won't be too small based on join type. val outputRows = joinType match { @@ -93,7 +91,7 @@ case class InnerOuterEstimation(join: Join) extends Logging { val outputStats: Seq[(Attribute, ColumnStat)] = if (outputRows == 0) { // The output is empty, we don't need to keep column stats. Nil - } else if (selectivity == 0) { + } else if (innerJoinedRows == 0) { joinType match { // For outer joins, if the join selectivity is 0, the number of output rows is the // same as that of the outer side. And column stats of join keys from the outer side @@ -115,24 +113,23 @@ case class InnerOuterEstimation(join: Join) extends Logging { } case _ => Nil } - } else if (selectivity == 1) { + } else if (innerJoinedRows == leftRows * rightRows) { // Cartesian product, just propagate the original column stats inputAttrStats.toSeq } else { - val joinKeyStats = getIntersectedStats(joinKeyPairs) join.joinType match { // For outer joins, don't update column stats from the outer side. case LeftOuter => fromLeft.map(a => (a, inputAttrStats(a))) ++ - updateAttrStats(outputRows, fromRight, inputAttrStats, joinKeyStats) + updateOutputStats(outputRows, fromRight, inputAttrStats) case RightOuter => - updateAttrStats(outputRows, fromLeft, inputAttrStats, joinKeyStats) ++ + updateOutputStats(outputRows, fromLeft, inputAttrStats) ++ fromRight.map(a => (a, inputAttrStats(a))) case FullOuter => inputAttrStats.toSeq case _ => // Update column stats from both sides for inner or cross join. - updateAttrStats(outputRows, attributesWithStat, inputAttrStats, joinKeyStats) + updateOutputStats(outputRows, attributesWithStat, inputAttrStats) } } @@ -157,64 +154,100 @@ case class InnerOuterEstimation(join: Join) extends Logging { // scalastyle:off /** * The number of rows of A inner join B on A.k1 = B.k1 is estimated by this basic formula: - * T(A IJ B) = T(A) * T(B) / max(V(A.k1), V(B.k1)), where V is the number of distinct values of - * that column. The underlying assumption for this formula is: each value of the smaller domain - * is included in the larger domain. - * Generally, inner join with multiple join keys can also be estimated based on the above - * formula: + * T(A IJ B) = T(A) * T(B) / max(V(A.k1), V(B.k1)), + * where V is the number of distinct values (ndv) of that column. The underlying assumption for + * this formula is: each value of the smaller domain is included in the larger domain. + * + * Generally, inner join with multiple join keys can be estimated based on the above formula: * T(A IJ B) = T(A) * T(B) / (max(V(A.k1), V(B.k1)) * max(V(A.k2), V(B.k2)) * ... * max(V(A.kn), V(B.kn))) * However, the denominator can become very large and excessively reduce the result, so we use a * conservative strategy to take only the largest max(V(A.ki), V(B.ki)) as the denominator. + * + * That is, join estimation is based on the most selective join keys. We follow this strategy + * when different types of column statistics are available. E.g., if card1 is the cardinality + * estimated by ndv of join key A.k1 and B.k1, card2 is the cardinality estimated by histograms + * of join key A.k2 and B.k2, then the result cardinality would be min(card1, card2). */ // scalastyle:on - def joinSelectivity(joinKeyPairs: Seq[(AttributeReference, AttributeReference)]): BigDecimal = { - var ndvDenom: BigInt = -1 + private def joinCardinality(joinKeyPairs: Seq[(AttributeReference, AttributeReference)]) + : BigInt = { + // If there's no column stats available for join keys, estimate as cartesian product. + var minCard: BigInt = leftStats.rowCount.get * rightStats.rowCount.get var i = 0 - while(i < joinKeyPairs.length && ndvDenom != 0) { + while(i < joinKeyPairs.length && minCard != 0) { val (leftKey, rightKey) = joinKeyPairs(i) // Check if the two sides are disjoint - val leftKeyStats = leftStats.attributeStats(leftKey) - val rightKeyStats = rightStats.attributeStats(rightKey) - val lInterval = ValueInterval(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) - val rInterval = ValueInterval(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) + val leftKeyStat = leftStats.attributeStats(leftKey) + val rightKeyStat = rightStats.attributeStats(rightKey) + val lInterval = ValueInterval(leftKeyStat.min, leftKeyStat.max, leftKey.dataType) + val rInterval = ValueInterval(rightKeyStat.min, rightKeyStat.max, rightKey.dataType) if (ValueInterval.isIntersected(lInterval, rInterval)) { - // Get the largest ndv among pairs of join keys - val maxNdv = leftKeyStats.distinctCount.max(rightKeyStats.distinctCount) - if (maxNdv > ndvDenom) ndvDenom = maxNdv + val (newMin, newMax) = ValueInterval.intersect(lInterval, rInterval, leftKey.dataType) + val card = joinCardByNdv(leftKey, rightKey, newMin, newMax) + // Return cardinality estimated from the most selective join keys. + if (card < minCard) minCard = card } else { - // Set ndvDenom to zero to indicate that this join should have no output - ndvDenom = 0 + // One of the join key pairs is disjoint, thus the two sides of join is disjoint. + minCard = 0 } i += 1 } + minCard + } - if (ndvDenom < 0) { - // We can't find any join key pairs with column stats, estimate it as cartesian join. - 1 - } else if (ndvDenom == 0) { - // One of the join key pairs is disjoint, thus the two sides of join is disjoint. - 0 - } else { - 1 / BigDecimal(ndvDenom) + /** Compute join cardinality using the basic formula, and update column stats for join keys. */ + private def joinCardByNdv( + leftKey: AttributeReference, + rightKey: AttributeReference, + newMin: Option[Any], + newMax: Option[Any]): BigInt = { + val leftKeyStat = leftStats.attributeStats(leftKey) + val rightKeyStat = rightStats.attributeStats(rightKey) + val maxNdv = leftKeyStat.distinctCount.max(rightKeyStat.distinctCount) + // Compute cardinality by the basic formula. + val card = BigDecimal(leftStats.rowCount.get * rightStats.rowCount.get) / BigDecimal(maxNdv) + + // Update intersected column stats. + val newNdv = leftKeyStat.distinctCount.min(rightKeyStat.distinctCount) + val newMaxLen = math.min(leftKeyStat.maxLen, rightKeyStat.maxLen) + val newAvgLen = (leftKeyStat.avgLen + rightKeyStat.avgLen) / 2 + + join.joinType match { + case LeftOuter => + keyStatsAfterJoin.put(leftKey, leftKeyStat) + keyStatsAfterJoin.put(rightKey, + ColumnStat(newNdv, newMin, newMax, rightKeyStat.nullCount, newAvgLen, newMaxLen)) + case RightOuter => + keyStatsAfterJoin.put(leftKey, + ColumnStat(newNdv, newMin, newMax, leftKeyStat.nullCount, newAvgLen, newMaxLen)) + keyStatsAfterJoin.put(rightKey, rightKeyStat) + case FullOuter => + keyStatsAfterJoin.put(leftKey, leftKeyStat) + keyStatsAfterJoin.put(rightKey, rightKeyStat) + case _ => + val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen) + keyStatsAfterJoin.put(leftKey, newStats) + keyStatsAfterJoin.put(rightKey, newStats) } + + ceil(card) } /** * Propagate or update column stats for output attributes. */ - private def updateAttrStats( + private def updateOutputStats( outputRows: BigInt, - attributes: Seq[Attribute], - oldAttrStats: AttributeMap[ColumnStat], - joinKeyStats: AttributeMap[ColumnStat]): Seq[(Attribute, ColumnStat)] = { + output: Seq[Attribute], + oldAttrStats: AttributeMap[ColumnStat]): Seq[(Attribute, ColumnStat)] = { val outputAttrStats = new ArrayBuffer[(Attribute, ColumnStat)]() val leftRows = leftStats.rowCount.get val rightRows = rightStats.rowCount.get - attributes.foreach { a => + output.foreach { a => // check if this attribute is a join key - if (joinKeyStats.contains(a)) { - outputAttrStats += a -> joinKeyStats(a) + if (keyStatsAfterJoin.contains(a)) { + outputAttrStats += a -> keyStatsAfterJoin(a) } else { val oldColStat = oldAttrStats(a) val oldNdv = oldColStat.distinctCount @@ -231,34 +264,6 @@ case class InnerOuterEstimation(join: Join) extends Logging { outputAttrStats } - /** Get intersected column stats for join keys. */ - private def getIntersectedStats(joinKeyPairs: Seq[(AttributeReference, AttributeReference)]) - : AttributeMap[ColumnStat] = { - - val intersectedStats = new mutable.HashMap[Attribute, ColumnStat]() - joinKeyPairs.foreach { case (leftKey, rightKey) => - val leftKeyStats = leftStats.attributeStats(leftKey) - val rightKeyStats = rightStats.attributeStats(rightKey) - val lInterval = ValueInterval(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) - val rInterval = ValueInterval(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) - // When we reach here, join selectivity is not zero, so each pair of join keys should be - // intersected. - assert(ValueInterval.isIntersected(lInterval, rInterval)) - - // Update intersected column stats - assert(leftKey.dataType.sameType(rightKey.dataType)) - val newNdv = leftKeyStats.distinctCount.min(rightKeyStats.distinctCount) - val (newMin, newMax) = ValueInterval.intersect(lInterval, rInterval, leftKey.dataType) - val newMaxLen = math.min(leftKeyStats.maxLen, rightKeyStats.maxLen) - val newAvgLen = (leftKeyStats.avgLen + rightKeyStats.avgLen) / 2 - val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen) - - intersectedStats.put(leftKey, newStats) - intersectedStats.put(rightKey, newStats) - } - AttributeMap(intersectedStats.toSeq) - } - private def extractJoinKeysWithColStats( leftKeys: Seq[Expression], rightKeys: Seq[Expression]): Seq[(AttributeReference, AttributeReference)] = { @@ -270,10 +275,8 @@ case class InnerOuterEstimation(join: Join) extends Logging { if columnStatsExist((leftStats, lk), (rightStats, rk)) => (lk, rk) } } -} -case class LeftSemiAntiEstimation(join: Join) { - def doEstimate(): Option[Statistics] = { + private def estimateLeftSemiAntiJoin(): Option[Statistics] = { // TODO: It's error-prone to estimate cardinalities for LeftSemi and LeftAnti based on basic // column stats. Now we just propagate the statistics from left side. We should do more // accurate estimation when advanced stats (e.g. histograms) are available. From 18edc1471d9bcfe6bb500afa77c6d9c1a4bd23dc Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Sat, 28 Oct 2017 14:55:29 +0800 Subject: [PATCH 2/4] remove global map and rename --- .../statsEstimation/JoinEstimation.scala | 79 ++++++++----------- 1 file changed, 34 insertions(+), 45 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index 3c011d983bdc6..0feb813fc679c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -32,7 +32,6 @@ case class JoinEstimation(join: Join) extends Logging { private val leftStats = join.left.stats private val rightStats = join.right.stats - private val keyStatsAfterJoin = new mutable.HashMap[Attribute, ColumnStat]() /** * Estimate statistics after join. Return `None` if the join type is not supported, or we don't @@ -60,7 +59,7 @@ case class JoinEstimation(join: Join) extends Logging { case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _) => // 1. Compute join selectivity val joinKeyPairs = extractJoinKeysWithColStats(leftKeys, rightKeys) - val innerJoinedRows = joinCardinality(joinKeyPairs) + val (numInnerJoinedRows, keyStatsAfterJoin) = computeCardinalityAndStats(joinKeyPairs) // 2. Estimate the number of output rows val leftRows = leftStats.rowCount.get @@ -70,16 +69,16 @@ case class JoinEstimation(join: Join) extends Logging { val outputRows = joinType match { case LeftOuter => // All rows from left side should be in the result. - leftRows.max(innerJoinedRows) + leftRows.max(numInnerJoinedRows) case RightOuter => // All rows from right side should be in the result. - rightRows.max(innerJoinedRows) + rightRows.max(numInnerJoinedRows) case FullOuter => // T(A FOJ B) = T(A LOJ B) + T(A ROJ B) - T(A IJ B) - leftRows.max(innerJoinedRows) + rightRows.max(innerJoinedRows) - innerJoinedRows - case _ => + leftRows.max(numInnerJoinedRows) + rightRows.max(numInnerJoinedRows) - numInnerJoinedRows + case Inner | Cross => // Don't change for inner or cross join - innerJoinedRows + numInnerJoinedRows } // 3. Update statistics based on the output of join @@ -91,7 +90,7 @@ case class JoinEstimation(join: Join) extends Logging { val outputStats: Seq[(Attribute, ColumnStat)] = if (outputRows == 0) { // The output is empty, we don't need to keep column stats. Nil - } else if (innerJoinedRows == 0) { + } else if (numInnerJoinedRows == 0) { joinType match { // For outer joins, if the join selectivity is 0, the number of output rows is the // same as that of the outer side. And column stats of join keys from the outer side @@ -111,9 +110,9 @@ case class JoinEstimation(join: Join) extends Logging { val oriColStat = inputAttrStats(a) (a, oriColStat.copy(nullCount = oriColStat.nullCount + leftRows)) } - case _ => Nil + case Inner | Cross => Nil } - } else if (innerJoinedRows == leftRows * rightRows) { + } else if (numInnerJoinedRows == leftRows * rightRows) { // Cartesian product, just propagate the original column stats inputAttrStats.toSeq } else { @@ -121,15 +120,15 @@ case class JoinEstimation(join: Join) extends Logging { // For outer joins, don't update column stats from the outer side. case LeftOuter => fromLeft.map(a => (a, inputAttrStats(a))) ++ - updateOutputStats(outputRows, fromRight, inputAttrStats) + updateOutputStats(outputRows, fromRight, inputAttrStats, keyStatsAfterJoin) case RightOuter => - updateOutputStats(outputRows, fromLeft, inputAttrStats) ++ + updateOutputStats(outputRows, fromLeft, inputAttrStats, keyStatsAfterJoin) ++ fromRight.map(a => (a, inputAttrStats(a))) case FullOuter => inputAttrStats.toSeq - case _ => + case Inner | Cross => // Update column stats from both sides for inner or cross join. - updateOutputStats(outputRows, attributesWithStat, inputAttrStats) + updateOutputStats(outputRows, attributesWithStat, inputAttrStats, keyStatsAfterJoin) } } @@ -167,15 +166,20 @@ case class JoinEstimation(join: Join) extends Logging { * when different types of column statistics are available. E.g., if card1 is the cardinality * estimated by ndv of join key A.k1 and B.k1, card2 is the cardinality estimated by histograms * of join key A.k2 and B.k2, then the result cardinality would be min(card1, card2). + * + * @param keyPairs pairs of join keys + * + * @return join cardinality, and column stats for join keys after the join */ // scalastyle:on - private def joinCardinality(joinKeyPairs: Seq[(AttributeReference, AttributeReference)]) - : BigInt = { + private def computeCardinalityAndStats(keyPairs: Seq[(AttributeReference, AttributeReference)]) + : (BigInt, Map[Attribute, ColumnStat]) = { // If there's no column stats available for join keys, estimate as cartesian product. - var minCard: BigInt = leftStats.rowCount.get * rightStats.rowCount.get + var cardJoin: BigInt = leftStats.rowCount.get * rightStats.rowCount.get + val keyStatsAfterJoin = new mutable.HashMap[Attribute, ColumnStat]() var i = 0 - while(i < joinKeyPairs.length && minCard != 0) { - val (leftKey, rightKey) = joinKeyPairs(i) + while(i < keyPairs.length && cardJoin != 0) { + val (leftKey, rightKey) = keyPairs(i) // Check if the two sides are disjoint val leftKeyStat = leftStats.attributeStats(leftKey) val rightKeyStat = rightStats.attributeStats(rightKey) @@ -183,24 +187,25 @@ case class JoinEstimation(join: Join) extends Logging { val rInterval = ValueInterval(rightKeyStat.min, rightKeyStat.max, rightKey.dataType) if (ValueInterval.isIntersected(lInterval, rInterval)) { val (newMin, newMax) = ValueInterval.intersect(lInterval, rInterval, leftKey.dataType) - val card = joinCardByNdv(leftKey, rightKey, newMin, newMax) + val (cardKeyPair, joinStatsKeyPair) = computeByNdv(leftKey, rightKey, newMin, newMax) + keyStatsAfterJoin ++= joinStatsKeyPair // Return cardinality estimated from the most selective join keys. - if (card < minCard) minCard = card + if (cardKeyPair < cardJoin) cardJoin = cardKeyPair } else { // One of the join key pairs is disjoint, thus the two sides of join is disjoint. - minCard = 0 + cardJoin = 0 } i += 1 } - minCard + (cardJoin, keyStatsAfterJoin.toMap) } /** Compute join cardinality using the basic formula, and update column stats for join keys. */ - private def joinCardByNdv( + private def computeByNdv( leftKey: AttributeReference, rightKey: AttributeReference, newMin: Option[Any], - newMax: Option[Any]): BigInt = { + newMax: Option[Any]): (BigInt, Map[Attribute, ColumnStat]) = { val leftKeyStat = leftStats.attributeStats(leftKey) val rightKeyStat = rightStats.attributeStats(rightKey) val maxNdv = leftKeyStat.distinctCount.max(rightKeyStat.distinctCount) @@ -211,26 +216,9 @@ case class JoinEstimation(join: Join) extends Logging { val newNdv = leftKeyStat.distinctCount.min(rightKeyStat.distinctCount) val newMaxLen = math.min(leftKeyStat.maxLen, rightKeyStat.maxLen) val newAvgLen = (leftKeyStat.avgLen + rightKeyStat.avgLen) / 2 + val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen) - join.joinType match { - case LeftOuter => - keyStatsAfterJoin.put(leftKey, leftKeyStat) - keyStatsAfterJoin.put(rightKey, - ColumnStat(newNdv, newMin, newMax, rightKeyStat.nullCount, newAvgLen, newMaxLen)) - case RightOuter => - keyStatsAfterJoin.put(leftKey, - ColumnStat(newNdv, newMin, newMax, leftKeyStat.nullCount, newAvgLen, newMaxLen)) - keyStatsAfterJoin.put(rightKey, rightKeyStat) - case FullOuter => - keyStatsAfterJoin.put(leftKey, leftKeyStat) - keyStatsAfterJoin.put(rightKey, rightKeyStat) - case _ => - val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen) - keyStatsAfterJoin.put(leftKey, newStats) - keyStatsAfterJoin.put(rightKey, newStats) - } - - ceil(card) + (ceil(card), Map(leftKey -> newStats, rightKey -> newStats)) } /** @@ -239,7 +227,8 @@ case class JoinEstimation(join: Join) extends Logging { private def updateOutputStats( outputRows: BigInt, output: Seq[Attribute], - oldAttrStats: AttributeMap[ColumnStat]): Seq[(Attribute, ColumnStat)] = { + oldAttrStats: AttributeMap[ColumnStat], + keyStatsAfterJoin: Map[Attribute, ColumnStat]): Seq[(Attribute, ColumnStat)] = { val outputAttrStats = new ArrayBuffer[(Attribute, ColumnStat)]() val leftRows = leftStats.rowCount.get val rightRows = rightStats.rowCount.get From 18cb42f84736c00f1ae3b7453ae5ff2f0c823484 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Sat, 28 Oct 2017 15:21:44 +0800 Subject: [PATCH 3/4] fix style --- .../plans/logical/statsEstimation/JoinEstimation.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index 0feb813fc679c..bd44b632e486b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -76,7 +76,8 @@ case class JoinEstimation(join: Join) extends Logging { case FullOuter => // T(A FOJ B) = T(A LOJ B) + T(A ROJ B) - T(A IJ B) leftRows.max(numInnerJoinedRows) + rightRows.max(numInnerJoinedRows) - numInnerJoinedRows - case Inner | Cross => + case _ => + assert(joinType == Inner || joinType == Cross) // Don't change for inner or cross join numInnerJoinedRows } @@ -110,7 +111,9 @@ case class JoinEstimation(join: Join) extends Logging { val oriColStat = inputAttrStats(a) (a, oriColStat.copy(nullCount = oriColStat.nullCount + leftRows)) } - case Inner | Cross => Nil + case _ => + assert(joinType == Inner || joinType == Cross) + Nil } } else if (numInnerJoinedRows == leftRows * rightRows) { // Cartesian product, just propagate the original column stats @@ -126,7 +129,8 @@ case class JoinEstimation(join: Join) extends Logging { fromRight.map(a => (a, inputAttrStats(a))) case FullOuter => inputAttrStats.toSeq - case Inner | Cross => + case _ => + assert(joinType == Inner || joinType == Cross) // Update column stats from both sides for inner or cross join. updateOutputStats(outputRows, attributesWithStat, inputAttrStats, keyStatsAfterJoin) } From a2dbb8eefacf560ff5e7090baa3796b2353daed2 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Tue, 31 Oct 2017 15:23:21 +0800 Subject: [PATCH 4/4] fix comments --- .../statsEstimation/JoinEstimation.scala | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index bd44b632e486b..b073108c26ee5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -177,12 +177,12 @@ case class JoinEstimation(join: Join) extends Logging { */ // scalastyle:on private def computeCardinalityAndStats(keyPairs: Seq[(AttributeReference, AttributeReference)]) - : (BigInt, Map[Attribute, ColumnStat]) = { + : (BigInt, AttributeMap[ColumnStat]) = { // If there's no column stats available for join keys, estimate as cartesian product. - var cardJoin: BigInt = leftStats.rowCount.get * rightStats.rowCount.get + var joinCard: BigInt = leftStats.rowCount.get * rightStats.rowCount.get val keyStatsAfterJoin = new mutable.HashMap[Attribute, ColumnStat]() var i = 0 - while(i < keyPairs.length && cardJoin != 0) { + while(i < keyPairs.length && joinCard != 0) { val (leftKey, rightKey) = keyPairs(i) // Check if the two sides are disjoint val leftKeyStat = leftStats.attributeStats(leftKey) @@ -191,38 +191,38 @@ case class JoinEstimation(join: Join) extends Logging { val rInterval = ValueInterval(rightKeyStat.min, rightKeyStat.max, rightKey.dataType) if (ValueInterval.isIntersected(lInterval, rInterval)) { val (newMin, newMax) = ValueInterval.intersect(lInterval, rInterval, leftKey.dataType) - val (cardKeyPair, joinStatsKeyPair) = computeByNdv(leftKey, rightKey, newMin, newMax) - keyStatsAfterJoin ++= joinStatsKeyPair + val (card, joinStat) = computeByNdv(leftKey, rightKey, newMin, newMax) + keyStatsAfterJoin += (leftKey -> joinStat, rightKey -> joinStat) // Return cardinality estimated from the most selective join keys. - if (cardKeyPair < cardJoin) cardJoin = cardKeyPair + if (card < joinCard) joinCard = card } else { // One of the join key pairs is disjoint, thus the two sides of join is disjoint. - cardJoin = 0 + joinCard = 0 } i += 1 } - (cardJoin, keyStatsAfterJoin.toMap) + (joinCard, AttributeMap(keyStatsAfterJoin.toSeq)) } - /** Compute join cardinality using the basic formula, and update column stats for join keys. */ + /** Returns join cardinality and the column stat for this pair of join keys. */ private def computeByNdv( leftKey: AttributeReference, rightKey: AttributeReference, newMin: Option[Any], - newMax: Option[Any]): (BigInt, Map[Attribute, ColumnStat]) = { + newMax: Option[Any]): (BigInt, ColumnStat) = { val leftKeyStat = leftStats.attributeStats(leftKey) val rightKeyStat = rightStats.attributeStats(rightKey) val maxNdv = leftKeyStat.distinctCount.max(rightKeyStat.distinctCount) // Compute cardinality by the basic formula. val card = BigDecimal(leftStats.rowCount.get * rightStats.rowCount.get) / BigDecimal(maxNdv) - // Update intersected column stats. + // Get the intersected column stat. val newNdv = leftKeyStat.distinctCount.min(rightKeyStat.distinctCount) val newMaxLen = math.min(leftKeyStat.maxLen, rightKeyStat.maxLen) val newAvgLen = (leftKeyStat.avgLen + rightKeyStat.avgLen) / 2 val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen) - (ceil(card), Map(leftKey -> newStats, rightKey -> newStats)) + (ceil(card), newStats) } /** @@ -232,7 +232,7 @@ case class JoinEstimation(join: Join) extends Logging { outputRows: BigInt, output: Seq[Attribute], oldAttrStats: AttributeMap[ColumnStat], - keyStatsAfterJoin: Map[Attribute, ColumnStat]): Seq[(Attribute, ColumnStat)] = { + keyStatsAfterJoin: AttributeMap[ColumnStat]): Seq[(Attribute, ColumnStat)] = { val outputAttrStats = new ArrayBuffer[(Attribute, ColumnStat)]() val leftRows = leftStats.rowCount.get val rightRows = rightStats.rowCount.get