diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 24d64399a2abb..9d86803c9006a 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -47,17 +47,29 @@ case class FilterEstimation(plan: Filter) extends Logging { // Estimate selectivity of this filter predicate, and update column stats if needed. // For not-supported condition, set filter selectivity to a conservative estimate 100% - val filterSelectivity = calculateFilterSelectivity(plan.condition).getOrElse(1.0) + val filterSelectivity = + calculateFilterSelectivity(plan.condition).map(boundProbability).getOrElse(1.0) - val filteredRowCount: BigInt = ceil(BigDecimal(childStats.rowCount.get) * filterSelectivity) + val childRowCount = childStats.rowCount.get + val filteredRowCount: BigInt = + ceil(BigDecimal(childRowCount) * filterSelectivity).min(childRowCount) val newColStats = if (filteredRowCount == 0) { // The output is empty, we don't need to keep column stats. AttributeMap[ColumnStat](Nil) } else { - colStatsMap.outputColumnStats(rowsBeforeFilter = childStats.rowCount.get, + colStatsMap.outputColumnStats(rowsBeforeFilter = childRowCount, rowsAfterFilter = filteredRowCount) } - val filteredSizeInBytes: BigInt = getOutputSize(plan.output, filteredRowCount, newColStats) + val sizeByOutputAttrs = getOutputSize(plan.output, filteredRowCount, newColStats) + val sizeByChildScaling = if (childRowCount > 0 && filteredRowCount > 0) { + ceil( + BigDecimal(childStats.sizeInBytes) * BigDecimal(filteredRowCount) / + BigDecimal(childRowCount)) + .max(BigInt(1)) + } else { + BigInt(1) + } + val filteredSizeInBytes: BigInt = sizeByOutputAttrs.min(sizeByChildScaling) Some(childStats.copy(sizeInBytes = filteredSizeInBytes, rowCount = Some(filteredRowCount), attributeStats = newColStats)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 2ec247564caf3..02dc2735d8b4b 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -290,6 +290,35 @@ class FilterEstimationSuite extends StatsEstimationTestBase { expectedRowCount = 10) } + test("IS NOT NULL should not increase sizeInBytes over child") { + val attrStringLarge = AttributeReference("cstring_large", StringType)() + val colStatStringLarge = ColumnStat(distinctCount = Some(10), min = None, max = None, + nullCount = Some(0), avgLen = Some(1000), maxLen = Some(1000)) + val childPlan = StatsTestPlan( + outputList = Seq(attrStringLarge), + rowCount = 1000, + attributeStats = AttributeMap(Seq(attrStringLarge -> colStatStringLarge)), + size = Some(1000) + ) + val filter = Filter(IsNotNull(attrStringLarge), childPlan) + val filterStats = filter.stats + assert(filterStats.rowCount.contains(1000)) + assert(filterStats.sizeInBytes == 1000) + } + + test("Range filter should scale sizeInBytes based on child size") { + val childPlan = StatsTestPlan( + outputList = Seq(attrInt), + rowCount = 10L, + attributeStats = AttributeMap(Seq(attrInt -> colStatInt)), + size = Some(100) + ) + val filter = Filter(GreaterThan(attrInt, Literal(6)), childPlan) + val filterStats = filter.stats + assert(filterStats.rowCount.contains(5)) + assert(filterStats.sizeInBytes == 50) + } + test("cint IS NOT NULL && null") { // 'cint < null' will be optimized to 'cint IS NOT NULL && null'. // More similar cases can be found in the Optimizer NullPropagation.