Skip to content

Commit

Permalink
[SPARK-45509][SQL][3.5] Fix df column reference behavior for Spark Co…
Browse files Browse the repository at this point in the history
…nnect

backport #43465 to 3.5

### What changes were proposed in this pull request?

This PR fixes a few problems of column resolution for Spark Connect, to make the behavior closer to classic Spark SQL (unfortunately we still have some behavior differences in corner cases).
1. resolve df column references in both `resolveExpressionByPlanChildren` and `resolveExpressionByPlanOutput`. Previously it's only in `resolveExpressionByPlanChildren`.
2. when the plan id has multiple matches, fail with `AMBIGUOUS_COLUMN_REFERENCE`

### Why are the changes needed?

fix behavior differences between spark connect and classic spark sql

### Does this PR introduce _any_ user-facing change?

Yes, for spark connect scala client

### How was this patch tested?

new tests

### Was this patch authored or co-authored using generative AI tooling?

no

Closes #43699 from cloud-fan/backport.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
  • Loading branch information
cloud-fan authored and zhengruifeng committed Nov 8, 2023
1 parent eac87e3 commit 85fbb3a
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 42 deletions.
9 changes: 9 additions & 0 deletions common/utils/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@
],
"sqlState" : "42702"
},
"AMBIGUOUS_COLUMN_REFERENCE" : {
"message" : [
"Column <name> is ambiguous. It's because you joined several DataFrame together, and some of these DataFrames are the same.",
"This column points to one of the DataFrame but Spark is unable to figure out which one.",
"Please alias the DataFrames with different names via `DataFrame.alias` before joining them,",
"and specify the column using qualified name, e.g. `df.alias(\"a\").join(df.alias(\"b\"), col(\"a.id\") > col(\"b.id\"))`."
],
"sqlState" : "42702"
},
"AMBIGUOUS_LATERAL_COLUMN_ALIAS" : {
"message" : [
"Lateral column alias <name> is ambiguous and has <n> matches."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,64 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM
assert(joined2.schema.catalogString === "struct<id:bigint,a:double>")
}

test("SPARK-45509: ambiguous column reference") {
val session = spark
import session.implicits._
val df1 = Seq(1 -> "a").toDF("i", "j")
val df1_filter = df1.filter(df1("i") > 0)
val df2 = Seq(2 -> "b").toDF("i", "y")

checkSameResult(
Seq(Row(1)),
// df1("i") is not ambiguous, and it's still valid in the filtered df.
df1_filter.select(df1("i")))

val e1 = intercept[AnalysisException] {
// df1("i") is not ambiguous, but it's not valid in the projected df.
df1.select((df1("i") + 1).as("plus")).select(df1("i")).collect()
}
assert(e1.getMessage.contains("MISSING_ATTRIBUTES.RESOLVED_ATTRIBUTE_MISSING_FROM_INPUT"))

checkSameResult(
Seq(Row(1, "a")),
// All these column references are not ambiguous and are still valid after join.
df1.join(df2, df1("i") + 1 === df2("i")).sort(df1("i").desc).select(df1("i"), df1("j")))

val e2 = intercept[AnalysisException] {
// df1("i") is ambiguous as df1 appears in both join sides.
df1.join(df1, df1("i") === 1).collect()
}
assert(e2.getMessage.contains("AMBIGUOUS_COLUMN_REFERENCE"))

val e3 = intercept[AnalysisException] {
// df1("i") is ambiguous as df1 appears in both join sides.
df1.join(df1).select(df1("i")).collect()
}
assert(e3.getMessage.contains("AMBIGUOUS_COLUMN_REFERENCE"))

val e4 = intercept[AnalysisException] {
// df1("i") is ambiguous as df1 appears in both join sides (df1_filter contains df1).
df1.join(df1_filter, df1("i") === 1).collect()
}
assert(e4.getMessage.contains("AMBIGUOUS_COLUMN_REFERENCE"))

checkSameResult(
Seq(Row("a")),
// df1_filter("i") is not ambiguous as df1_filter does not exist in the join left side.
df1.join(df1_filter, df1_filter("i") === 1).select(df1_filter("j")))

val e5 = intercept[AnalysisException] {
// df1("i") is ambiguous as df1 appears in both sides of the first join.
df1.join(df1_filter, df1_filter("i") === 1).join(df2, df1("i") === 1).collect()
}
assert(e5.getMessage.contains("AMBIGUOUS_COLUMN_REFERENCE"))

checkSameResult(
Seq(Row("a")),
// df1_filter("i") is not ambiguous as df1_filter only appears once.
df1.join(df1_filter).join(df2, df1_filter("i") === 1).select(df1_filter("j")))
}

test("broadcast join") {
withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") {
val left = spark.range(100).select(col("id"), rand(10).as("a"))
Expand Down
9 changes: 9 additions & 0 deletions docs/sql-error-conditions.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,15 @@ See '`<docroot>`/sql-migration-guide.html#query-engine'.

Column or field `<name>` is ambiguous and has `<n>` matches.

### AMBIGUOUS_COLUMN_REFERENCE

[SQLSTATE: 42702](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)

Column `<name>` is ambiguous. It's because you joined several DataFrame together, and some of these DataFrames are the same.
This column points to one of the DataFrame but Spark is unable to figure out which one.
Please alias the DataFrames with different names via `DataFrame.alias` before joining them,
and specify the column using qualified name, e.g. `df.alias("a").join(df.alias("b"), col("a.id") > col("b.id"))`.

### AMBIGUOUS_LATERAL_COLUMN_ALIAS

[SQLSTATE: 42702](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/pandas/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ def symmetric_difference( # type: ignore[override]
sdf_symdiff = sdf_self.union(sdf_other).subtract(sdf_self.intersect(sdf_other))

if sort:
sdf_symdiff = sdf_symdiff.sort(*self._internal.index_spark_columns)
sdf_symdiff = sdf_symdiff.sort(*self._internal.index_spark_column_names)

internal = InternalFrame(
spark_frame=sdf_symdiff,
Expand Down
4 changes: 3 additions & 1 deletion python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2123,7 +2123,9 @@ def __init__(
self._input_grouping_cols = input_grouping_cols
self._other_grouping_cols = other_grouping_cols
self._other = cast(LogicalPlan, other)
self._func = function._build_common_inline_user_defined_function(*cols)
# The function takes entire DataFrame as inputs, no need to do
# column binding (no input columns).
self._func = function._build_common_inline_user_defined_function()

def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.errors.{DataTypeErrorsBase, QueryCompilationErrors}
import org.apache.spark.sql.internal.SQLConf

trait ColumnResolutionHelper extends Logging {
trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {

def conf: SQLConf

Expand Down Expand Up @@ -337,7 +337,7 @@ trait ColumnResolutionHelper extends Logging {
throws: Boolean = false,
allowOuter: Boolean = false): Expression = {
resolveExpression(
expr,
tryResolveColumnByPlanId(expr, plan),
resolveColumnByName = nameParts => {
plan.resolve(nameParts, conf.resolver)
},
Expand All @@ -358,21 +358,8 @@ trait ColumnResolutionHelper extends Logging {
e: Expression,
q: LogicalPlan,
allowOuter: Boolean = false): Expression = {
val newE = if (e.exists(_.getTagValue(LogicalPlan.PLAN_ID_TAG).nonEmpty)) {
// If the TreeNodeTag 'LogicalPlan.PLAN_ID_TAG' is attached, it means that the plan and
// expression are from Spark Connect, and need to be resolved in this way:
// 1, extract the attached plan id from the expression (UnresolvedAttribute only for now);
// 2, top-down traverse the query plan to find the plan node that matches the plan id;
// 3, if can not find the matching node, fail the analysis due to illegal references;
// 4, resolve the expression with the matching node, if any error occurs here, apply the
// old code path;
resolveExpressionByPlanId(e, q)
} else {
e
}

resolveExpression(
newE,
tryResolveColumnByPlanId(e, q),
resolveColumnByName = nameParts => {
q.resolveChildren(nameParts, conf.resolver)
},
Expand All @@ -392,39 +379,46 @@ trait ColumnResolutionHelper extends Logging {
}
}

private def resolveExpressionByPlanId(
// If the TreeNodeTag 'LogicalPlan.PLAN_ID_TAG' is attached, it means that the plan and
// expression are from Spark Connect, and need to be resolved in this way:
// 1. extract the attached plan id from UnresolvedAttribute;
// 2. top-down traverse the query plan to find the plan node that matches the plan id;
// 3. if can not find the matching node, fail the analysis due to illegal references;
// 4. if more than one matching nodes are found, fail due to ambiguous column reference;
// 5. resolve the expression with the matching node, if any error occurs here, return the
// original expression as it is.
private def tryResolveColumnByPlanId(
e: Expression,
q: LogicalPlan): Expression = {
if (!e.exists(_.getTagValue(LogicalPlan.PLAN_ID_TAG).nonEmpty)) {
return e
}

e match {
case u: UnresolvedAttribute =>
resolveUnresolvedAttributeByPlanId(u, q).getOrElse(u)
case _ =>
e.mapChildren(c => resolveExpressionByPlanId(c, q))
}
q: LogicalPlan,
idToPlan: mutable.HashMap[Long, LogicalPlan] = mutable.HashMap.empty): Expression = e match {
case u: UnresolvedAttribute =>
resolveUnresolvedAttributeByPlanId(
u, q, idToPlan: mutable.HashMap[Long, LogicalPlan]
).getOrElse(u)
case _ if e.containsPattern(UNRESOLVED_ATTRIBUTE) =>
e.mapChildren(c => tryResolveColumnByPlanId(c, q, idToPlan))
case _ => e
}

private def resolveUnresolvedAttributeByPlanId(
u: UnresolvedAttribute,
q: LogicalPlan): Option[NamedExpression] = {
q: LogicalPlan,
idToPlan: mutable.HashMap[Long, LogicalPlan]): Option[NamedExpression] = {
val planIdOpt = u.getTagValue(LogicalPlan.PLAN_ID_TAG)
if (planIdOpt.isEmpty) return None
val planId = planIdOpt.get
logDebug(s"Extract plan_id $planId from $u")

val planOpt = q.find(_.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(planId))
if (planOpt.isEmpty) {
// For example:
// df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]])
// df2 = spark.createDataFrame([Row(a = 1, b = 2)]])
// df1.select(df2.a) <- illegal reference df2.a
throw new AnalysisException(s"When resolving $u, " +
s"fail to find subplan with plan_id=$planId in $q")
}
val plan = planOpt.get
val plan = idToPlan.getOrElseUpdate(planId, {
findPlanById(u, planId, q).getOrElse {
// For example:
// df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]])
// df2 = spark.createDataFrame([Row(a = 1, b = 2)]])
// df1.select(df2.a) <- illegal reference df2.a
throw new AnalysisException(s"When resolving $u, " +
s"fail to find subplan with plan_id=$planId in $q")
}
})

try {
plan.resolve(u.nameParts, conf.resolver)
Expand All @@ -434,4 +428,28 @@ trait ColumnResolutionHelper extends Logging {
None
}
}

private def findPlanById(
u: UnresolvedAttribute,
id: Long,
plan: LogicalPlan): Option[LogicalPlan] = {
if (plan.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(id)) {
Some(plan)
} else if (plan.children.length == 1) {
findPlanById(u, id, plan.children.head)
} else if (plan.children.length > 1) {
val matched = plan.children.flatMap(findPlanById(u, id, _))
if (matched.length > 1) {
throw new AnalysisException(
errorClass = "AMBIGUOUS_COLUMN_REFERENCE",
messageParameters = Map("name" -> toSQLId(u.nameParts)),
origin = u.origin
)
} else {
matched.headOption
}
} else {
None
}
}
}

0 comments on commit 85fbb3a

Please sign in to comment.