Skip to content

Commit

Permalink
[SPARK-36874][SQL] DeduplicateRelations should copy dataset_id tag to…
Browse files Browse the repository at this point in the history
… avoid ambiguous self join

### What changes were proposed in this pull request?

This PR fixes an issue that ambiguous self join can't be detected if the left and right DataFrame are swapped.
This is an example.
```
val df1 = Seq((1, 2, "A1"),(2, 1, "A2")).toDF("key1", "key2", "value")
val df2 = df1.filter($"value" === "A2")

df1.join(df2, df1("key1") === df2("key2")) // Ambiguous self join is detected and AnalysisException is thrown.

df2.join(df1, df1("key1") === df2("key2)) // Ambiguous self join is not detected.
```

The root cause seems that an inner function `collectConflictPlans` in `DeduplicateRelations.` doesn't copy the `dataset_id` tag when it copies a `LogicalPlan`.

### Why are the changes needed?

Bug fix.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

New tests.

Closes #34172 from sarutak/fix-deduplication-issue.

Authored-by: Kousuke Saruta <sarutak@oss.nttdata.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
sarutak authored and cloud-fan committed Oct 5, 2021
1 parent 65eb4a2 commit fa1805d
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 16 deletions.
Expand Up @@ -181,13 +181,16 @@ object DeduplicateRelations extends Rule[LogicalPlan] {

case oldVersion: SerializeFromObject
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
Seq((oldVersion, oldVersion.copy(
serializer = oldVersion.serializer.map(_.newInstance()))))
val newVersion = oldVersion.copy(serializer = oldVersion.serializer.map(_.newInstance()))
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

// Handle projects that create conflicting aliases.
case oldVersion @ Project(projectList, _)
if findAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
Seq((oldVersion, oldVersion.copy(projectList = newAliases(projectList))))
val newVersion = oldVersion.copy(projectList = newAliases(projectList))
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

// Handle projects that create conflicting outer references.
case oldVersion @ Project(projectList, _)
Expand All @@ -197,7 +200,9 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
case o @ OuterReference(a) if conflictingAttributes.contains(a) => Alias(o, a.name)()
case other => other
}
Seq((oldVersion, oldVersion.copy(projectList = aliasedProjectList)))
val newVersion = oldVersion.copy(projectList = aliasedProjectList)
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

// We don't need to search child plan recursively if the projectList of a Project
// is only composed of Alias and doesn't contain any conflicting attributes.
Expand All @@ -209,8 +214,9 @@ object DeduplicateRelations extends Rule[LogicalPlan] {

case oldVersion @ Aggregate(_, aggregateExpressions, _)
if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
Seq((oldVersion, oldVersion.copy(
aggregateExpressions = newAliases(aggregateExpressions))))
val newVersion = oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

// We don't search the child plan recursively for the same reason as the above Project.
case _ @ Aggregate(_, aggregateExpressions, _)
Expand All @@ -219,24 +225,34 @@ object DeduplicateRelations extends Rule[LogicalPlan] {

case oldVersion @ FlatMapGroupsInPandas(_, _, output, _)
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
val newVersion = oldVersion.copy(output = output.map(_.newInstance()))
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

case oldVersion @ FlatMapCoGroupsInPandas(_, _, _, output, _, _)
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
val newVersion = oldVersion.copy(output = output.map(_.newInstance()))
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

case oldVersion @ MapInPandas(_, output, _)
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
val newVersion = oldVersion.copy(output = output.map(_.newInstance()))
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

case oldVersion @ AttachDistributedSequence(sequenceAttr, _)
if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
Seq((oldVersion, oldVersion.copy(sequenceAttr = sequenceAttr.newInstance())))
val newVersion = oldVersion.copy(sequenceAttr = sequenceAttr.newInstance())
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

case oldVersion: Generate
if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
val newOutput = oldVersion.generatorOutput.map(_.newInstance())
Seq((oldVersion, oldVersion.copy(generatorOutput = newOutput)))
val newVersion = oldVersion.copy(generatorOutput = newOutput)
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

case oldVersion: Expand
if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
Expand All @@ -248,16 +264,22 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
attr
}
}
Seq((oldVersion, oldVersion.copy(output = newOutput)))
val newVersion = oldVersion.copy(output = newOutput)
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

case oldVersion @ Window(windowExpressions, _, _, child)
if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
.nonEmpty =>
Seq((oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))))
val newVersion = oldVersion.copy(windowExpressions = newAliases(windowExpressions))
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

case oldVersion @ ScriptTransformation(_, output, _, _)
if AttributeSet(output).intersect(conflictingAttributes).nonEmpty =>
Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
val newVersion = oldVersion.copy(output = output.map(_.newInstance()))
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

case _ => plan.children.flatMap(collectConflictPlans)
}
Expand Down
Expand Up @@ -17,12 +17,15 @@

package org.apache.spark.sql

import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, AttributeReference, PythonUDF, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Expand, Generate, ScriptInputOutputSchema, ScriptTransformation, Window => WindowPlan}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{count, sum}
import org.apache.spark.sql.functions.{count, explode, sum}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SQLTestData.TestData
import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType}

class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession {
import testImplicits._
Expand Down Expand Up @@ -344,4 +347,124 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession {
assertAmbiguousSelfJoin(df1.join(df2).join(df5).join(df4).select(df2("b")))
}
}

test("SPARK-36874: DeduplicateRelations should copy dataset_id tag " +
"to avoid ambiguous self join") {
// Test for Project
val df1 = Seq((1, 2, "A1"), (2, 1, "A2")).toDF("key1", "key2", "value")
val df2 = df1.filter($"value" === "A2")
assertAmbiguousSelfJoin(df1.join(df2, df1("key1") === df2("key2")))
assertAmbiguousSelfJoin(df2.join(df1, df1("key1") === df2("key2")))

// Test for SerializeFromObject
val df3 = spark.sparkContext.parallelize(1 to 10).map(x => (x, x)).toDF
val df4 = df3.filter($"_1" <=> 0)
assertAmbiguousSelfJoin(df3.join(df4, df3("_1") === df4("_2")))
assertAmbiguousSelfJoin(df4.join(df3, df3("_1") === df4("_2")))

// Test For Aggregate
val df5 = df1.groupBy($"key1").agg(count($"value") as "count")
val df6 = df5.filter($"key1" > 0)
assertAmbiguousSelfJoin(df5.join(df6, df5("key1") === df6("count")))
assertAmbiguousSelfJoin(df6.join(df5, df5("key1") === df6("count")))

// Test for MapInPandas
val mapInPandasUDF = PythonUDF("mapInPandasUDF", null,
StructType(Seq(StructField("x", LongType), StructField("y", LongType))),
Seq.empty,
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
true)
val df7 = df1.mapInPandas(mapInPandasUDF)
val df8 = df7.filter($"x" > 0)
assertAmbiguousSelfJoin(df7.join(df8, df7("x") === df8("y")))
assertAmbiguousSelfJoin(df8.join(df7, df7("x") === df8("y")))

// Test for FlatMapGroupsInPandas
val flatMapGroupsInPandasUDF = PythonUDF("flagMapGroupsInPandasUDF", null,
StructType(Seq(StructField("x", LongType), StructField("y", LongType))),
Seq.empty,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
true)
val df9 = df1.groupBy($"key1").flatMapGroupsInPandas(flatMapGroupsInPandasUDF)
val df10 = df9.filter($"x" > 0)
assertAmbiguousSelfJoin(df9.join(df10, df9("x") === df10("y")))
assertAmbiguousSelfJoin(df10.join(df9, df9("x") === df10("y")))

// Test for FlatMapCoGroupsInPandas
val flatMapCoGroupsInPandasUDF = PythonUDF("flagMapCoGroupsInPandasUDF", null,
StructType(Seq(StructField("x", LongType), StructField("y", LongType))),
Seq.empty,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
true)
val df11 = df1.groupBy($"key1").flatMapCoGroupsInPandas(
df1.groupBy($"key2"), flatMapCoGroupsInPandasUDF)
val df12 = df11.filter($"x" > 0)
assertAmbiguousSelfJoin(df11.join(df12, df11("x") === df12("y")))
assertAmbiguousSelfJoin(df12.join(df11, df11("x") === df12("y")))

// Test for AttachDistributedSequence
val df13 = df1.withSequenceColumn("seq")
val df14 = df13.filter($"value" === "A2")
assertAmbiguousSelfJoin(df13.join(df14, df13("key1") === df14("key2")))
assertAmbiguousSelfJoin(df14.join(df13, df13("key1") === df14("key2")))

// Test for Generate
// Ensure that the root of the plan is Generate
val df15 = Seq((1, Seq(1, 2, 3))).toDF("a", "intList").select($"a", explode($"intList"))
.queryExecution.optimizedPlan.find(_.isInstanceOf[Generate]).get.toDF
val df16 = df15.filter($"a" > 0)
assertAmbiguousSelfJoin(df15.join(df16, df15("a") === df16("col")))
assertAmbiguousSelfJoin(df16.join(df15, df15("a") === df16("col")))

// Test for Expand
// Ensure that the root of the plan is Expand
val df17 =
Expand(
Seq(Seq($"key1".expr, $"key2".expr)),
Seq(
AttributeReference("x", IntegerType)(),
AttributeReference("y", IntegerType)()),
df1.queryExecution.logical).toDF
val df18 = df17.filter($"x" > 0)
assertAmbiguousSelfJoin(df17.join(df18, df17("x") === df18("y")))
assertAmbiguousSelfJoin(df18.join(df17, df17("x") === df18("y")))

// Test for Window
val dfWithTS = spark.sql("SELECT timestamp'2021-10-15 01:52:00' time, 1 a, 2 b")
// Ensure that the root of the plan is Window
val df19 = WindowPlan(
Seq(Alias(dfWithTS("time").expr, "ts")()),
Seq(dfWithTS("a").expr),
Seq(SortOrder(dfWithTS("a").expr, Ascending)),
dfWithTS.queryExecution.logical).toDF
val df20 = df19.filter($"a" > 0)
assertAmbiguousSelfJoin(df19.join(df20, df19("a") === df20("b")))
assertAmbiguousSelfJoin(df20.join(df19, df19("a") === df20("b")))

// Test for ScriptTransformation
val ioSchema =
ScriptInputOutputSchema(
Seq(("TOK_TABLEROWFORMATFIELD", ","),
("TOK_TABLEROWFORMATCOLLITEMS", "#"),
("TOK_TABLEROWFORMATMAPKEYS", "@"),
("TOK_TABLEROWFORMATNULL", "null"),
("TOK_TABLEROWFORMATLINES", "\n")),
Seq(("TOK_TABLEROWFORMATFIELD", ","),
("TOK_TABLEROWFORMATCOLLITEMS", "#"),
("TOK_TABLEROWFORMATMAPKEYS", "@"),
("TOK_TABLEROWFORMATNULL", "null"),
("TOK_TABLEROWFORMATLINES", "\n")), None, None,
List.empty, List.empty, None, None, false)
// Ensure that the root of the plan is ScriptTransformation
val df21 = ScriptTransformation(
"cat",
Seq(
AttributeReference("x", IntegerType)(),
AttributeReference("y", IntegerType)()),
df1.queryExecution.logical,
ioSchema).toDF
val df22 = df21.filter($"x" > 0)
assertAmbiguousSelfJoin(df21.join(df22, df21("x") === df22("y")))
assertAmbiguousSelfJoin(df22.join(df21, df21("x") === df22("y")))
}
}

0 comments on commit fa1805d

Please sign in to comment.