diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 9bc65ae32a276..2d50fe1a1a1a8 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -28,6 +28,15 @@ ], "sqlState" : "42702" }, + "AMBIGUOUS_COLUMN_REFERENCE" : { + "message" : [ + "Column is ambiguous. It's because you joined several DataFrame together, and some of these DataFrames are the same.", + "This column points to one of the DataFrame but Spark is unable to figure out which one.", + "Please alias the DataFrames with different names via `DataFrame.alias` before joining them,", + "and specify the column using qualified name, e.g. `df.alias(\"a\").join(df.alias(\"b\"), col(\"a.id\") > col(\"b.id\"))`." + ], + "sqlState" : "42702" + }, "AMBIGUOUS_LATERAL_COLUMN_ALIAS" : { "message" : [ "Lateral column alias is ambiguous and has matches." diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index df36b53791a81..feefd19000d1d 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -767,6 +767,64 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM assert(joined2.schema.catalogString === "struct") } + test("SPARK-45509: ambiguous column reference") { + val session = spark + import session.implicits._ + val df1 = Seq(1 -> "a").toDF("i", "j") + val df1_filter = df1.filter(df1("i") > 0) + val df2 = Seq(2 -> "b").toDF("i", "y") + + checkSameResult( + Seq(Row(1)), + // df1("i") is not ambiguous, and it's still valid in the filtered df. + df1_filter.select(df1("i"))) + + val e1 = intercept[AnalysisException] { + // df1("i") is not ambiguous, but it's not valid in the projected df. + df1.select((df1("i") + 1).as("plus")).select(df1("i")).collect() + } + assert(e1.getMessage.contains("MISSING_ATTRIBUTES.RESOLVED_ATTRIBUTE_MISSING_FROM_INPUT")) + + checkSameResult( + Seq(Row(1, "a")), + // All these column references are not ambiguous and are still valid after join. + df1.join(df2, df1("i") + 1 === df2("i")).sort(df1("i").desc).select(df1("i"), df1("j"))) + + val e2 = intercept[AnalysisException] { + // df1("i") is ambiguous as df1 appears in both join sides. + df1.join(df1, df1("i") === 1).collect() + } + assert(e2.getMessage.contains("AMBIGUOUS_COLUMN_REFERENCE")) + + val e3 = intercept[AnalysisException] { + // df1("i") is ambiguous as df1 appears in both join sides. + df1.join(df1).select(df1("i")).collect() + } + assert(e3.getMessage.contains("AMBIGUOUS_COLUMN_REFERENCE")) + + val e4 = intercept[AnalysisException] { + // df1("i") is ambiguous as df1 appears in both join sides (df1_filter contains df1). + df1.join(df1_filter, df1("i") === 1).collect() + } + assert(e4.getMessage.contains("AMBIGUOUS_COLUMN_REFERENCE")) + + checkSameResult( + Seq(Row("a")), + // df1_filter("i") is not ambiguous as df1_filter does not exist in the join left side. + df1.join(df1_filter, df1_filter("i") === 1).select(df1_filter("j"))) + + val e5 = intercept[AnalysisException] { + // df1("i") is ambiguous as df1 appears in both sides of the first join. + df1.join(df1_filter, df1_filter("i") === 1).join(df2, df1("i") === 1).collect() + } + assert(e5.getMessage.contains("AMBIGUOUS_COLUMN_REFERENCE")) + + checkSameResult( + Seq(Row("a")), + // df1_filter("i") is not ambiguous as df1_filter only appears once. + df1.join(df1_filter).join(df2, df1_filter("i") === 1).select(df1_filter("j"))) + } + test("broadcast join") { withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") { val left = spark.range(100).select(col("id"), rand(10).as("a")) diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index 90d21f9758573..0cf05748f58f0 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -55,6 +55,15 @@ See '``/sql-migration-guide.html#query-engine'. Column or field `` is ambiguous and has `` matches. +### AMBIGUOUS_COLUMN_REFERENCE + +[SQLSTATE: 42702](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +Column `` is ambiguous. It's because you joined several DataFrame together, and some of these DataFrames are the same. +This column points to one of the DataFrame but Spark is unable to figure out which one. +Please alias the DataFrames with different names via `DataFrame.alias` before joining them, +and specify the column using qualified name, e.g. `df.alias("a").join(df.alias("b"), col("a.id") > col("b.id"))`. + ### AMBIGUOUS_LATERAL_COLUMN_ALIAS [SQLSTATE: 42702](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) diff --git a/python/pyspark/pandas/indexes/multi.py b/python/pyspark/pandas/indexes/multi.py index dd93e31d0235e..74e0b328e4dfb 100644 --- a/python/pyspark/pandas/indexes/multi.py +++ b/python/pyspark/pandas/indexes/multi.py @@ -815,7 +815,7 @@ def symmetric_difference( # type: ignore[override] sdf_symdiff = sdf_self.union(sdf_other).subtract(sdf_self.intersect(sdf_other)) if sort: - sdf_symdiff = sdf_symdiff.sort(*self._internal.index_spark_columns) + sdf_symdiff = sdf_symdiff.sort(*self._internal.index_spark_column_names) internal = InternalFrame( spark_frame=sdf_symdiff, diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 9af5823dd8b84..b49274e399c48 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -2123,7 +2123,9 @@ def __init__( self._input_grouping_cols = input_grouping_cols self._other_grouping_cols = other_grouping_cols self._other = cast(LogicalPlan, other) - self._func = function._build_common_inline_user_defined_function(*cols) + # The function takes entire DataFrame as inputs, no need to do + # column binding (no input columns). + self._func = function._build_common_inline_user_defined_function() def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index 98cbdea72d53b..c48006286be9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -29,10 +29,10 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util.toPrettySQL -import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.errors.{DataTypeErrorsBase, QueryCompilationErrors} import org.apache.spark.sql.internal.SQLConf -trait ColumnResolutionHelper extends Logging { +trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { def conf: SQLConf @@ -337,7 +337,7 @@ trait ColumnResolutionHelper extends Logging { throws: Boolean = false, allowOuter: Boolean = false): Expression = { resolveExpression( - expr, + tryResolveColumnByPlanId(expr, plan), resolveColumnByName = nameParts => { plan.resolve(nameParts, conf.resolver) }, @@ -358,21 +358,8 @@ trait ColumnResolutionHelper extends Logging { e: Expression, q: LogicalPlan, allowOuter: Boolean = false): Expression = { - val newE = if (e.exists(_.getTagValue(LogicalPlan.PLAN_ID_TAG).nonEmpty)) { - // If the TreeNodeTag 'LogicalPlan.PLAN_ID_TAG' is attached, it means that the plan and - // expression are from Spark Connect, and need to be resolved in this way: - // 1, extract the attached plan id from the expression (UnresolvedAttribute only for now); - // 2, top-down traverse the query plan to find the plan node that matches the plan id; - // 3, if can not find the matching node, fail the analysis due to illegal references; - // 4, resolve the expression with the matching node, if any error occurs here, apply the - // old code path; - resolveExpressionByPlanId(e, q) - } else { - e - } - resolveExpression( - newE, + tryResolveColumnByPlanId(e, q), resolveColumnByName = nameParts => { q.resolveChildren(nameParts, conf.resolver) }, @@ -392,39 +379,46 @@ trait ColumnResolutionHelper extends Logging { } } - private def resolveExpressionByPlanId( + // If the TreeNodeTag 'LogicalPlan.PLAN_ID_TAG' is attached, it means that the plan and + // expression are from Spark Connect, and need to be resolved in this way: + // 1. extract the attached plan id from UnresolvedAttribute; + // 2. top-down traverse the query plan to find the plan node that matches the plan id; + // 3. if can not find the matching node, fail the analysis due to illegal references; + // 4. if more than one matching nodes are found, fail due to ambiguous column reference; + // 5. resolve the expression with the matching node, if any error occurs here, return the + // original expression as it is. + private def tryResolveColumnByPlanId( e: Expression, - q: LogicalPlan): Expression = { - if (!e.exists(_.getTagValue(LogicalPlan.PLAN_ID_TAG).nonEmpty)) { - return e - } - - e match { - case u: UnresolvedAttribute => - resolveUnresolvedAttributeByPlanId(u, q).getOrElse(u) - case _ => - e.mapChildren(c => resolveExpressionByPlanId(c, q)) - } + q: LogicalPlan, + idToPlan: mutable.HashMap[Long, LogicalPlan] = mutable.HashMap.empty): Expression = e match { + case u: UnresolvedAttribute => + resolveUnresolvedAttributeByPlanId( + u, q, idToPlan: mutable.HashMap[Long, LogicalPlan] + ).getOrElse(u) + case _ if e.containsPattern(UNRESOLVED_ATTRIBUTE) => + e.mapChildren(c => tryResolveColumnByPlanId(c, q, idToPlan)) + case _ => e } private def resolveUnresolvedAttributeByPlanId( u: UnresolvedAttribute, - q: LogicalPlan): Option[NamedExpression] = { + q: LogicalPlan, + idToPlan: mutable.HashMap[Long, LogicalPlan]): Option[NamedExpression] = { val planIdOpt = u.getTagValue(LogicalPlan.PLAN_ID_TAG) if (planIdOpt.isEmpty) return None val planId = planIdOpt.get logDebug(s"Extract plan_id $planId from $u") - val planOpt = q.find(_.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(planId)) - if (planOpt.isEmpty) { - // For example: - // df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]]) - // df2 = spark.createDataFrame([Row(a = 1, b = 2)]]) - // df1.select(df2.a) <- illegal reference df2.a - throw new AnalysisException(s"When resolving $u, " + - s"fail to find subplan with plan_id=$planId in $q") - } - val plan = planOpt.get + val plan = idToPlan.getOrElseUpdate(planId, { + findPlanById(u, planId, q).getOrElse { + // For example: + // df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]]) + // df2 = spark.createDataFrame([Row(a = 1, b = 2)]]) + // df1.select(df2.a) <- illegal reference df2.a + throw new AnalysisException(s"When resolving $u, " + + s"fail to find subplan with plan_id=$planId in $q") + } + }) try { plan.resolve(u.nameParts, conf.resolver) @@ -434,4 +428,28 @@ trait ColumnResolutionHelper extends Logging { None } } + + private def findPlanById( + u: UnresolvedAttribute, + id: Long, + plan: LogicalPlan): Option[LogicalPlan] = { + if (plan.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(id)) { + Some(plan) + } else if (plan.children.length == 1) { + findPlanById(u, id, plan.children.head) + } else if (plan.children.length > 1) { + val matched = plan.children.flatMap(findPlanById(u, id, _)) + if (matched.length > 1) { + throw new AnalysisException( + errorClass = "AMBIGUOUS_COLUMN_REFERENCE", + messageParameters = Map("name" -> toSQLId(u.nameParts)), + origin = u.origin + ) + } else { + matched.headOption + } + } else { + None + } + } }