Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@ import scala.collection.mutable
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeReference, AttributeSet, Expression, NamedExpression, OuterReference, SubqueryExpression}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.catalyst.trees.TreePattern._

object DeduplicateRelations extends Rule[LogicalPlan] {
val PROJECT_FOR_EXPRESSION_ID_DEDUPLICATION =
TreeNodeTag[Unit]("project_for_expression_id_deduplication")

type ExprIdMap = mutable.HashMap[Class[_], mutable.HashSet[Long]]

Expand Down Expand Up @@ -67,7 +70,9 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
val projectList = child.output.map { attr =>
Alias(attr, attr.name)()
}
Project(projectList, child)
val project = Project(projectList, child)
project.setTagValue(DeduplicateRelations.PROJECT_FOR_EXPRESSION_ID_DEDUPLICATION, ())
project
}
}
u.copy(children = newChildren)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.{
TypeCoercionBase
}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, ExprId}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Project, Union}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.{DataType, MetadataBuilder}

Expand All @@ -52,8 +52,6 @@ class UnionResolver(resolver: Resolver, expressionResolver: ExpressionResolver)
* for partially resolved subtrees from DataFrame programs.
* - Resolve each child in the context of a) New [[NameScope]] b) New [[ExpressionIdAssigner]]
* mapping. Collect child outputs to coerce them later.
* - Perform projection-based expression ID deduplication if required. This is a hack to stay
* compatible with fixed-point [[Analyzer]].
* - Perform individual output deduplication to handle the distinct union case described in
* [[performIndividualOutputExpressionIdDeduplication]] scaladoc.
* - Validate that child outputs have same length or throw "NUM_COLUMNS_MISMATCH" otherwise.
Expand All @@ -68,10 +66,10 @@ class UnionResolver(resolver: Resolver, expressionResolver: ExpressionResolver)
* - Return the resolved [[Union]] with new children.
*/
override def resolve(unresolvedUnion: Union): Union = {
val (oldOutput, oldChildOutputs) = if (unresolvedUnion.resolved) {
(Some(unresolvedUnion.output), Some(unresolvedUnion.children.map(_.output)))
val oldOutput = if (unresolvedUnion.resolved) {
Some(unresolvedUnion.output)
} else {
(None, None)
None
}

val (resolvedChildren, childOutputs) = unresolvedUnion.children.zipWithIndex.map {
Expand All @@ -84,16 +82,10 @@ class UnionResolver(resolver: Resolver, expressionResolver: ExpressionResolver)
}
}.unzip

val (projectBasedDeduplicatedChildren, projectBasedDeduplicatedChildOutputs) =
performProjectionBasedExpressionIdDeduplication(
resolvedChildren,
childOutputs,
oldChildOutputs
)
val (deduplicatedChildren, deduplicatedChildOutputs) =
performIndividualOutputExpressionIdDeduplication(
projectBasedDeduplicatedChildren,
projectBasedDeduplicatedChildOutputs
resolvedChildren,
childOutputs
)

val (newChildren, newChildOutputs) = if (needToCoerceChildOutputs(deduplicatedChildOutputs)) {
Expand All @@ -117,64 +109,6 @@ class UnionResolver(resolver: Resolver, expressionResolver: ExpressionResolver)
unresolvedUnion.copy(children = newChildren)
}

/**
* Fixed-point [[Analyzer]] uses [[DeduplicateRelations]] rule to handle duplicate expression IDs
* in multi-child operator outputs. For [[Union]]s it uses a "projection-based deduplication",
* i.e. places another [[Project]] operator with new [[Alias]]es on the right child if duplicate
* expression IDs detected. New [[Alias]] "covers" the original attribute with new expression ID.
* This is done for all child operators except [[LeafNode]]s.
*
* We don't need this operation in single-pass [[Resolver]], since we have
* [[ExpressionIdAssigner]] for expression ID deduplication, but perform it nevertheless to stay
* compatible with fixed-point [[Analyzer]]. Since new outputs are already deduplicated by
* [[ExpressionIdAssigner]], we check the _old_ outputs for duplicates and place a [[Project]]
* only if old outputs are available (i.e. we are dealing with a resolved subtree from
* DataFrame program).
*/
private def performProjectionBasedExpressionIdDeduplication(
children: Seq[LogicalPlan],
childOutputs: Seq[Seq[Attribute]],
oldChildOutputs: Option[Seq[Seq[Attribute]]]
): (Seq[LogicalPlan], Seq[Seq[Attribute]]) = {
oldChildOutputs match {
case Some(oldChildOutputs) =>
val oldExpressionIds = new HashSet[ExprId]

children
.zip(childOutputs)
.zip(oldChildOutputs)
.map {
case ((child: LeafNode, output), _) =>
(child, output)

case ((child, output), oldOutput) =>
val oldOutputExpressionIds = new HashSet[ExprId]

val hasConflicting = oldOutput.exists { oldAttribute =>
oldOutputExpressionIds.add(oldAttribute.exprId)
oldExpressionIds.contains(oldAttribute.exprId)
}

if (hasConflicting) {
val newExpressions = output.map { attribute =>
Alias(attribute, attribute.name)()
}
(
Project(projectList = newExpressions, child = child),
newExpressions.map(_.toAttribute)
)
} else {
oldExpressionIds.addAll(oldOutputExpressionIds)

(child, output)
}
}
.unzip
case _ =>
(children, childOutputs)
}
}

/**
* Deduplicate expression IDs at the scope of each individual child output. This is necessary to
* handle the following case:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans

import java.util.HashMap

import org.apache.spark.sql.catalyst.analysis.GetViewColumnByNameAndOrdinal
import org.apache.spark.sql.catalyst.analysis.{DeduplicateRelations, GetViewColumnByNameAndOrdinal}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.optimizer.ReplaceExpressions
Expand Down Expand Up @@ -145,6 +145,11 @@ object NormalizePlan extends PredicateHelper {
.sortBy(_.hashCode())
.reduce(And)
Join(left, right, newJoinType, Some(newCondition), hint)
case project: Project
if project
.getTagValue(DeduplicateRelations.PROJECT_FOR_EXPRESSION_ID_DEDUPLICATION)
.isDefined =>
project.child
case Project(projectList, child) =>
val projList = projectList
.map { e =>
Expand Down