From 79a30741473fef653277b386a21f627e3b308691 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 28 Apr 2026 10:40:10 +0000 Subject: [PATCH] [SPARK-56632][SQL][CONNECT] Fix AMBIGUOUS_COLUMN_REFERENCE regression for reused DataFrame in natural join Fix an AMBIGUOUS_COLUMN_REFERENCE regression introduced by SPARK-55070 when a DataFrame is referenced both directly in a join and also nested under a natural/USING join elsewhere in the same plan. Replace the single broadened ancestor walk in `resolveDataFrameColumn` with a two-walk pattern, mirroring the `outputAttributes.resolve orElse outputMetadataAttributes.resolve` precedence in `LogicalPlan.resolve`. Regular access walks first with the strict `p.outputSet` filter; only on no match does it retry with `p.output ++ p.metadataOutput`. Metadata access keeps a single walk filtered by `p.metadataOutput`. Co-authored-by: Isaac Co-authored-by: Wenchen Fan --- python/pyspark/sql/tests/test_column.py | 20 +++++++++ .../analysis/ColumnResolutionHelper.scala | 41 +++++++++++++------ 2 files changed, 49 insertions(+), 12 deletions(-) diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py index 06359854c6d51..c9e3f81969c9f 100644 --- a/python/pyspark/sql/tests/test_column.py +++ b/python/pyspark/sql/tests/test_column.py @@ -558,6 +558,26 @@ def test_select_join_keys(self): self.assertTrue(df1.join(df2, "id", how).select(df1["id"]).count() >= 0, how) self.assertTrue(df1.join(df2, "id", how).select(df2["id"]).count() >= 0, how) + def test_select_regular_column_with_reused_dataframe_hidden_in_natural_join(self): + # A DataFrame appears both as a direct join side and inside a natural/USING + # join that hides one of its columns into `metadataOutput`. When resolving + # `dim["dim_id"]`, two candidates match the plan id: one from `p.output` + # (the direct join side) and one only visible via `p.metadataOutput` (the + # reused `dim` nested under the USING-join wrapper). We should prefer the + # regular candidate and not throw AMBIGUOUS_COLUMN_REFERENCE. + fact = self.spark.createDataFrame([(1, 10, "T1"), (2, 20, "T2")], ["id", "fk", "txn_id"]) + dim = self.spark.createDataFrame([(10, "X"), (20, "Y"), (30, "Z")], ["dim_id", "dim_name"]) + events = self.spark.createDataFrame( + [(10, "T1", 100), (20, "T2", 200)], ["dim_id", "txn_id", "amount"] + ) + enriched = events.join(dim, "dim_id", "left") + result = ( + fact.join(dim, fact["fk"] == dim["dim_id"], "left") + .join(enriched, "txn_id", "full_outer") + .select(dim["dim_id"]) + ) + self.assertEqual(result.count(), 2) + def test_drop_notexistent_col(self): df1 = self.spark.createDataFrame( [("a", "b", "c")], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index 93d71642ac9fd..8d5efefe378f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -527,10 +527,25 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { val planId = planIdOpt.get logDebug(s"Extract plan_id $planId from $u") - val isMetadataAccess = u.containsTag(LogicalPlan.IS_METADATA_COL) - - val (resolved, matched) = resolveDataFrameColumnByPlanId( - u, planId, isMetadataAccess, q, 0) + val (resolved, matched) = if (u.containsTag(LogicalPlan.IS_METADATA_COL)) { + // Metadata access (e.g. `df["_metadata"]`): the resolved attribute lives + // in `p.metadataOutput`, so filter ancestors by `p.metadataOutput`. + resolveDataFrameColumnByPlanId( + u, planId, true, q, 0, plan => AttributeSet(plan.metadataOutput)) + } else { + // Regular access: try the strict `p.outputSet` filter first. + // That drops candidates hidden at an ancestor, e.g. the right side's join + // key after a natural/USING join. Fall back to `p.output ++ p.metadataOutput` + // only when strict resolves nothing, handling the SPARK-55070 + // `rhs["join_key"]` case. Mirrors `outputAttributes.resolve orElse + // outputMetadataAttributes.resolve` in `LogicalPlan.resolve`. + resolveDataFrameColumnByPlanId( + u, planId, false, q, 0, plan => plan.outputSet) match { + case (Some(r), m) => (Some(r), m) + case _ => resolveDataFrameColumnByPlanId(u, planId, false, q, 0, + plan => AttributeSet(plan.output ++ plan.metadataOutput)) + } + } if (!matched) { // Can not find the target plan node with plan id, e.g. // df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]]) @@ -546,9 +561,11 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { id: Long, isMetadataAccess: Boolean, q: Seq[LogicalPlan], - currentDepth: Int): (Option[(NamedExpression, Int)], Boolean) = { + currentDepth: Int, + getAllowed: LogicalPlan => AttributeSet + ): (Option[(NamedExpression, Int)], Boolean) = { val resolved = q.map(resolveDataFrameColumnRecursively( - u, id, isMetadataAccess, _, currentDepth)) + u, id, isMetadataAccess, _, currentDepth, getAllowed)) val merged = resolved .flatMap(_._1) .sortBy(_._2) // sort by depth @@ -566,7 +583,9 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { id: Long, isMetadataAccess: Boolean, p: LogicalPlan, - currentDepth: Int): (Option[(NamedExpression, Int)], Boolean) = { + currentDepth: Int, + getAllowed: LogicalPlan => AttributeSet + ): (Option[(NamedExpression, Int)], Boolean) = { val (resolved, matched) = if (p.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(id)) { val resolved = if (!isMetadataAccess) { p.resolve(u.nameParts, conf.resolver) @@ -585,7 +604,8 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { case _: Union => Seq.empty[LogicalPlan] case _ => p.children } - resolveDataFrameColumnByPlanId(u, id, isMetadataAccess, children, currentDepth + 1) + resolveDataFrameColumnByPlanId( + u, id, isMetadataAccess, children, currentDepth + 1, getAllowed) } // In self join case like: @@ -619,10 +639,7 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { // In this case, resolveDataFrameColumnByPlanId returns None, // the dataframe column 'df.id' will remain unresolved, and the analyzer // will try to resolve 'id' without plan id later. - val filtered = resolved.filter { r => - // A DataFrame column can be resolved as a metadata column, we should keep it. - r._1.references.subsetOf(AttributeSet(p.output ++ p.metadataOutput)) - } + val filtered = resolved.filter { case (r, _) => r.references.subsetOf(getAllowed(p)) } (filtered, matched) }