diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala index 6a52326ff58de..33ebc380d21e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala @@ -24,6 +24,10 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Statistics} object AggregateEstimation { import EstimationUtils._ + /** + * Estimate the number of output rows based on column stats of group-by columns, and propagate + * column stats for aggregate expressions. + */ def estimate(agg: Aggregate): Option[Statistics] = { val childStats = agg.child.statistics // Check if we have column stats for all group-by columns. @@ -31,14 +35,10 @@ object AggregateEstimation { e.isInstanceOf[Attribute] && childStats.attributeStats.contains(e.asInstanceOf[Attribute]) } if (rowCountsExist(agg.child) && colStatsExist) { - // Initial value for agg without group expressions - var outputRows: BigInt = 1 - agg.groupingExpressions.map(_.asInstanceOf[Attribute]).foreach { attr => - val colStat = childStats.attributeStats(attr) - // Multiply distinct counts of group by columns. This is an upper bound, which assumes - // the data contains all combinations of distinct values of group by columns. - outputRows *= colStat.distinctCount - } + // Multiply distinct counts of group-by columns. This is an upper bound, which assumes + // the data contains all combinations of distinct values of group-by columns. + var outputRows: BigInt = agg.groupingExpressions.foldLeft(BigInt(1))( + (res, expr) => res * childStats.attributeStats(expr.asInstanceOf[Attribute]).distinctCount) // Here we set another upper bound for the number of output rows: it must not be larger than // child's number of rows. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggEstimationSuite.scala index 99d362d3ced96..42ce2f8c5e8d6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggEstimationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.statsEstimation -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, Literal} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ @@ -25,89 +25,111 @@ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUti class AggEstimationSuite extends StatsEstimationTestBase { - /** Column info: names and column stats for group-by columns */ - val (key11, colStat11) = (attr("key11"), ColumnStat(2, Some(1), Some(2), 0, 4, 4)) - val (key12, colStat12) = (attr("key12"), ColumnStat(1, Some(10), Some(10), 0, 4, 4)) - val (key21, colStat21) = (attr("key21"), colStat11) - val (key22, colStat22) = (attr("key22"), ColumnStat(4, Some(10), Some(40), 0, 4, 4)) - val (key31, colStat31) = (attr("key31"), colStat11) - val (key32, colStat32) = (attr("key32"), ColumnStat(2, Some(10), Some(20), 0, 4, 4)) - - /** Tables for testing */ - /** Data for table1: (1, 10), (2, 10) */ - val table1 = StatsTestPlan( - outputList = Seq(key11, key12), - stats = Statistics( - sizeInBytes = 2 * (4 + 4), - rowCount = Some(2), - attributeStats = AttributeMap(Seq(key11 -> colStat11, key12 -> colStat12)))) - - /** Data for table2: (1, 10), (1, 20), (2, 30), (2, 40) */ - val table2 = StatsTestPlan( - outputList = Seq(key21, key22), - stats = Statistics( - sizeInBytes = 4 * (4 + 4), - rowCount = Some(4), - attributeStats = AttributeMap(Seq(key21 -> colStat21, key22 -> colStat22)))) - - /** Data for table3: (1, 10), (1, 10), (1, 20), (2, 20), (2, 10), (2, 10) */ - val table3 = StatsTestPlan( - outputList = Seq(key31, key32), - stats = Statistics( - sizeInBytes = 6 * (4 + 4), - rowCount = Some(6), - attributeStats = AttributeMap(Seq(key31 -> colStat31, key32 -> colStat32)))) + /** Columns for testing */ + private val columnInfo: Map[Attribute, ColumnStat] = + Map( + attr("key11") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, + avgLen = 4, maxLen = 4), + attr("key12") -> ColumnStat(distinctCount = 1, min = Some(10), max = Some(10), nullCount = 0, + avgLen = 4, maxLen = 4), + attr("key21") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, + avgLen = 4, maxLen = 4), + attr("key22") -> ColumnStat(distinctCount = 4, min = Some(10), max = Some(40), nullCount = 0, + avgLen = 4, maxLen = 4), + attr("key31") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, + avgLen = 4, maxLen = 4), + attr("key32") -> ColumnStat(distinctCount = 2, min = Some(10), max = Some(20), nullCount = 0, + avgLen = 4, maxLen = 4)) + + private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) + private val nameToColInfo: Map[String, (Attribute, ColumnStat)] = + columnInfo.map(kv => kv._1.name -> kv) test("empty group-by column") { + val colNames = Seq("key11", "key12") + // Suppose table1 has 2 records: (1, 10), (2, 10) + val table1 = StatsTestPlan( + outputList = colNames.map(nameToAttr), + stats = Statistics( + sizeInBytes = 2 * (4 + 4), + rowCount = Some(2), + attributeStats = AttributeMap(colNames.map(nameToColInfo)))) + checkAggStats( - testAgg = Aggregate( - groupingExpressions = Nil, - aggregateExpressions = Alias(Count(Literal(1)), "cnt")() :: Nil, - child = table1), - expectedRowCount = 1, - expectedAttrStats = AttributeMap(Nil)) + child = table1, + colNames = Nil, + expectedRowCount = 1) } test("there's a primary key in group-by columns") { + val colNames = Seq("key11", "key12") + // Suppose table1 has 2 records: (1, 10), (2, 10) + val table1 = StatsTestPlan( + outputList = colNames.map(nameToAttr), + stats = Statistics( + sizeInBytes = 2 * (4 + 4), + rowCount = Some(2), + attributeStats = AttributeMap(colNames.map(nameToColInfo)))) + checkAggStats( - testAgg = Aggregate( - groupingExpressions = Seq(key11, key12), - aggregateExpressions = Seq(key11, key12), - child = table1), + child = table1, + colNames = colNames, // Column key11 a primary key, so row count = ndv of key11 = child's row count - expectedRowCount = table1.stats.rowCount.get, - expectedAttrStats = AttributeMap(Seq(key11 -> colStat11, key12 -> colStat12))) + expectedRowCount = table1.stats.rowCount.get) } test("the product of ndv's of group-by columns is too large") { + val colNames = Seq("key21", "key22") + // Suppose table2 has 4 records: (1, 10), (1, 20), (2, 30), (2, 40) + val table2 = StatsTestPlan( + outputList = colNames.map(nameToAttr), + stats = Statistics( + sizeInBytes = 4 * (4 + 4), + rowCount = Some(4), + attributeStats = AttributeMap(colNames.map(nameToColInfo)))) + checkAggStats( - testAgg = Aggregate( - groupingExpressions = Seq(key21, key22), - aggregateExpressions = Seq(key21, key22), - child = table2), + child = table2, + colNames = colNames, // Use child's row count as an upper bound - expectedRowCount = table2.stats.rowCount.get, - expectedAttrStats = AttributeMap(Seq(key21 -> colStat21, key22 -> colStat22))) + expectedRowCount = table2.stats.rowCount.get) } test("data contains all combinations of distinct values of group-by columns.") { + val colNames = Seq("key31", "key32") + // Suppose table3 has 6 records: (1, 10), (1, 10), (1, 20), (2, 20), (2, 10), (2, 10) + val table3 = StatsTestPlan( + outputList = colNames.map(nameToAttr), + stats = Statistics( + sizeInBytes = 6 * (4 + 4), + rowCount = Some(6), + attributeStats = AttributeMap(colNames.map(nameToColInfo)))) + checkAggStats( - testAgg = Aggregate( - groupingExpressions = Seq(key31, key32), - aggregateExpressions = Seq(key31, key32), - child = table3), - expectedRowCount = colStat31.distinctCount * colStat32.distinctCount, - expectedAttrStats = AttributeMap(Seq(key31 -> colStat31, key32 -> colStat32))) + child = table3, + colNames = colNames, + // Row count = product of ndv + expectedRowCount = nameToColInfo("key31")._2.distinctCount * nameToColInfo("key32")._2 + .distinctCount) } private def checkAggStats( - testAgg: Aggregate, - expectedRowCount: BigInt, - expectedAttrStats: AttributeMap[ColumnStat]): Unit = { + child: LogicalPlan, + colNames: Seq[String], + expectedRowCount: BigInt): Unit = { + + val columns = colNames.map(nameToAttr) + val testAgg = Aggregate( + groupingExpressions = columns, + aggregateExpressions = columns :+ Alias(Count(Literal(1)), "cnt")(), + child = child) + + val expectedAttrStats = AttributeMap(colNames.map(nameToColInfo)) val expectedStats = Statistics( sizeInBytes = expectedRowCount * getRowSize(testAgg.output, expectedAttrStats), rowCount = Some(expectedRowCount), attributeStats = expectedAttrStats) + assert(testAgg.statistics == expectedStats) } }