Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,25 @@ object SortOrder {
}
}
}

/**
* Returns if a sequence of SortOrder satisfies a sequence of expressions.
* @param orderings the sequence of SortOrder
* @param groupExpressions the sequence of expressions
* @return
*/
def satisfiesExpressions(
orderings: Seq[SortOrder],
groupExpressions: Seq[Expression]): Boolean = {
var collectedExpr = Seq[Expression]()
orderings.foreach { order =>
groupExpressions.foreach {
case e if order.children.exists(_.semanticEquals(e)) => collectedExpr :+= e
case _ =>
}
}
groupExpressions.toSet == collectedExpr.toSet
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,17 @@ object ReplaceHashWithSortAgg extends Rule[SparkPlan] {
hashAgg.child match {
case partialAgg: BaseAggregateExec
if isHashBasedAggWithKeys(partialAgg) && isPartialAgg(partialAgg, hashAgg) =>
if (SortOrder.orderingSatisfies(
partialAgg.child.outputOrdering, sortAgg.requiredChildOrdering.head)) {
if (SortOrder.satisfiesExpressions(
partialAgg.child.outputOrdering, hashAgg.groupingExpressions)) {
sortAgg.copy(
aggregateExpressions = sortAgg.aggregateExpressions.map(_.copy(mode = Complete)),
child = partialAgg.child)
} else {
hashAgg
}
case other =>
if (SortOrder.orderingSatisfies(
other.outputOrdering, sortAgg.requiredChildOrdering.head)) {
if (SortOrder.satisfiesExpressions(
other.outputOrdering, hashAgg.groupingExpressions)) {
sortAgg
} else {
hashAgg
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,34 @@ abstract class ReplaceHashWithSortAggSuiteBase
}
}
}

test("SPARK-44870: The orders of grouping expressions and child outputOrdering not matched") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val query =
s"""SELECT a, b, count(1)
|FROM values(1, 1, 1), (2, 2, 2) t1(a, b, c)
|JOIN values(1, 1, 1), (2, 2, 2) t2(d, e, f)
|ON a = d
|AND b = e
|GROUP by b, a
|""".stripMargin
checkAggs(query, 0, 1, 2, 0)
}
}

test("SPARK-44870: There are semantic expressions in grouping expression") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val query =
s"""SELECT a, b, count(1)
|FROM values(1, 1, 1), (2, 2, 2) t1(a, b, c)
|JOIN values(1, 1, 1), (2, 2, 2) t2(d, e, f)
|ON a = d
|AND b = e
|GROUP by a, b, d
|""".stripMargin
checkAggs(query, 0, 1, 2, 0)
}
}
}

class ReplaceHashWithSortAggSuite extends ReplaceHashWithSortAggSuiteBase
Expand Down