diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9518b3d4f29b2..7e0f664d33b49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -988,9 +988,19 @@ class Analyzer(override val catalogManager: CatalogManager) object AddMetadataColumns extends Rule[LogicalPlan] { import org.apache.spark.sql.catalyst.util._ + private def getMetadataAttributes(plan: LogicalPlan): Seq[Attribute] = { + lazy val childMetadataOutput = plan.children.flatMap(_.metadataOutput) + plan.expressions.collect { + case a: Attribute if a.isMetadataCol => a + case a: Attribute if childMetadataOutput.exists(_.exprId == a.exprId) => + childMetadataOutput.find(_.exprId == a.exprId).get + } + } + private def hasMetadataCol(plan: LogicalPlan): Boolean = { + lazy val childMetadataOutput = plan.children.flatMap(_.metadataOutput) plan.expressions.exists(_.find { - case a: Attribute => a.isMetadataCol + case a: Attribute => a.isMetadataCol || childMetadataOutput.exists(_.exprId == a.exprId) case _ => false }.isDefined) } @@ -1006,7 +1016,7 @@ class Analyzer(override val catalogManager: CatalogManager) def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { case node if node.children.nonEmpty && node.resolved && hasMetadataCol(node) => val inputAttrs = AttributeSet(node.children.flatMap(_.output)) - val metaCols = node.expressions.flatMap(_.collect { + val metaCols = getMetadataAttributes(node).flatMap(_.collect { case a: Attribute if a.isMetadataCol && !inputAttrs.contains(a) => a }) if (metaCols.isEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala index 54b01416381c6..629b39afb04e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala @@ -85,7 +85,7 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan => if (!analyzed) { AnalysisHelper.allowInvokingTransformsInAnalyzer { val afterRuleOnChildren = mapChildren(_.resolveOperatorsUp(rule)) - if (self fastEquals afterRuleOnChildren) { + val newNode = if (self fastEquals afterRuleOnChildren) { CurrentOrigin.withOrigin(origin) { rule.applyOrElse(self, identity[LogicalPlan]) } @@ -94,6 +94,8 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan => rule.applyOrElse(afterRuleOnChildren, identity[LogicalPlan]) } } + newNode.copyTagsFrom(this) + newNode } } else { self diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index c2555a1991414..a803fa88ed313 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -477,4 +477,26 @@ class DataFrameJoinSuite extends QueryTest checkAnswer(df3.except(df4), Row(10, 50, 2, Row(10, 50, 2))) } + + test("SPARK-34527: Resolve common columns from USING JOIN") { + val joinDf = testData2.as("testData2").join( + testData3.as("testData3"), usingColumns = Seq("a"), joinType = "fullouter") + val dfQuery = joinDf.select( + $"a", $"testData2.a", $"testData2.b", $"testData3.a", $"testData3.b") + val dfQuery2 = joinDf.select( + $"a", testData2.col("a"), testData2.col("b"), testData3.col("a"), testData3.col("b")) + + Seq(dfQuery, dfQuery2).map { query => + checkAnswer(query, + Seq( + Row(1, 1, 1, 1, null), + Row(1, 1, 2, 1, null), + Row(2, 2, 1, 2, 2), + Row(2, 2, 2, 2, 2), + Row(3, 3, 1, null, null), + Row(3, 3, 2, null, null) + ) + ) + } + } }