diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b3e7eb44ae653..95d774c6e9915 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -2564,12 +2564,8 @@ object DecimalAggregates extends Rule[LogicalPlan] { /** Maximum number of decimal digits representable precisely in a Double */ private val MAX_DOUBLE_DIGITS = 15 - /** Tighter than the AVG fast path's `prec + 4 <= MAX_DOUBLE_DIGITS` (= 11): - * the strict-subset keeps SPARK-37024 Double-regime exposure unchanged. */ - private val AVG_PEEL_MAX_INNER_PRECISION = 7 - - /** Matches a scale-preserving widening decimal Cast; refuses CheckOverflow - * to preserve overflow semantics on the unscaled value. */ + /** Matches a scale-preserving widening decimal Cast. + * Refuses CheckOverflow so per-row overflow checks are not hoisted out. */ private object WidenedDecimalChild { def unapply(e: Expression): Option[(Expression, Int, Int, Int)] = e match { case Cast(inner @ DecimalExpression(p, s), DecimalType.Fixed(pPrime, sPrime), _, _) @@ -2604,27 +2600,35 @@ object DecimalAggregates extends Rule[LogicalPlan] { case _ => we } case ae @ AggregateExpression(af, _, _, _, _) => af match { + // Hoist a scale-preserving widening Cast out of Sum so the existing + // Long fast-path can fire on the inner. The MakeDecimal target type + // matches `Sum(Cast(x, dec(pPrime, s))).dataType` (see Sum.resultType) + // so the final-value overflow boundary is the same as the un-rewritten + // expression. case s @ Sum(WidenedDecimalChild(inner, p, pPrime, s_scale), _) if p + 10 <= MAX_LONG_DIGITS => - Cast( - MakeDecimal( - ae.copy(aggregateFunction = s.copy(child = UnscaledValue(inner))), - p + 10, s_scale), - DecimalType.bounded(pPrime + 10, s_scale), - Option(conf.sessionLocalTimeZone)) + val target = DecimalType.bounded(pPrime + 10, s_scale) + MakeDecimal( + ae.copy(aggregateFunction = s.copy(child = UnscaledValue(inner))), + target.precision, target.scale) case s @ Sum(e @ DecimalExpression(prec, scale), _) if prec + 10 <= MAX_LONG_DIGITS => MakeDecimal(ae.copy(aggregateFunction = s.copy(child = UnscaledValue(e))), prec + 10, scale) - // Ordered before the un-widened Average arm: when pPrime in [8, 11], - // the outer Cast's DecimalType would otherwise match that arm first. + // Hoist a scale-preserving widening Cast out of Average. Guarded on + // the OUTER precision `pPrime + 4 <= MAX_DOUBLE_DIGITS` so the + // rewrite only fires inside the existing Double-regime envelope; + // for wider outer casts the un-rewritten Decimal-exact path is + // preserved. Ordered before the un-widened arm so the outer Cast's + // dataType does not let that arm intercept first (when pPrime <= 11, + // it would also match -- but on the outer Cast, not the inner). case a @ Average(WidenedDecimalChild(inner, p, pPrime, s_scale), _) - if p <= AVG_PEEL_MAX_INNER_PRECISION => + if pPrime + 4 <= MAX_DOUBLE_DIGITS => val newAggExpr = ae.copy(aggregateFunction = a.copy(child = UnscaledValue(inner))) Cast( Divide(newAggExpr, Literal.create(math.pow(10.0, s_scale), DoubleType)), - DecimalType.bounded(pPrime + 4, s_scale + 4), Option(conf.sessionLocalTimeZone)) + DecimalType(pPrime + 4, s_scale + 4), Option(conf.sessionLocalTimeZone)) case a @ Average(e @ DecimalExpression(prec, scale), _) if prec + 4 <= MAX_DOUBLE_DIGITS => val newAggExpr = ae.copy(aggregateFunction = a.copy(child = UnscaledValue(e))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala index 6f8c0db261b20..b65ce3a0f0179 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala @@ -74,13 +74,15 @@ class DecimalAggregatesSuite extends PlanTest with ScalaCheckDrivenPropertyCheck val testRelationC = LocalRelation($"c".decimal(7, 2)) - test("Decimal Average Aggregation widened-cast peel: Optimized (p=7, p'=12)") { + test("Decimal Average Aggregation widened-cast peel: " + + "Not Optimized (pPrime+4 > MAX_DOUBLE_DIGITS preserves Decimal-exact path)") { + // pPrime=12, pPrime+4=16 > 15. The new AVG arm only fires inside the + // existing Double-regime envelope (pPrime+4 <= 15); for wider outer casts + // the un-rewritten Decimal-exact path is preserved. val widened = $"c".cast(DecimalType(12, 2)) val originalQuery = testRelationC.select(avg(widened)) val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = testRelationC - .select((avg(UnscaledValue($"c")) / 100.0).cast(DecimalType(16, 6)) - .as("avg(CAST(c AS DECIMAL(12,2)))")).analyze + val correctAnswer = originalQuery.analyze comparePlans(optimized, correctAnswer) } @@ -107,7 +109,10 @@ class DecimalAggregatesSuite extends PlanTest with ScalaCheckDrivenPropertyCheck comparePlans(optimized, correctAnswer) } - test("Decimal Average Aggregation widened-cast peel: Not Optimized (boundary p=8)") { + test("Decimal Average Aggregation widened-cast peel: " + + "Not Optimized (pPrime+4 > MAX_DOUBLE_DIGITS, boundary)") { + // pPrime=13, pPrime+4=17 > 15. AVG peel does not fire; existing un-widened + // arm also does not fire on the outer Cast (same guard). Plan unchanged. val testRelationE = LocalRelation($"e".decimal(8, 2)) val widened = $"e".cast(DecimalType(13, 2)) val originalQuery = testRelationE.select(avg(widened)) @@ -117,29 +122,25 @@ class DecimalAggregatesSuite extends PlanTest with ScalaCheckDrivenPropertyCheck comparePlans(optimized, correctAnswer) } - // SPARK-56627 F2 regression: with `pPrime in [8, 11]`, the outer Cast's - // dataType `Decimal(pPrime, s)` would let the un-widened existing - // `Average(DecimalExpression)` arm match first via `prec + 4 <= MAX_DOUBLE_DIGITS` - // (= pPrime <= 11). New AVG peel arm must be ordered before to win this band - // and rewrite via the inner `p`-based UnscaledValue path. + // Cast-hoisting plan simplification: when pPrime+4 <= MAX_DOUBLE_DIGITS, the + // existing un-widened AVG arm would also match the outer Cast, but wraps + // UnscaledValue around the OUTER Cast (running the Cast per row). The new + // arm is ordered before so that UnscaledValue feeds directly off the inner. test("Decimal Average Aggregation widened-cast peel: " + - "Optimized for pPrime band [8, 11] (must beat existing AVG fast-path arm)") { + "Optimized for pPrime band [p+1, 11] (drops per-row inner Cast)") { val testRelationE = LocalRelation($"e".decimal(7, 2)) val widened = $"e".cast(DecimalType(10, 2)) val originalQuery = testRelationE.select(avg(widened).as("avg_widened")) val optimized = Optimize.execute(originalQuery.analyze) // Expected: peeled via WidenedDecimalChild(inner=e, p=7, pPrime=10, s=2), - // outer Cast bounded(pPrime+4=14, s+4=6). NOT - // `MakeDecimal(Sum(UnscaledValue(cast(e as dec(10,2)))), 14, 2)` (existing - // arm form), which would lose F2's intent of avoiding the widened-cast - // intermediate. + // outer Cast target DecimalType(pPrime+4=14, s+4=6). val correctAnswer = testRelationE .select( Cast( Divide( avg(UnscaledValue($"e")), Literal.create(math.pow(10.0, 2), DoubleType)), - DecimalType.bounded(14, 6), + DecimalType(14, 6), Option(conf.sessionLocalTimeZone)) .as("avg_widened")) .analyze @@ -147,18 +148,19 @@ class DecimalAggregatesSuite extends PlanTest with ScalaCheckDrivenPropertyCheck comparePlans(optimized, correctAnswer) } - // SPARK-56627 F1 regression: `WidenedDecimalChild` must NOT peel when the - // inner expression is a `CheckOverflow` (introduced by `DecimalPrecision` - // analyzer for nullOnOverflow semantics). Peeling through `CheckOverflow` - // would change the overflow behavior of the inner aggregate. + // WidenedDecimalChild must NOT peel when the inner expression is a + // CheckOverflow (introduced by DecimalPrecision for nullOnOverflow + // semantics). Peeling through CheckOverflow would hoist a per-row + // overflow check out of the aggregate. // - // The existing un-widened `Average(DecimalExpression)` arm still fires on - // the outer Cast (dataType `Decimal(pPrime=10, s=2)`, `pPrime + 4 = 14 <= 15`), - // so the optimized plan wraps `UnscaledValue` around the OUTER cast (not - // the inner CheckOverflow). The peel-arm-fired form would instead be - // `UnscaledValue(CheckOverflow(e))` (no outer cast), which we want to AVOID. + // The existing un-widened Average(DecimalExpression) arm still fires on + // the outer Cast (dataType Decimal(pPrime=10, s=2), pPrime + 4 = 14 <= 15), + // so the optimized plan wraps UnscaledValue around the OUTER cast. Without + // the CheckOverflow guard the peel arm would feed UnscaledValue off the + // inner CheckOverflow instead, which we want to AVOID. test("Decimal Average Aggregation widened-cast peel: " + - "Not peeled for Cast(CheckOverflow(inner), wider) form (F1 guard)") { + "Not peeled for Cast(CheckOverflow(inner), wider) form " + + "(CheckOverflow guard)") { val testRelationE = LocalRelation($"e".decimal(7, 2)) val co = CheckOverflow($"e", DecimalType(7, 2), nullOnOverflow = true) val widened = Cast(co, DecimalType(10, 2)) @@ -255,28 +257,22 @@ class DecimalAggregatesSuite extends PlanTest with ScalaCheckDrivenPropertyCheck $"i".int) test("SPARK-56627: SUM(CAST(dec(7,2) AS dec(17,2))) peels via widened-Cast fast path") { - // Witness chosen so p+10=17 <= MAX_LONG_DIGITS(18) < pPrime+10=27 -- the - // new case fires (a bare-Cast inner cannot fall through to the existing - // DecimalExpression case). Expected shape: - // Cast(MakeDecimal(Sum(UnscaledValue(d7_2)), p+10=17, s=2), - // DecimalType.bounded(pPrime+10=27, s=2), - // Option(conf.sessionLocalTimeZone)) + // Cast-hoisting framing: SUM(Cast(x, dec(pPrime, s))) is rewritten to + // SUM(x) wrapped in a MakeDecimal whose precision equals the un-rewritten + // Sum's output type `min(pPrime + 10, 38)`. Expected shape: + // MakeDecimal(Sum(UnscaledValue(d7_2)), min(pPrime+10, 38)=27, s=2) val q = widenRel.select(sum($"d7_2".cast(DecimalType(17, 2)))) val optimized = Optimize.execute(q.analyze) val correctAnswer = widenRel - .select(Cast( - MakeDecimal(sum(UnscaledValue($"d7_2")), 17, 2), - DecimalType.bounded(27, 2), - Option(conf.sessionLocalTimeZone)) + .select(MakeDecimal(sum(UnscaledValue($"d7_2")), 27, 2) .as("sum(CAST(d7_2 AS DECIMAL(17,2)))")).analyze comparePlans(optimized, correctAnswer) } test("SPARK-56627: SUM(CAST(dec(7,2) AS dec(17,2))) -- peel preserves schema") { - // Schema invariance via DataType equality (not string). - // Top-level output type of SUM(dec(p,s)) is DecimalType(min(p+10,38), s); - // peeled tree wraps inner with outer Cast(_, dec(pPrime+10,s)) = dec(27,2) - // -- identical to baseline schema. + // Schema invariance via DataType equality. Top-level output type of + // `SUM(Cast(x, dec(pPrime, s)))` is `DecimalType.bounded(pPrime+10, s)`; + // the peeled MakeDecimal target precision matches. val q = widenRel.select(sum($"d7_2".cast(DecimalType(17, 2)))) val baselineSchema = q.analyze.schema val optimized = Optimize.execute(q.analyze) @@ -293,8 +289,9 @@ class DecimalAggregatesSuite extends PlanTest with ScalaCheckDrivenPropertyCheck comparePlans(optimized, correctAnswer) } - test("SPARK-56627: AVG(CAST(dec(7,2) AS dec(17,2))) -- peel preserves schema") { - val q = widenRel.select(avg($"d7_2".cast(DecimalType(17, 2)))) + test("SPARK-56627: AVG(CAST(dec(7,2) AS dec(10,2))) -- peel preserves schema") { + // Witness inside the new AVG peel bound (pPrime+4 = 14 <= 15). + val q = widenRel.select(avg($"d7_2".cast(DecimalType(10, 2)))) val baselineSchema = q.analyze.schema val optimized = Optimize.execute(q.analyze) assert(optimized.schema === baselineSchema, @@ -387,10 +384,7 @@ class DecimalAggregatesSuite extends PlanTest with ScalaCheckDrivenPropertyCheck val q = widenRel.select(sum(nullLit.cast(DecimalType(17, 2)))) val optimized = Optimize.execute(q.analyze) val correctAnswer = widenRel - .select(Cast( - MakeDecimal(sum(UnscaledValue(nullLit)), 17, 2), - DecimalType.bounded(27, 2), - Option(conf.sessionLocalTimeZone)) + .select(MakeDecimal(sum(UnscaledValue(nullLit)), 27, 2) .as("sum(CAST(NULL AS DECIMAL(17,2)))")).analyze comparePlans(optimized, correctAnswer) } @@ -401,10 +395,7 @@ class DecimalAggregatesSuite extends PlanTest with ScalaCheckDrivenPropertyCheck val q = emptyRel.select(sum($"d7_2".cast(DecimalType(17, 2)))) val optimized = Optimize.execute(q.analyze) val correctAnswer = emptyRel - .select(Cast( - MakeDecimal(sum(UnscaledValue($"d7_2")), 17, 2), - DecimalType.bounded(27, 2), - Option(conf.sessionLocalTimeZone)) + .select(MakeDecimal(sum(UnscaledValue($"d7_2")), 27, 2) .as("sum(CAST(d7_2 AS DECIMAL(17,2)))")).analyze comparePlans(optimized, correctAnswer) } @@ -453,69 +444,39 @@ class DecimalAggregatesSuite extends PlanTest with ScalaCheckDrivenPropertyCheck comparePlans(optimized, q) } - // Plan-shape property: structural invariants on the peeled tree. + // Plan-shape property: structural invariants on the peeled SUM tree. // - // Sweeps the (p, p', s) lattice where the widened-cast peel fires: - // regime (ii): p + 10 <= 18 <= p' + 10 (new arm, old fast-path off) - // regime (iii): p + 10 <= 18 < p' + 10 <= 38 - // Assertion (peel-on, structural -- NOT a hand-typed RHS clone): - // - aggregate expression is wrapped by exactly one outer Cast - // - the outer Cast wraps exactly one MakeDecimal - // - inside MakeDecimal, the Sum's child has dataType=LongType (i.e. - // UnscaledValue was inserted) - // - outer Cast target precision = p' + 10 (or 38, clamped) - // - outer Cast target scale = s - // Reframed away from RHS-equality to detect behavioural regressions - // rather than just refactor drift. - // Peel-off branch: plan is unchanged relative to its analyzed form - // (the local RuleExecutor runs only DecimalAggregates; no other rule - // can rewrite the SUM when the peel does not fire for a Cast child). + // Sweeps the (p, p', s) lattice where the widened-cast SUM peel fires: + // p + 10 <= 18 and p' > p, with p' <= 38. The rewrite produces a single + // MakeDecimal at precision min(p' + 10, 38) wrapping Sum(UnscaledValue(x)). + // I1. exactly one Sum node, whose child has LongType. + // I2. exactly one MakeDecimal node, with precision = min(p' + 10, 38) + // and scale = s -- matches Sum(Cast(x, dec(p', s))).dataType, so the + // final-value overflow boundary is unchanged from un-rewritten. private case class PeelInputs(p: Int, pPrime: Int, s: Int) - private val peelGen: Gen[PeelInputs] = Gen.frequency( - 5 -> (for { - p <- Gen.choose(1, 8) - pPrime <- Gen.choose(math.max(p + 1, 9), 28) - s <- Gen.choose(0, p) - } yield PeelInputs(p, pPrime, s)), - 5 -> (for { - p <- Gen.choose(1, 8) - pPrime <- Gen.choose(9, 28) - s <- Gen.choose(0, p) - } yield PeelInputs(p, pPrime, s)) - ) - - private val boundaryGen: Gen[PeelInputs] = Gen.oneOf( - PeelInputs(7, 17, 2), PeelInputs(7, 18, 2), PeelInputs(7, 19, 2)) - - private val peelSpaceGen: Gen[PeelInputs] = Gen.frequency( - 8 -> peelGen, - 2 -> boundaryGen - ).retryUntil(in => in.p + 10 <= 18 && in.p < in.pPrime && in.pPrime + 10 <= 38) + // Bounds already enforce the peel-firing predicate: + // p + 10 <= 18 (p <= 8), p < pPrime (pPrime >= p+1), pPrime + 10 <= 38 + // (pPrime <= 28). + private val peelGen: Gen[PeelInputs] = for { + p <- Gen.choose(1, 8) + pPrime <- Gen.choose(p + 1, 28) + s <- Gen.choose(0, p) + } yield PeelInputs(p, pPrime, s) implicit override val generatorDrivenConfig: PropertyCheckConfiguration = PropertyCheckConfiguration(minSuccessful = 50, minSize = 0, sizeRange = 0) test("SPARK-56627: DecimalAggregates widened-Cast SUM peel -- plan-shape " + "structural-invariants property") { - forAll(peelSpaceGen) { in => + forAll(peelGen) { in => val rel = LocalRelation($"x".decimal(in.p, in.s)) val q = rel.select(sum($"x".cast(DecimalType(in.pPrime, in.s)))) val analyzed = q.analyze val optimized = Optimize.execute(analyzed) - // Structural invariants the peel rewrite must establish, regardless - // of incidental tree-shape changes from neighbouring rules: - // - // I1. exactly one Sum node, whose child has LongType (the peeled - // UnscaledValue feed); - // I2. exactly one MakeDecimal node in the tree (rebuilds Decimal - // from the LONG accumulator); - // I3. an outer Cast whose target DecimalType has precision at - // least as wide as the user-written widened cast, so we never - // narrow result precision below the baseline plan. val sums = optimized.expressions.flatMap(_.collect { case s: Sum => s }) assert(sums.size == 1, s"expected exactly 1 Sum, got ${sums.size} in $optimized") assert(sums.head.child.dataType == LongType, @@ -524,43 +485,23 @@ class DecimalAggregatesSuite extends PlanTest with ScalaCheckDrivenPropertyCheck val mds = optimized.expressions.flatMap(_.collect { case m: MakeDecimal => m }) assert(mds.size == 1, s"expected exactly 1 MakeDecimal, got ${mds.size} in $optimized") - - val outerCasts = optimized.expressions.flatMap(_.collect { - case c @ Cast(_, _: DecimalType, _, _) => c - }) - assert(outerCasts.nonEmpty, - s"expected an outer Cast to DecimalType, got none in $optimized") - val outerPrec = outerCasts.map(_.dataType.asInstanceOf[DecimalType].precision).max - assert(outerPrec >= in.pPrime, - s"outer Cast precision $outerPrec < baseline ${in.pPrime} in $optimized") + val expectedPrec = math.min(in.pPrime + 10, DecimalType.MAX_PRECISION) + assert(mds.head.precision == expectedPrec && mds.head.scale == in.s, + s"expected MakeDecimal($expectedPrec, ${in.s}), got " + + s"MakeDecimal(${mds.head.precision}, ${mds.head.scale}) in $optimized") } } - // --------------------------------------------------------------------------- - // F5 (skeptic round 1): Long-accumulator / Double-regime safety boundary - // invariant guards. - // - // Background: a strict "overflow oracle" cannot be written at unit-test - // scale -- the existing fast-path guards (`p + 10 <= MAX_LONG_DIGITS = 18` - // for SUM, `AVG_PEEL_MAX_INNER_PRECISION = 7` for AVG) keep the peel-eligible - // inner-precision band so narrow that the Long accumulator (~9.22e18) cannot - // wrap on any reachable peel input: at `p=8` we'd need ~9.22e10 rows. So - // there is no production input that exercises a "peeled Long-wrap vs - // un-peeled CheckOverflow" asymmetry to oracle against. - // - // What we CAN lock is the boundary itself: if someone in the future relaxes - // either guard (raising `MAX_LONG_DIGITS - 10` for SUM, or - // `AVG_PEEL_MAX_INNER_PRECISION` for AVG), the input shapes below WOULD - // start peeling -- and the assertion that the rule is a no-op for these - // inputs would fail. That is the safety net we want: a mechanical guard - // that catches accidental widening of the peel-trigger surface. + // Safety-boundary guards: pin the SUM Long-fast-path and AVG Double-fast-path + // bounds. If either guard is later relaxed (raising `MAX_LONG_DIGITS - 10` + // for SUM, or relaxing `pPrime + 4 <= MAX_DOUBLE_DIGITS` for AVG), the input + // shapes below would start peeling and these tests would fail, flagging the + // change for re-review. test("SPARK-56627: SUM(CAST(dec(9,2) AS dec(19,2))) does NOT peel " + "(Long-accumulator safety boundary)") { - // Boundary witness: inner p=9 makes widened-arm `p + 10 = 19 > 18` reject, - // AND outer-cast existing-arm `prec + 10 = 29 > 18` reject. Both arms are - // no-ops by design -- peel cannot fire on this shape today, and must not - // start firing if the inner-precision band is later widened without - // re-deriving the Long-accumulator bound. + // Inner p=9 makes the widened-arm guard p + 10 = 19 > 18 reject. The + // existing un-widened arm also rejects (prec + 10 = 29 > 18 on the outer + // Cast). Both arms are no-ops by design. val q = widenRel.select(sum($"d9_2".cast(DecimalType(19, 2)))) val optimized = Optimize.execute(q.analyze) val correctAnswer = q.analyze @@ -568,15 +509,10 @@ class DecimalAggregatesSuite extends PlanTest with ScalaCheckDrivenPropertyCheck } test("SPARK-56627: AVG(CAST(dec(8,2) AS dec(20,2))) does NOT peel " + - "(Double-regime / SPARK-37024 safety boundary)") { - // Boundary witness: inner p=8 makes widened-AVG arm - // `p > AVG_PEEL_MAX_INNER_PRECISION (7)` reject, AND outer-cast existing - // AVG arm `prec + 4 = 24 > MAX_DOUBLE_DIGITS (15)` reject. The strict- - // subset guard `p <= 7` keeps this rule's trigger surface strictly - // inside the existing AVG fast path's surface, so SPARK-37024 - // (Double-regime silent precision loss) is not amplified. If someone - // raises `AVG_PEEL_MAX_INNER_PRECISION` past 7 without first fixing - // SPARK-37024, this test will start firing and flag the regression. + "(Double-regime safety boundary)") { + // pPrime=20, pPrime+4 = 24 > 15 rejects the widened AVG peel arm. The + // existing un-widened AVG arm also rejects on the outer Cast (same + // guard). Plan unchanged. val q = widenRel.select(avg($"d8_2".cast(DecimalType(20, 2)))) val optimized = Optimize.execute(q.analyze) val correctAnswer = q.analyze @@ -638,11 +574,13 @@ class DecimalAggregatesSuite extends PlanTest with ScalaCheckDrivenPropertyCheck test("SPARK-56949: DecimalAggregates preserves Average.evalMode " + "for try_avg on widened-cast peel arm") { - val tryAvg = Average($"d7_2".cast(DecimalType(12, 2)), EvalMode.TRY) + // pPrime=10 keeps pPrime+4=14 <= MAX_DOUBLE_DIGITS so the AVG peel arm + // fires. (pPrime=12 is outside the new bound; see SPARK-56983.) + val tryAvg = Average($"d7_2".cast(DecimalType(10, 2)), EvalMode.TRY) val q = widenRel.select(tryAvg.toAggregateExpression().as("ta")) val optimized = Optimize.execute(q.analyze) val avgs = findAverage(optimized) - assert(avgs.nonEmpty, "widened-cast AVG peel should fire for dec(7,2)->dec(12,2)") + assert(avgs.nonEmpty, "widened-cast AVG peel should fire for dec(7,2)->dec(10,2)") assert(avgs.forall(_.evalMode == EvalMode.TRY), s"evalMode should be preserved as TRY after rewrite, got " + avgs.map(_.evalMode).mkString(",")) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/explain.txt index ff0b0e468530e..f7c0dcd7c56b6 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/explain.txt @@ -257,7 +257,7 @@ Arguments: [[cs_quantity#4, cs_list_price#5, cs_sales_price#6, cs_coupon_amt#7, (46) HashAggregate [codegen id : 13] Input [12]: [cs_quantity#4, cs_list_price#5, cs_sales_price#6, cs_coupon_amt#7, cs_net_profit#8, cd_dep_count#14, c_birth_year#22, i_item_id#28, ca_country#29, ca_state#30, ca_county#31, spark_grouping_id#32] Keys [5]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, spark_grouping_id#32] -Functions [7]: [partial_avg(cast(cs_quantity#4 as decimal(12,2))), partial_avg(UnscaledValue(cs_list_price#5)), partial_avg(UnscaledValue(cs_coupon_amt#7)), partial_avg(UnscaledValue(cs_sales_price#6)), partial_avg(UnscaledValue(cs_net_profit#8)), partial_avg(cast(c_birth_year#22 as decimal(12,2))), partial_avg(cast(cd_dep_count#14 as decimal(12,2)))] +Functions [7]: [partial_avg(cast(cs_quantity#4 as decimal(12,2))), partial_avg(cast(cs_list_price#5 as decimal(12,2))), partial_avg(cast(cs_coupon_amt#7 as decimal(12,2))), partial_avg(cast(cs_sales_price#6 as decimal(12,2))), partial_avg(cast(cs_net_profit#8 as decimal(12,2))), partial_avg(cast(c_birth_year#22 as decimal(12,2))), partial_avg(cast(cd_dep_count#14 as decimal(12,2)))] Aggregate Attributes [14]: [sum#33, count#34, sum#35, count#36, sum#37, count#38, sum#39, count#40, sum#41, count#42, sum#43, count#44, sum#45, count#46] Results [19]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, spark_grouping_id#32, sum#47, count#48, sum#49, count#50, sum#51, count#52, sum#53, count#54, sum#55, count#56, sum#57, count#58, sum#59, count#60] @@ -268,9 +268,9 @@ Arguments: hashpartitioning(i_item_id#28, ca_country#29, ca_state#30, ca_county# (48) HashAggregate [codegen id : 14] Input [19]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, spark_grouping_id#32, sum#47, count#48, sum#49, count#50, sum#51, count#52, sum#53, count#54, sum#55, count#56, sum#57, count#58, sum#59, count#60] Keys [5]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, spark_grouping_id#32] -Functions [7]: [avg(cast(cs_quantity#4 as decimal(12,2))), avg(UnscaledValue(cs_list_price#5)), avg(UnscaledValue(cs_coupon_amt#7)), avg(UnscaledValue(cs_sales_price#6)), avg(UnscaledValue(cs_net_profit#8)), avg(cast(c_birth_year#22 as decimal(12,2))), avg(cast(cd_dep_count#14 as decimal(12,2)))] -Aggregate Attributes [7]: [avg(cast(cs_quantity#4 as decimal(12,2)))#61, avg(UnscaledValue(cs_list_price#5))#62, avg(UnscaledValue(cs_coupon_amt#7))#63, avg(UnscaledValue(cs_sales_price#6))#64, avg(UnscaledValue(cs_net_profit#8))#65, avg(cast(c_birth_year#22 as decimal(12,2)))#66, avg(cast(cd_dep_count#14 as decimal(12,2)))#67] -Results [11]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, avg(cast(cs_quantity#4 as decimal(12,2)))#61 AS agg1#68, cast((avg(UnscaledValue(cs_list_price#5))#62 / 100.0) as decimal(16,6)) AS agg2#69, cast((avg(UnscaledValue(cs_coupon_amt#7))#63 / 100.0) as decimal(16,6)) AS agg3#70, cast((avg(UnscaledValue(cs_sales_price#6))#64 / 100.0) as decimal(16,6)) AS agg4#71, cast((avg(UnscaledValue(cs_net_profit#8))#65 / 100.0) as decimal(16,6)) AS agg5#72, avg(cast(c_birth_year#22 as decimal(12,2)))#66 AS agg6#73, avg(cast(cd_dep_count#14 as decimal(12,2)))#67 AS agg7#74] +Functions [7]: [avg(cast(cs_quantity#4 as decimal(12,2))), avg(cast(cs_list_price#5 as decimal(12,2))), avg(cast(cs_coupon_amt#7 as decimal(12,2))), avg(cast(cs_sales_price#6 as decimal(12,2))), avg(cast(cs_net_profit#8 as decimal(12,2))), avg(cast(c_birth_year#22 as decimal(12,2))), avg(cast(cd_dep_count#14 as decimal(12,2)))] +Aggregate Attributes [7]: [avg(cast(cs_quantity#4 as decimal(12,2)))#61, avg(cast(cs_list_price#5 as decimal(12,2)))#62, avg(cast(cs_coupon_amt#7 as decimal(12,2)))#63, avg(cast(cs_sales_price#6 as decimal(12,2)))#64, avg(cast(cs_net_profit#8 as decimal(12,2)))#65, avg(cast(c_birth_year#22 as decimal(12,2)))#66, avg(cast(cd_dep_count#14 as decimal(12,2)))#67] +Results [11]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, avg(cast(cs_quantity#4 as decimal(12,2)))#61 AS agg1#68, avg(cast(cs_list_price#5 as decimal(12,2)))#62 AS agg2#69, avg(cast(cs_coupon_amt#7 as decimal(12,2)))#63 AS agg3#70, avg(cast(cs_sales_price#6 as decimal(12,2)))#64 AS agg4#71, avg(cast(cs_net_profit#8 as decimal(12,2)))#65 AS agg5#72, avg(cast(c_birth_year#22 as decimal(12,2)))#66 AS agg6#73, avg(cast(cd_dep_count#14 as decimal(12,2)))#67 AS agg7#74] (49) TakeOrderedAndProject Input [11]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, agg1#68, agg2#69, agg3#70, agg4#71, agg5#72, agg6#73, agg7#74] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/simplified.txt index 079bb6aba3ec8..276165729be54 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/simplified.txt @@ -1,6 +1,6 @@ TakeOrderedAndProject [ca_country,ca_state,ca_county,i_item_id,agg1,agg2,agg3,agg4,agg5,agg6,agg7] WholeStageCodegen (14) - HashAggregate [i_item_id,ca_country,ca_state,ca_county,spark_grouping_id,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count] [avg(cast(cs_quantity as decimal(12,2))),avg(UnscaledValue(cs_list_price)),avg(UnscaledValue(cs_coupon_amt)),avg(UnscaledValue(cs_sales_price)),avg(UnscaledValue(cs_net_profit)),avg(cast(c_birth_year as decimal(12,2))),avg(cast(cd_dep_count as decimal(12,2))),agg1,agg2,agg3,agg4,agg5,agg6,agg7,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count] + HashAggregate [i_item_id,ca_country,ca_state,ca_county,spark_grouping_id,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count] [avg(cast(cs_quantity as decimal(12,2))),avg(cast(cs_list_price as decimal(12,2))),avg(cast(cs_coupon_amt as decimal(12,2))),avg(cast(cs_sales_price as decimal(12,2))),avg(cast(cs_net_profit as decimal(12,2))),avg(cast(c_birth_year as decimal(12,2))),avg(cast(cd_dep_count as decimal(12,2))),agg1,agg2,agg3,agg4,agg5,agg6,agg7,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count] InputAdapter Exchange [i_item_id,ca_country,ca_state,ca_county,spark_grouping_id] #1 WholeStageCodegen (13) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/explain.txt index 8f25c83767ffc..7db1c87c52a6a 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/explain.txt @@ -227,7 +227,7 @@ Arguments: [[cs_quantity#4, cs_list_price#5, cs_sales_price#6, cs_coupon_amt#7, (40) HashAggregate [codegen id : 7] Input [12]: [cs_quantity#4, cs_list_price#5, cs_sales_price#6, cs_coupon_amt#7, cs_net_profit#8, cd_dep_count#14, c_birth_year#19, i_item_id#28, ca_country#29, ca_state#30, ca_county#31, spark_grouping_id#32] Keys [5]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, spark_grouping_id#32] -Functions [7]: [partial_avg(cast(cs_quantity#4 as decimal(12,2))), partial_avg(UnscaledValue(cs_list_price#5)), partial_avg(UnscaledValue(cs_coupon_amt#7)), partial_avg(UnscaledValue(cs_sales_price#6)), partial_avg(UnscaledValue(cs_net_profit#8)), partial_avg(cast(c_birth_year#19 as decimal(12,2))), partial_avg(cast(cd_dep_count#14 as decimal(12,2)))] +Functions [7]: [partial_avg(cast(cs_quantity#4 as decimal(12,2))), partial_avg(cast(cs_list_price#5 as decimal(12,2))), partial_avg(cast(cs_coupon_amt#7 as decimal(12,2))), partial_avg(cast(cs_sales_price#6 as decimal(12,2))), partial_avg(cast(cs_net_profit#8 as decimal(12,2))), partial_avg(cast(c_birth_year#19 as decimal(12,2))), partial_avg(cast(cd_dep_count#14 as decimal(12,2)))] Aggregate Attributes [14]: [sum#33, count#34, sum#35, count#36, sum#37, count#38, sum#39, count#40, sum#41, count#42, sum#43, count#44, sum#45, count#46] Results [19]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, spark_grouping_id#32, sum#47, count#48, sum#49, count#50, sum#51, count#52, sum#53, count#54, sum#55, count#56, sum#57, count#58, sum#59, count#60] @@ -238,9 +238,9 @@ Arguments: hashpartitioning(i_item_id#28, ca_country#29, ca_state#30, ca_county# (42) HashAggregate [codegen id : 8] Input [19]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, spark_grouping_id#32, sum#47, count#48, sum#49, count#50, sum#51, count#52, sum#53, count#54, sum#55, count#56, sum#57, count#58, sum#59, count#60] Keys [5]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, spark_grouping_id#32] -Functions [7]: [avg(cast(cs_quantity#4 as decimal(12,2))), avg(UnscaledValue(cs_list_price#5)), avg(UnscaledValue(cs_coupon_amt#7)), avg(UnscaledValue(cs_sales_price#6)), avg(UnscaledValue(cs_net_profit#8)), avg(cast(c_birth_year#19 as decimal(12,2))), avg(cast(cd_dep_count#14 as decimal(12,2)))] -Aggregate Attributes [7]: [avg(cast(cs_quantity#4 as decimal(12,2)))#61, avg(UnscaledValue(cs_list_price#5))#62, avg(UnscaledValue(cs_coupon_amt#7))#63, avg(UnscaledValue(cs_sales_price#6))#64, avg(UnscaledValue(cs_net_profit#8))#65, avg(cast(c_birth_year#19 as decimal(12,2)))#66, avg(cast(cd_dep_count#14 as decimal(12,2)))#67] -Results [11]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, avg(cast(cs_quantity#4 as decimal(12,2)))#61 AS agg1#68, cast((avg(UnscaledValue(cs_list_price#5))#62 / 100.0) as decimal(16,6)) AS agg2#69, cast((avg(UnscaledValue(cs_coupon_amt#7))#63 / 100.0) as decimal(16,6)) AS agg3#70, cast((avg(UnscaledValue(cs_sales_price#6))#64 / 100.0) as decimal(16,6)) AS agg4#71, cast((avg(UnscaledValue(cs_net_profit#8))#65 / 100.0) as decimal(16,6)) AS agg5#72, avg(cast(c_birth_year#19 as decimal(12,2)))#66 AS agg6#73, avg(cast(cd_dep_count#14 as decimal(12,2)))#67 AS agg7#74] +Functions [7]: [avg(cast(cs_quantity#4 as decimal(12,2))), avg(cast(cs_list_price#5 as decimal(12,2))), avg(cast(cs_coupon_amt#7 as decimal(12,2))), avg(cast(cs_sales_price#6 as decimal(12,2))), avg(cast(cs_net_profit#8 as decimal(12,2))), avg(cast(c_birth_year#19 as decimal(12,2))), avg(cast(cd_dep_count#14 as decimal(12,2)))] +Aggregate Attributes [7]: [avg(cast(cs_quantity#4 as decimal(12,2)))#61, avg(cast(cs_list_price#5 as decimal(12,2)))#62, avg(cast(cs_coupon_amt#7 as decimal(12,2)))#63, avg(cast(cs_sales_price#6 as decimal(12,2)))#64, avg(cast(cs_net_profit#8 as decimal(12,2)))#65, avg(cast(c_birth_year#19 as decimal(12,2)))#66, avg(cast(cd_dep_count#14 as decimal(12,2)))#67] +Results [11]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, avg(cast(cs_quantity#4 as decimal(12,2)))#61 AS agg1#68, avg(cast(cs_list_price#5 as decimal(12,2)))#62 AS agg2#69, avg(cast(cs_coupon_amt#7 as decimal(12,2)))#63 AS agg3#70, avg(cast(cs_sales_price#6 as decimal(12,2)))#64 AS agg4#71, avg(cast(cs_net_profit#8 as decimal(12,2)))#65 AS agg5#72, avg(cast(c_birth_year#19 as decimal(12,2)))#66 AS agg6#73, avg(cast(cd_dep_count#14 as decimal(12,2)))#67 AS agg7#74] (43) TakeOrderedAndProject Input [11]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, agg1#68, agg2#69, agg3#70, agg4#71, agg5#72, agg6#73, agg7#74] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/simplified.txt index 7c3075e26fa23..269bfd3f44fcb 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/simplified.txt @@ -1,6 +1,6 @@ TakeOrderedAndProject [ca_country,ca_state,ca_county,i_item_id,agg1,agg2,agg3,agg4,agg5,agg6,agg7] WholeStageCodegen (8) - HashAggregate [i_item_id,ca_country,ca_state,ca_county,spark_grouping_id,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count] [avg(cast(cs_quantity as decimal(12,2))),avg(UnscaledValue(cs_list_price)),avg(UnscaledValue(cs_coupon_amt)),avg(UnscaledValue(cs_sales_price)),avg(UnscaledValue(cs_net_profit)),avg(cast(c_birth_year as decimal(12,2))),avg(cast(cd_dep_count as decimal(12,2))),agg1,agg2,agg3,agg4,agg5,agg6,agg7,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count] + HashAggregate [i_item_id,ca_country,ca_state,ca_county,spark_grouping_id,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count] [avg(cast(cs_quantity as decimal(12,2))),avg(cast(cs_list_price as decimal(12,2))),avg(cast(cs_coupon_amt as decimal(12,2))),avg(cast(cs_sales_price as decimal(12,2))),avg(cast(cs_net_profit as decimal(12,2))),avg(cast(c_birth_year as decimal(12,2))),avg(cast(cd_dep_count as decimal(12,2))),agg1,agg2,agg3,agg4,agg5,agg6,agg7,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count] InputAdapter Exchange [i_item_id,ca_country,ca_state,ca_county,spark_grouping_id] #1 WholeStageCodegen (7) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 180dd5d5db949..694b087d6c3fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -4822,11 +4822,10 @@ class DataFrameAggregateSuite extends SharedSparkSession // Numerical-equivalence property (sql-core layer). // - // Sweeps the (p, p', s, n) lattice where the widened-cast peel fires, - // asserting that SUM(CAST(x AS DECIMAL(p', s))) on an on-vs-off SQLConf - // pair returns bit-equal java.math.BigDecimal (same unscaled value AND - // same scale). Domain is restricted to the non-overflow regime so the - // peeled LONG accumulator cannot wrap. + // Sweeps the (p, p', s, n) lattice where the widened-cast SUM peel fires + // and asserts that the optimized result matches an external java.math.BigDecimal + // reference computed in pure Scala. Domain is restricted to the non-overflow + // regime so the peeled LONG accumulator cannot wrap. // // Non-overflow bound: with |unscaled(x)| < 10^p, p <= 8, n <= 1000, // worst-case accumulator is 1000 * (10^8 - 1) < 10^12 << 2^63. @@ -4913,29 +4912,26 @@ class DataFrameAggregateSuite extends SharedSparkSession // identical to the existing fast path on AVG(x) directly. Both arms in // Optimizer.DecimalAggregates produce // Cast(Divide(Avg(UnscaledValue()), Lit(10^s, Double)), - // DecimalType.bounded(, s + 4)) + // DecimalType(, s + 4)) // and the peel arm makes equal to the user's column, so the // Double-divide dividends are bit-identical between the two paths; only // the outer Cast target precision differs (pPrime+4 vs p+4), a widening // precision Cast that preserves numerical value. We therefore assert // BigDecimal.compareTo == 0 (value equality across differing precisions). // - // Domain: inner p in [1, 7] (the AVG strict-subset guard - // `AVG_PEEL_MAX_INNER_PRECISION = 7`), pPrime in [8, 11] (the band where - // the existing `Average(DecimalExpression)` arm would intercept on the - // outer Cast type if not for our prepended arm), s in [0, p], - // n <= 1000 rows. The inner DataFrame schema is constructed as - // DecimalType(p, s) explicitly (NOT via tuple-inference, which would - // infer DecimalType.SYSTEM_DEFAULT and silently route through a DIFFERENT - // rule arm than intended -- the failure mode this PBT must lock down). + // Domain: pPrime in [p+1, 11] -- the band where pPrime + 4 <= MAX_DOUBLE_DIGITS + // so the new arm fires and the existing un-widened arm would also have + // matched the outer Cast (allowing comparison against AVG(x) as oracle). + // The inner DataFrame schema is constructed as DecimalType(p, s) explicitly + // (NOT via tuple-inference, which would infer DecimalType.SYSTEM_DEFAULT and + // silently route through a DIFFERENT rule arm than intended). private case class AvgDomain(p: Int, pPrime: Int, s: Int) private val avgDomainGen: Gen[AvgDomain] = (for { - p <- Gen.choose(1, 7) - pPrime <- Gen.choose(8, 11) + p <- Gen.choose(1, 10) + pPrime <- Gen.choose(p + 1, 11) s <- Gen.choose(0, p) } yield AvgDomain(p, pPrime, s)) - .retryUntil(d => d.p < d.pPrime) private def avgInputDf(unscaledLongs: Seq[Long], d: AvgDomain) = { val rows = unscaledLongs.map(u => Row(java.math.BigDecimal.valueOf(u, d.s))) @@ -4972,7 +4968,7 @@ class DataFrameAggregateSuite extends SharedSparkSession val peeled = avgCastResult(xs, d) val direct = avgDirectResult(xs, d) // BigDecimal.compareTo ignores trailing-zero precision differences: - // peeled has output DecimalType.bounded(pPrime+4, s+4), direct has + // peeled has output DecimalType(pPrime+4, s+4), direct has // DecimalType(p+4, s+4). Both wrap the same Double-divide bit pattern // so the underlying value is identical. assert(peeled.compareTo(direct) == 0, @@ -4982,17 +4978,13 @@ class DataFrameAggregateSuite extends SharedSparkSession } } - // Wider-pPrime regime shape witness: (p=4, p'=20, s=2). The equivalence - // PBT above only covers pPrime in [8, 11] (where the existing AVG arm - // would otherwise intercept and provide a comparable oracle). For pPrime - // outside that band the new arm still fires (only constrained by inner - // p <= 7), but the comparison oracle "AVG(x) directly" is no longer - // available because the existing arm targets a narrower output type. - // This witness asserts non-null result and the expected widened output - // schema, locking the rule's shape contract without claiming an - // unreachable oracle. - test("SPARK-56627: AVG(CAST(dec(4,2) AS dec(20,2))) peels and yields " + - "widened output schema (wider-pPrime regime shape witness)") { + // Wider-pPrime regime: when pPrime + 4 > MAX_DOUBLE_DIGITS the AVG peel arm + // is intentionally NOT fired so the un-rewritten Decimal-exact path is + // preserved. Witness: (p=4, p'=20, s=2) -- pPrime + 4 = 24 > 15. Asserts + // non-null result and the expected widened output schema; rule shape is + // covered by the catalyst-layer suite. + test("SPARK-56627: AVG(CAST(dec(4,2) AS dec(20,2))) yields " + + "widened output schema (Decimal-exact path preserved)") { val rows = Seq(123L, -456L, 789L, 0L) .map(u => Row(java.math.BigDecimal.valueOf(u, 2))) val schema = StructType(StructField("x", DecimalType(4, 2)) :: Nil) @@ -5001,10 +4993,9 @@ class DataFrameAggregateSuite extends SharedSparkSession val row = df.collect()(0) assert(!row.isNullAt(0), s"expected non-null AVG, got null; df schema = ${df.schema}") val outType = df.schema("a").dataType.asInstanceOf[DecimalType] - // Widened-arm output Cast target = DecimalType.bounded(pPrime + 4, s + 4) - // = DecimalType.bounded(24, 6). + // Un-rewritten Average.dataType = bounded(pPrime + 4, s + 4) = (24, 6). assert(outType.precision == 24 && outType.scale == 6, - s"expected DecimalType(24, 6) from widened-arm peel, got $outType") + s"expected DecimalType(24, 6) from un-rewritten AVG, got $outType") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DecimalAggregatesBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DecimalAggregatesBenchmark.scala index fc00bea62dd16..e006787dbfa13 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DecimalAggregatesBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DecimalAggregatesBenchmark.scala @@ -46,9 +46,11 @@ import org.apache.spark.sql.types.Decimal * PBT in `DataFrameAggregateSuite`). * * Sections: - * A -- Aggregate SUM widened-cast sweep over (p, s, p') cases. - * B -- Aggregate AVG widened-cast sweep (p <= 7 per - * AVG_PEEL_MAX_INNER_PRECISION). + * A -- Aggregate SUM widened-cast sweep (`p + 10 <= MAX_LONG_DIGITS`, + * any `pPrime > p` up to 38). + * B -- Aggregate AVG widened-cast sweep (`pPrime + 4 <= MAX_DOUBLE_DIGITS` + * so the rule fires only inside the existing AVG Double-regime + * envelope; wider casts stay on the Decimal-exact path). * * NOTE on Window arm: the optimizer does not extend widened-Cast peel to * the Window arm (see DecimalAggregates rule comment) because the analyzer @@ -56,12 +58,13 @@ import org.apache.spark.sql.types.Decimal * not exercise this rule. A Window benchmark belongs with a future * plan-layer rewrite, not here. * - * Case design (`p+1` boundary vs `p+10`-class wider) deliberately includes - * both the minimum widening (one extra digit, e.g. `dec(7,2) -> dec(8,2)`) - * and a production-canonical wider one (e.g. `dec(7,2) -> dec(17,2)` is the - * inner-widened-decimal shape in TPC-DS q18) so reviewers see whether - * widening magnitude matters and whether the canonical shape behaves like - * the boundary one. + * Case design: + * - Section A pairs a `p+1` boundary widening with a `p+10`-class wider + * cast (A2 mirrors the TPC-DS q18 inner-widened-decimal shape), so + * reviewers see whether widening magnitude matters. + * - Section B pairs a `p+1` boundary widening with the `pPrime <= 11` + * upper bound, the widest cast the AVG arm will accept under the + * semantics-preserving guard. * * Args: aN (Section A/B row count), iters, apl * (`spark.sql.decimalOperations.allowPrecisionLoss`; default true). @@ -106,17 +109,16 @@ object DecimalAggregatesBenchmark extends SqlBasedBenchmark { /** * Aggregate AVG cases: (label, p, s, widened p'). * - * All `p <= 7` per the conservative `AVG_PEEL_MAX_INNER_PRECISION = 7` - * guard (see design doc 0001 rev 7 S7.1 -- strict-subset narrowing so - * SPARK-37024 Double-regime exposure is NOT amplified by this rule). - * Same `p+1` / `p+10` split as Section A. B2 mirrors the canonical - * TPC-DS q18 AVG shape. + * All `pPrime + 4 <= MAX_DOUBLE_DIGITS = 15`, i.e. `pPrime <= 11` -- the + * AVG peel arm only fires inside the existing Double-regime envelope, so + * the un-rewritten Decimal-exact path is preserved for wider casts (see + * SPARK-56983). */ private val AvgAggCases: Seq[(String, Int, Int, Int)] = Seq( ("B1 p=7 s=2 p'=8", 7, 2, 8), // p+1 boundary - ("B2 p=7 s=2 p'=12", 7, 2, 12), // canonical TPC-DS q18 AVG shape + ("B2 p=7 s=2 p'=11", 7, 2, 11), // pPrime upper bound ("B3 p=5 s=0 p'=6", 5, 0, 6), // p+1 boundary, zero scale - ("B4 p=5 s=0 p'=15", 5, 0, 15) // p+10, zero scale + ("B4 p=5 s=0 p'=11", 5, 0, 11) // pPrime upper bound, zero scale ) /** Clamp generator to `10^(p-s) - 1` so rand() * bound fits `DECIMAL(p, s)`. */ @@ -195,8 +197,9 @@ object DecimalAggregatesBenchmark extends SqlBasedBenchmark { runBenchmark("DecimalAggregates AVG widened-cast peel (Aggregate)") { AvgAggCases.foreach { case (label, p, s, pPrime) => require(pPrime > p, s"$label: p'=$pPrime must exceed p=$p") - require(p <= 7, - s"$label: p=$p violates conservative AVG_PEEL_MAX_INNER_PRECISION=7 guard") + require(pPrime + 4 <= 15, + s"$label: p'=$pPrime violates AVG fast-path guard " + + s"pPrime+4<=MAX_DOUBLE_DIGITS=15; rule would not fire") setupAggTable(spark, aN, p, s) runThreeWay(label, aN, nativeSql = "select avg(x) from t",