Skip to content

Commit

Permalink
support unique join in hive context
Browse files Browse the repository at this point in the history
  • Loading branch information
scwf committed Feb 3, 2015
1 parent 60f67e7 commit b7e89a9
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
// we figure out why this is the case let just ignore all of avro related tests.
".*avro.*",

// Unique joins are weird and will require a lot of hacks (see comments in hive parser).
"uniquejoin",

// Hive seems to get the wrong answer on some outer joins. MySQL agrees with catalyst.
"auto_join29",

Expand Down Expand Up @@ -971,6 +968,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"union_remove_3",
"union_remove_6",
"union_script",
"uniquejoin",
"varchar_2",
"varchar_join1",
"varchar_union1",
Expand Down
68 changes: 41 additions & 27 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.hive

import java.sql.Date
import scala.collection.mutable.ArrayBuffer

import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.ql.Context
Expand Down Expand Up @@ -862,44 +863,57 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case (arg, i) => arg.getText == "TOK_TABREF"
}.map(_._2)

val isPreserved = tableOrdinals.map(i => (i - 1 < 0) || joinArgs(i - 1).getText == "PRESERVE")
val isPreserved = tableOrdinals.map { i =>
if(i == 0) false else joinArgs(i - 1).getText == "PRESERVE"
}
val tables = tableOrdinals.map(i => nodeToRelation(joinArgs(i)))
val joinExpressions = tableOrdinals.map(i => joinArgs(i + 1).getChildren.map(nodeToExpr))
var i = 1
var nextJoinExpression: mutable.Buffer[Expression] = null
var joinConditions = new ArrayBuffer[Expression]()
while(i < joinExpressions.length) {
nextJoinExpression = joinExpressions(i)
val predicates = joinExpressions.take(i).map { exps =>
exps.zip(nextJoinExpression).map {
case (e1, e2) => EqualTo(e1, e2): Expression
}.reduceLeft(And)
}.reduceLeft(Or)
joinConditions += predicates
i = i + 1
}

val joinConditions = joinExpressions.sliding(2).map {
case Seq(c1, c2) =>
val predicates = (c1, c2).zipped.map { case (e1, e2) => EqualTo(e1, e2): Expression }
predicates.reduceLeft(And)
}.toBuffer

val joinType = isPreserved.sliding(2).map {
case Seq(true, true) => FullOuter
case Seq(true, false) => LeftOuter
case Seq(false, true) => RightOuter
case Seq(false, false) => Inner
}.toBuffer

val joinedTables = tables.reduceLeft(Join(_,_, Inner, None))
// Must be transform down.
i = joinConditions.length
val fullOuterJoinedResult = tables.reduceLeft(Join(_,_, FullOuter, None)) transform {
case j: Join =>
i = i - 1
j.copy(
condition = Some(joinConditions(i)))
}

// Must be transform down.
val joinedResult = joinedTables transform {
val fullInnerJoinedResult = tables.reduceLeft(Join(_,_, Inner, None)) transform {
case j: Join =>
j.copy(
condition = Some(joinConditions.remove(joinConditions.length - 1)),
joinType = joinType.remove(joinType.length - 1))
condition = Some(joinConditions.remove(joinConditions.length - 1)))
}

val groups = (0 until joinExpressions.head.size).map(i => Coalesce(joinExpressions.map(_(i))))

// Unique join is not really the same as an outer join so we must group together results where
// the joinExpressions are the same, taking the First of each value is only okay because the
// user of a unique join is implicitly promising that there is only one result.
// TODO: This doesn't actually work since [[Star]] is not a valid aggregate expression.
// instead we should figure out how important supporting this feature is and whether it is
// worth the number of hacks that will be required to implement it. Namely, we need to add
// some sort of mapped star expansion that would expand all child output row to be similarly
// named output expressions where some aggregate expression has been applied (i.e. First).
??? // Aggregate(groups, Star(None, First(_)) :: Nil, joinedResult)
val filterConditions = isPreserved.zip(joinExpressions).flatMap { case (preserved, expressions) =>
if (preserved) {
Seq(IsNotNull(expressions.get(0)))
} else {
None
}
}
if (isPreserved.reduceLeft(_ && _)) {
fullOuterJoinedResult
} else if (filterConditions.isEmpty) {
fullInnerJoinedResult
} else {
Filter(filterConditions.reduceLeft(Or), fullOuterJoinedResult)
}

case Token(allJoinTokens(joinToken),
relation1 ::
Expand Down

0 comments on commit b7e89a9

Please sign in to comment.