From 376691af5f639ca4f7ff07cc9e8f572d53e961bf Mon Sep 17 00:00:00 2001 From: xiaoli Date: Sun, 8 Nov 2015 10:28:12 -0800 Subject: [PATCH] Spark-10838 --- .../sql/catalyst/analysis/Analyzer.scala | 7 +++- .../expressions/namedExpressions.scala | 3 +- .../org/apache/spark/sql/DataFrame.scala | 18 ++++++++-- .../apache/spark/sql/DataFrameJoinSuite.scala | 34 +++++++++++++++++++ 4 files changed, 58 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index beabacfc88e32..6d3467561f371 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -374,7 +374,12 @@ class Analyzer( case a: Attribute => attributeRewrites.get(a).getOrElse(a) } } - j.copy(right = newRight) + val newCondition = j.condition.map ( _.transform { + case a: AttributeReference if a.resolved && a.qualifiers.head == "RIGHT_TREE" => + attributeRewrites.get(a).getOrElse(a).withQualifiers(Nil) + case o => o + }) + j.copy(right = newRight, condition = newCondition ) } // When resolve `SortOrder`s in Sort based on child, don't report errors as diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 8957df0be6814..8d083f6cdcacc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -194,7 +194,8 @@ case class AttributeReference( def sameRef(other: AttributeReference): Boolean = this.exprId == other.exprId override def equals(other: Any): Boolean = other match { - case ar: AttributeReference => name == ar.name && exprId == ar.exprId && dataType == ar.dataType + case ar: AttributeReference => name == ar.name && exprId == ar.exprId && + dataType == ar.dataType && qualifiers == ar.qualifiers case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 174bc6f42ad8d..2cd5911564fb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -560,6 +560,16 @@ class DataFrame private[sql]( * @since 1.3.0 */ def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = { + // Note that ... + val newJoinExprs = joinExprs.expr.transform { + case arLeft: AttributeReference + if arLeft.qualifiers.head == this.hashCode.toString => + arLeft.withQualifiers("LEFT_TREE" :: Nil) + case arRight: AttributeReference + if arRight.qualifiers.head == right.hashCode.toString => + arRight.withQualifiers("RIGHT_TREE" :: Nil) + case o => o + } // Note that in this function, we introduce a hack in the case of self-join to automatically // resolve ambiguous join conditions into ones that might make sense [SPARK-6231]. // Consider this case: df.join(df, df("key") === df("key")) @@ -570,7 +580,7 @@ class DataFrame private[sql]( // Trigger analysis so in the case of self-join, the analyzer will clone the plan. // After the cloning, left and right side will have distinct expression ids. - val plan = Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr)) + val plan = Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(newJoinExprs)) .queryExecution.analyzed.asInstanceOf[Join] // If auto self join alias is disabled, return the plan. @@ -669,7 +679,11 @@ class DataFrame private[sql]( case "*" => Column(ResolvedStar(schema.fieldNames.map(resolve))) case _ => - val expr = resolve(colName) + val expr = resolve(colName) match { + case ar: AttributeReference => + ar.withQualifiers(this.hashCode.toString :: Nil) + case o => o + } Column(expr) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 56ad71ea4f487..50e37700c57e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -106,6 +106,40 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(1, 1, 1, 1) :: Row(2, 1, 2, 2) :: Nil) } + test("[SPARK-10838] self join - conflicting attributes in condition - incorrect result 1") { + val df1 = Seq((1, 3), (2, 1)).toDF("keyCol1", "keyCol2") + val df2 = Seq((1, 4), (2, 1)).toDF("keyCol1", "keyCol3") + + val df3 = df1.join(df2, df1("keyCol1") === df2("keyCol1")).select(df1("keyCol1")) + + checkAnswer( + df3.join(df1, df1("keyCol2") === df3("keyCol1")), + Row(1, 2, 1) :: Nil) + } + + test("[SPARK-10838] self join - conflicting attributes in condition - incorrect result 2") { + val df1 = Seq((1, 3), (2, 1)).toDF("keyCol1", "keyCol2") + val df2 = Seq((1, 4), (2, 1)).toDF("keyCol1", "keyCol3") + + val df3 = df1.join(df2, df1("keyCol1") === df2("keyCol1")).select(df1("keyCol1"), $"keyCol3") + + checkAnswer( + df3.join(df1, df3("keyCol3") === df1("keyCol1") && df1("keyCol1") === df3("keyCol3")), + Row(2, 1, 1, 3) :: Nil) + } + + test("[SPARK-10838] self join - conflicting attributes in condition - exception") { + val df1 = Seq((1, 3), (2, 1)).toDF("keyCol1", "keyCol2") + val df2 = Seq((1, 4), (2, 1)).toDF("keyCol1", "keyCol3") + + val df3 = df1.join(df2, df1("keyCol1") === df2("keyCol1")).select(df1("keyCol1"), $"keyCol3") + val df4 = df2.as("df4") + + checkAnswer( + df3.join(df4, df3("keyCol3") === df4("keyCol1") && df3("keyCol3") === df4("keyCol1")), + Row(2, 1, 1, 4) :: Nil) + } + test("broadcast join hint") { val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value")