From 852acdb3510ab850123d9318a70b1c04e6a0ece1 Mon Sep 17 00:00:00 2001 From: Vladimir Golubev Date: Tue, 4 Mar 2025 13:00:11 +0000 Subject: [PATCH] Normalize out projection added in DeduplicateRelations for union child output deduplication --- .../analysis/DeduplicateRelations.scala | 7 +- .../analysis/resolver/UnionResolver.scala | 78 ++----------------- .../sql/catalyst/plans/NormalizePlan.scala | 7 +- 3 files changed, 18 insertions(+), 74 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala index 8398fb8d1e830..752a2a648ce99 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala @@ -22,9 +22,12 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeReference, AttributeSet, Expression, NamedExpression, OuterReference, SubqueryExpression} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.trees.TreePattern._ object DeduplicateRelations extends Rule[LogicalPlan] { + val PROJECT_FOR_EXPRESSION_ID_DEDUPLICATION = + TreeNodeTag[Unit]("project_for_expression_id_deduplication") type ExprIdMap = mutable.HashMap[Class[_], mutable.HashSet[Long]] @@ -67,7 +70,9 @@ object DeduplicateRelations extends Rule[LogicalPlan] { val projectList = child.output.map { attr => Alias(attr, attr.name)() } - Project(projectList, child) + val project = Project(projectList, child) + project.setTagValue(DeduplicateRelations.PROJECT_FOR_EXPRESSION_ID_DEDUPLICATION, ()) + project } } u.copy(children = newChildren) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnionResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnionResolver.scala index 0e4eed3c20f15..cfd81114b0300 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnionResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnionResolver.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.{ TypeCoercionBase } import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, ExprId} -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Project, Union} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.{DataType, MetadataBuilder} @@ -52,8 +52,6 @@ class UnionResolver(resolver: Resolver, expressionResolver: ExpressionResolver) * for partially resolved subtrees from DataFrame programs. * - Resolve each child in the context of a) New [[NameScope]] b) New [[ExpressionIdAssigner]] * mapping. Collect child outputs to coerce them later. - * - Perform projection-based expression ID deduplication if required. This is a hack to stay - * compatible with fixed-point [[Analyzer]]. * - Perform individual output deduplication to handle the distinct union case described in * [[performIndividualOutputExpressionIdDeduplication]] scaladoc. * - Validate that child outputs have same length or throw "NUM_COLUMNS_MISMATCH" otherwise. @@ -68,10 +66,10 @@ class UnionResolver(resolver: Resolver, expressionResolver: ExpressionResolver) * - Return the resolved [[Union]] with new children. */ override def resolve(unresolvedUnion: Union): Union = { - val (oldOutput, oldChildOutputs) = if (unresolvedUnion.resolved) { - (Some(unresolvedUnion.output), Some(unresolvedUnion.children.map(_.output))) + val oldOutput = if (unresolvedUnion.resolved) { + Some(unresolvedUnion.output) } else { - (None, None) + None } val (resolvedChildren, childOutputs) = unresolvedUnion.children.zipWithIndex.map { @@ -84,16 +82,10 @@ class UnionResolver(resolver: Resolver, expressionResolver: ExpressionResolver) } }.unzip - val (projectBasedDeduplicatedChildren, projectBasedDeduplicatedChildOutputs) = - performProjectionBasedExpressionIdDeduplication( - resolvedChildren, - childOutputs, - oldChildOutputs - ) val (deduplicatedChildren, deduplicatedChildOutputs) = performIndividualOutputExpressionIdDeduplication( - projectBasedDeduplicatedChildren, - projectBasedDeduplicatedChildOutputs + resolvedChildren, + childOutputs ) val (newChildren, newChildOutputs) = if (needToCoerceChildOutputs(deduplicatedChildOutputs)) { @@ -117,64 +109,6 @@ class UnionResolver(resolver: Resolver, expressionResolver: ExpressionResolver) unresolvedUnion.copy(children = newChildren) } - /** - * Fixed-point [[Analyzer]] uses [[DeduplicateRelations]] rule to handle duplicate expression IDs - * in multi-child operator outputs. For [[Union]]s it uses a "projection-based deduplication", - * i.e. places another [[Project]] operator with new [[Alias]]es on the right child if duplicate - * expression IDs detected. New [[Alias]] "covers" the original attribute with new expression ID. - * This is done for all child operators except [[LeafNode]]s. - * - * We don't need this operation in single-pass [[Resolver]], since we have - * [[ExpressionIdAssigner]] for expression ID deduplication, but perform it nevertheless to stay - * compatible with fixed-point [[Analyzer]]. Since new outputs are already deduplicated by - * [[ExpressionIdAssigner]], we check the _old_ outputs for duplicates and place a [[Project]] - * only if old outputs are available (i.e. we are dealing with a resolved subtree from - * DataFrame program). - */ - private def performProjectionBasedExpressionIdDeduplication( - children: Seq[LogicalPlan], - childOutputs: Seq[Seq[Attribute]], - oldChildOutputs: Option[Seq[Seq[Attribute]]] - ): (Seq[LogicalPlan], Seq[Seq[Attribute]]) = { - oldChildOutputs match { - case Some(oldChildOutputs) => - val oldExpressionIds = new HashSet[ExprId] - - children - .zip(childOutputs) - .zip(oldChildOutputs) - .map { - case ((child: LeafNode, output), _) => - (child, output) - - case ((child, output), oldOutput) => - val oldOutputExpressionIds = new HashSet[ExprId] - - val hasConflicting = oldOutput.exists { oldAttribute => - oldOutputExpressionIds.add(oldAttribute.exprId) - oldExpressionIds.contains(oldAttribute.exprId) - } - - if (hasConflicting) { - val newExpressions = output.map { attribute => - Alias(attribute, attribute.name)() - } - ( - Project(projectList = newExpressions, child = child), - newExpressions.map(_.toAttribute) - ) - } else { - oldExpressionIds.addAll(oldOutputExpressionIds) - - (child, output) - } - } - .unzip - case _ => - (children, childOutputs) - } - } - /** * Deduplicate expression IDs at the scope of each individual child output. This is necessary to * handle the following case: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala index 62ef65eb11128..1651003dd7744 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans import java.util.HashMap -import org.apache.spark.sql.catalyst.analysis.GetViewColumnByNameAndOrdinal +import org.apache.spark.sql.catalyst.analysis.{DeduplicateRelations, GetViewColumnByNameAndOrdinal} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.ReplaceExpressions @@ -145,6 +145,11 @@ object NormalizePlan extends PredicateHelper { .sortBy(_.hashCode()) .reduce(And) Join(left, right, newJoinType, Some(newCondition), hint) + case project: Project + if project + .getTagValue(DeduplicateRelations.PROJECT_FOR_EXPRESSION_ID_DEDUPLICATION) + .isDefined => + project.child case Project(projectList, child) => val projList = projectList .map { e =>