Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
f.copy(condition = newCond)

// We should make sure all [[SortOrder]]s have been resolved.
case s @ Sort(order, _, child)
case s @ Sort(order, _, child, _)
if order.exists(hasGroupingFunction) && order.forall(_.resolved) =>
val groupingExprs = findGroupingExprs(child)
val gid = VirtualColumn.groupingIdAttribute
Expand Down Expand Up @@ -1832,7 +1832,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
case p if !p.childrenResolved => p
// Replace the index with the related attribute for ORDER BY,
// which is a 1-base position of the projection list.
case Sort(orders, global, child)
case Sort(orders, global, child, hint)
if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) =>
val newOrders = orders map {
case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering, _) =>
Expand All @@ -1843,14 +1843,14 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
}
case o => o
}
Sort(newOrders, global, child)
Sort(newOrders, global, child, hint)

// Replace the index with the corresponding expression in aggregateExpressions. The index is
// a 1-base position of aggregateExpressions, which is output columns (select expression)
case Aggregate(groups, aggs, child, _) if aggs.forall(_.resolved) &&
case Aggregate(groups, aggs, child, hint) if aggs.forall(_.resolved) &&
groups.exists(containUnresolvedOrdinal) =>
val newGroups = groups.map(resolveGroupByExpressionOrdinal(_, aggs))
Aggregate(newGroups, aggs, child)
Aggregate(newGroups, aggs, child, hint)
}

private def containUnresolvedOrdinal(e: Expression): Boolean = e match {
Expand Down Expand Up @@ -2634,15 +2634,15 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
Filter(newExprs.head, newChild)
})

case s @ Sort(_, _, agg: Aggregate) if agg.resolved && s.order.forall(_.resolved) =>
case s @ Sort(_, _, agg: Aggregate, _) if agg.resolved && s.order.forall(_.resolved) =>
resolveOperatorWithAggregate(s.order.map(_.child), agg, (newExprs, newChild) => {
val newSortOrder = s.order.zip(newExprs).map {
case (sortOrder, expr) => sortOrder.copy(child = expr)
}
s.copy(order = newSortOrder, child = newChild)
})

case s @ Sort(_, _, f @ Filter(cond, agg: Aggregate))
case s @ Sort(_, _, f @ Filter(cond, agg: Aggregate), _)
if agg.resolved && cond.resolved && s.order.forall(_.resolved) =>
resolveOperatorWithAggregate(s.order.map(_.child), agg, (newExprs, newChild) => {
val newSortOrder = s.order.zip(newExprs).map {
Expand Down Expand Up @@ -3895,10 +3895,10 @@ object CleanupAliases extends Rule[LogicalPlan] with AliasHelper {
val cleanedAggs = aggs.map(trimNonTopLevelAliases)
Aggregate(grouping.map(trimAliases), cleanedAggs, child, hint)

case Window(windowExprs, partitionSpec, orderSpec, child) =>
case Window(windowExprs, partitionSpec, orderSpec, child, hint) =>
val cleanedWindowExprs = windowExprs.map(trimNonTopLevelAliases)
Window(cleanedWindowExprs, partitionSpec.map(trimAliases),
orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child)
orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child, hint)

case CollectMetrics(name, metrics, child, dataframeId) =>
val cleanedMetrics = metrics.map(trimNonTopLevelAliases)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
case up: Unpivot if up.canBeCoercioned && !up.valuesTypeCoercioned =>
throw QueryCompilationErrors.unpivotValueDataTypeMismatchError(up.values.get)

case Sort(orders, _, _) =>
case Sort(orders, _, _, _) =>
orders.foreach { order =>
if (!RowOrdering.isOrderable(order.dataType)) {
order.failAnalysis(
Expand All @@ -607,7 +607,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
}
}

case Window(_, partitionSpec, _, _) =>
case Window(_, partitionSpec, _, _, _) =>
// Both `partitionSpec` and `orderSpec` must be orderable. We only need an extra check
// for `partitionSpec` here because `orderSpec` has the type check itself.
partitionSpec.foreach { p =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

case oldVersion @ Window(windowExpressions, _, _, child)
case oldVersion @ Window(windowExpressions, _, _, child, _)
if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
.nonEmpty =>
val newVersion = oldVersion.copy(windowExpressions = newAliases(windowExpressions))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -484,14 +484,14 @@ object UnsupportedOperationChecker extends Logging {

case Offset(_, _) => throwError("Offset is not supported on streaming DataFrames/Datasets")

case Sort(_, _, _) if !containsCompleteData(subPlan) =>
case Sort(_, _, _, _) if !containsCompleteData(subPlan) =>
throwError("Sorting is not supported on streaming DataFrames/Datasets, unless it is on " +
"aggregated DataFrame/Dataset in Complete output mode")

case Sample(_, _, _, _, child) if child.isStreaming =>
throwError("Sampling is not supported on streaming DataFrames/Datasets")

case Window(windowExpression, _, _, child) if child.isStreaming =>
case Window(windowExpression, _, _, child, _) if child.isStreaming =>
val (windowFuncList, columnNameList, windowSpecList) = windowExpression.flatMap { e =>
e.collect {
case we: WindowExpression =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ object DecorrelateInnerQuery extends PredicateHelper {
// of limit in that case. This branch is for the case where there's no limit operator
// above offset.
val (child, ordering) = input match {
case Sort(order, _, child) => (child, order)
case Sort(order, _, child, _) => (child, order)
case _ => (input, Seq())
}
val (newChild, joinCond, outerReferenceMap) =
Expand Down Expand Up @@ -705,8 +705,8 @@ object DecorrelateInnerQuery extends PredicateHelper {
// SELECT T2.a, row_number() OVER (PARTITION BY T2.b ORDER BY T2.c) AS rn FROM T2)
// WHERE rn > 2 AND rn <= 2+3
val (child, ordering, offsetExpr) = input match {
case Sort(order, _, child) => (child, order, Literal(0))
case Offset(offsetExpr, offsetChild@(Sort(order, _, child))) =>
case Sort(order, _, child, _) => (child, order, Literal(0))
case Offset(offsetExpr, offsetChild@(Sort(order, _, child, _))) =>
(child, order, offsetExpr)
case Offset(offsetExpr, child) =>
(child, Seq(), offsetExpr)
Expand Down Expand Up @@ -754,7 +754,7 @@ object DecorrelateInnerQuery extends PredicateHelper {
(project, joinCond, outerReferenceMap)
}

case w @ Window(projectList, partitionSpec, orderSpec, child) =>
case w @ Window(projectList, partitionSpec, orderSpec, child, hint) =>
val outerReferences = collectOuterReferences(w.expressions)
assert(outerReferences.isEmpty, s"Correlated column is not allowed in window " +
s"function: $w")
Expand All @@ -770,7 +770,7 @@ object DecorrelateInnerQuery extends PredicateHelper {

val newWindow = Window(newProjectList ++ referencesToAdd,
partitionSpec = newPartitionSpec ++ referencesToAdd,
orderSpec = newOrderSpec, newChild)
orderSpec = newOrderSpec, newChild, hint)
(newWindow, joinCond, outerReferenceMap)

case a @ Aggregate(groupingExpressions, aggregateExpressions, child, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{WINDOW, WINDOW_EXPRESSIO
object EliminateWindowPartitions extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsPattern(WINDOW), ruleId) {
case w @ Window(windowExprs, partitionSpec, _, _) if partitionSpec.exists(_.foldable) =>
case w @ Window(windowExprs, partitionSpec, _, _, _) if partitionSpec.exists(_.foldable) =>
val newWindowExprs = windowExprs.map(_.transformWithPruning(
_.containsPattern(WINDOW_EXPRESSION)) {
case windowExpr @ WindowExpression(_, wsd @ WindowSpecDefinition(ps, _, _))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ object InferWindowGroupLimit extends Rule[LogicalPlan] with PredicateHelper {

plan.transformWithPruning(_.containsAllPatterns(FILTER, WINDOW), ruleId) {
case filter @ Filter(condition,
window @ Window(windowExpressions, partitionSpec, orderSpec, child))
window @ Window(windowExpressions, partitionSpec, orderSpec, child, _))
if !child.isInstanceOf[WindowGroupLimit] && windowExpressions.forall(isExpandingWindow) &&
orderSpec.nonEmpty =>
val limits = windowExpressions.collect {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ object LimitPushDownThroughWindow extends Rule[LogicalPlan] {
_.containsAllPatterns(WINDOW, LIMIT), ruleId) {
// Adding an extra Limit below WINDOW when the partitionSpec of all window functions is empty.
case LocalLimit(limitExpr @ IntegerLiteral(limit),
window @ Window(windowExpressions, Nil, orderSpec, child))
window @ Window(windowExpressions, Nil, orderSpec, child, _))
if supportsPushdownThroughWindow(windowExpressions) && child.maxRows.forall(_ > limit) &&
limit < conf.topKSortFallbackThreshold =>
// Sort is needed here because we need global sort.
window.copy(child = Limit(limitExpr, Sort(orderSpec, true, child)))
// There is a Project between LocalLimit and Window if they do not have the same output.
case LocalLimit(limitExpr @ IntegerLiteral(limit), project @ Project(_,
window @ Window(windowExpressions, Nil, orderSpec, child)))
window @ Window(windowExpressions, Nil, orderSpec, child, _)))
if supportsPushdownThroughWindow(windowExpressions) && child.maxRows.forall(_ > limit) &&
limit < conf.topKSortFallbackThreshold =>
// Sort is needed here because we need global sort.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ object OptimizeOneRowPlan extends Rule[LogicalPlan] {
val enableForStreaming = conf.getConf(SQLConf.STREAMING_OPTIMIZE_ONE_ROW_PLAN_ENABLED)

plan.transformUpWithPruning(_.containsAnyPattern(SORT, AGGREGATE), ruleId) {
case Sort(_, _, child) if child.maxRows.exists(_ <= 1L) &&
case Sort(_, _, child, _) if child.maxRows.exists(_ <= 1L) &&
isChildEligible(child, enableForStreaming) => child
case Sort(_, false, child) if child.maxRowsPerPartition.exists(_ <= 1L) &&
case Sort(_, false, child, _) if child.maxRowsPerPartition.exists(_ <= 1L) &&
isChildEligible(child, enableForStreaming) => child
case agg @ Aggregate(_, _, child, _) if agg.groupOnly && child.maxRows.exists(_ <= 1L) &&
isChildEligible(child, enableForStreaming) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
return plan
}
plan match {
case Sort(_, _, child) => child
case Sort(_, _, child, _) => child
case Project(fields, child) => Project(fields, removeTopLevelSort(child))
case other => other
}
Expand Down Expand Up @@ -1303,7 +1303,7 @@ object CollapseRepartition extends Rule[LogicalPlan] {
// Case 2: When a RepartitionByExpression has a child of global Sort, Repartition or
// RepartitionByExpression we can remove the child.
case r @ RepartitionByExpression(
_, child @ (Sort(_, true, _) | _: RepartitionOperation), _, _) =>
_, child @ (Sort(_, true, _, _) | _: RepartitionOperation), _, _) =>
r.withNewChildren(child.children)
// Case 3: When a RebalancePartitions has a child of local or global Sort, Repartition or
// RepartitionByExpression we can remove the child.
Expand Down Expand Up @@ -1370,11 +1370,11 @@ object CollapseWindow extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsPattern(WINDOW), ruleId) {
case w1 @ Window(we1, _, _, w2 @ Window(we2, _, _, grandChild))
case w1 @ Window(we1, _, _, w2 @ Window(we2, _, _, grandChild, _), _)
if windowsCompatible(w1, w2) =>
w1.copy(windowExpressions = we2 ++ we1, child = grandChild)

case w1 @ Window(we1, _, _, Project(pl, w2 @ Window(we2, _, _, grandChild)))
case w1 @ Window(we1, _, _, Project(pl, w2 @ Window(we2, _, _, grandChild, _)), _)
if windowsCompatible(w1, w2) && w1.references.subsetOf(grandChild.outputSet) =>
Project(
pl ++ w1.windowOutputSet,
Expand Down Expand Up @@ -1403,11 +1403,11 @@ object TransposeWindow extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsPattern(WINDOW), ruleId) {
case w1 @ Window(_, _, _, w2 @ Window(_, _, _, grandChild))
case w1 @ Window(_, _, _, w2 @ Window(_, _, _, grandChild, _), _)
if windowsCompatible(w1, w2) =>
Project(w1.output, w2.copy(child = w1.copy(child = grandChild)))

case w1 @ Window(_, _, _, Project(pl, w2 @ Window(_, _, _, grandChild)))
case w1 @ Window(_, _, _, Project(pl, w2 @ Window(_, _, _, grandChild, _)), _)
if windowsCompatible(w1, w2) && w1.references.subsetOf(grandChild.outputSet) =>
Project(
pl ++ w1.windowOutputSet,
Expand Down Expand Up @@ -1649,14 +1649,14 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper {
*/
object EliminateSorts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(_.containsPattern(SORT)) {
case s @ Sort(orders, _, child) if orders.isEmpty || orders.exists(_.child.foldable) =>
case s @ Sort(orders, _, child, _) if orders.isEmpty || orders.exists(_.child.foldable) =>
val newOrders = orders.filterNot(_.child.foldable)
if (newOrders.isEmpty) {
child
} else {
s.copy(order = newOrders)
}
case s @ Sort(_, global, child) => s.copy(child = recursiveRemoveSort(child, global))
case s @ Sort(_, global, child, _) => s.copy(child = recursiveRemoveSort(child, global))
case j @ Join(originLeft, originRight, _, cond, _) if cond.forall(_.deterministic) =>
j.copy(left = recursiveRemoveSort(originLeft, true),
right = recursiveRemoveSort(originRight, true))
Expand All @@ -1675,7 +1675,7 @@ object EliminateSorts extends Rule[LogicalPlan] {
return plan
}
plan match {
case Sort(_, global, child) if canRemoveGlobalSort || !global =>
case Sort(_, global, child, _) if canRemoveGlobalSort || !global =>
recursiveRemoveSort(child, canRemoveGlobalSort)
case other if canEliminateSort(other) =>
other.withNewChildren(other.children.map(c => recursiveRemoveSort(c, canRemoveGlobalSort)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ object RemoveRedundantSorts extends Rule[LogicalPlan] {
return plan
}
plan match {
case s @ Sort(orders, false, child) =>
case s @ Sort(orders, false, child, _) =>
if (SortOrder.orderingSatisfies(child.outputOrdering, orders)) {
recursiveRemoveSort(child, optimizeGlobalSort = false)
} else {
s.withNewChildren(Seq(recursiveRemoveSort(child, optimizeGlobalSort = true)))
}

case s @ Sort(orders, true, child) =>
case s @ Sort(orders, true, child, _) =>
val newChild = recursiveRemoveSort(child, optimizeGlobalSort = false)
if (optimizeGlobalSort) {
// For this case, the upper sort is local so the ordering of present sort is unnecessary,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ object PhysicalWindow {
(WindowFunctionType, Seq[NamedExpression], Seq[Expression], Seq[SortOrder], LogicalPlan)

def unapply(a: Any): Option[ReturnType] = a match {
case expr @ logical.Window(windowExpressions, partitionSpec, orderSpec, child) =>
case expr @ logical.Window(windowExpressions, partitionSpec, orderSpec, child, _) =>

// The window expression should not be empty here, otherwise it's a bug.
if (windowExpressions.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -945,7 +945,8 @@ case class WithWindowDefinition(
case class Sort(
order: Seq[SortOrder],
global: Boolean,
child: LogicalPlan) extends UnaryNode {
child: LogicalPlan,
hint: Option[SortHint] = None) extends UnaryNode {
override def output: Seq[Attribute] = child.output
override def maxRows: Option[Long] = child.maxRows
override def maxRowsPerPartition: Option[Long] = {
Expand Down Expand Up @@ -1265,7 +1266,8 @@ case class Window(
windowExpressions: Seq[NamedExpression],
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
child: LogicalPlan) extends UnaryNode {
child: LogicalPlan,
hint: Option[WindowHint] = None) extends UnaryNode {
override def maxRows: Option[Long] = child.maxRows
override def output: Seq[Attribute] =
child.output ++ windowExpressions.map(_.toAttribute)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ case object NO_BROADCAST_AND_REPLICATION extends JoinStrategyHint {

abstract class AggregateHint;

abstract class WindowHint;

abstract class SortHint;

/**
* The callback for implementing customized strategies of handling hint errors.
*/
Expand Down
Loading