Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SQL][SPARK-2212]Hash Outer Join #1147

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
leftKeys, rightKeys, buildSide, planLater(left), planLater(right))
condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil

case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) =>
execution.HashOuterJoin(
leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil

case _ => Nil
}
}
Expand Down
183 changes: 181 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ trait HashJoin {
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if(!rowKey.anyNull) {
if (!rowKey.anyNull) {
val existingMatchList = hashTable.get(rowKey)
val matchList = if (existingMatchList == null) {
val newMatchList = new ArrayBuffer[Row]()
Expand Down Expand Up @@ -136,6 +136,185 @@ trait HashJoin {
}
}

/**
* Constant Value for Binary Join Node
*/
object HashOuterJoin {
val DUMMY_LIST = Seq[Row](null)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets see if we can remove this too. It really obfuscates your logic.

val EMPTY_LIST = Seq[Row]()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use Seq.empty[Row] inline instead of this variable.

}

/**
* :: DeveloperApi ::
* Performs a hash based outer join for two child relations by shuffling the data using
* the join keys. This operator requires loading the associated partition in both side into memory.
*/
@DeveloperApi
case class HashOuterJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
condition: Option[Expression],
left: SparkPlan,
right: SparkPlan) extends BinaryNode {

override def outputPartitioning: Partitioning = left.outputPartitioning

override def requiredChildDistribution =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

def output = left.output ++ right.output

// TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
// iterator for performance purpose.

private[this] def leftOuterIterator(
key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = {
val joinedRow = new JoinedRow()
val rightNullRow = new GenericRow(right.output.length)
val boundCondition =
condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)

leftIter.iterator.flatMap { l =>
joinedRow.withLeft(l)
var matched = false
(if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) =>
matched = true
joinedRow.copy
} else {
Nil
}) ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => {
// HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the tricky way instead of if (!matched) Iterator(joinedRow.withRight(rightNullRow).copy) else Iterator.empty? That seems much clearer to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logically, it is different in your example. The matched in the closure will postpone the matched checking until we finish the iterating for rightIter.collect, but in your version, matched will be always false, since it happens before the iterating for rightIter.collect. Sorry for the confusing. I will think about it how to improve its readability.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wow... thanks for the explanation. That is even more subtle than I thought.

// as we don't know whether we need to append it until finish iterating all of the
// records in right side.
// If we didn't get any proper row, then append a single row with empty right
joinedRow.withRight(rightNullRow).copy
})
}
}

private[this] def rightOuterIterator(
key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = {
val joinedRow = new JoinedRow()
val leftNullRow = new GenericRow(left.output.length)
val boundCondition =
condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)

rightIter.iterator.flatMap { r =>
joinedRow.withRight(r)
var matched = false
(if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) =>
matched = true
joinedRow.copy
} else {
Nil
}) ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => {
// HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
// as we don't know whether we need to append it until finish iterating all of the
// records in left side.
// If we didn't get any proper row, then append a single row with empty left.
joinedRow.withLeft(leftNullRow).copy
})
}
}

private[this] def fullOuterIterator(
key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = {
val joinedRow = new JoinedRow()
val leftNullRow = new GenericRow(left.output.length)
val rightNullRow = new GenericRow(right.output.length)
val boundCondition =
condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)

if (!key.anyNull) {
// Store the positions of records in right, if one of its associated row satisfy
// the join condition.
val rightMatchedSet = scala.collection.mutable.Set[Int]()
leftIter.iterator.flatMap[Row] { l =>
joinedRow.withLeft(l)
var matched = false
rightIter.zipWithIndex.collect {
// 1. For those matched (satisfy the join condition) records with both sides filled,
// append them directly

case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> {
matched = true
// if the row satisfy the join condition, add its index into the matched set
rightMatchedSet.add(idx)
joinedRow.copy
}
} ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => {
// 2. For those unmatched records in left, append additional records with empty right.

// HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
// as we don't know whether we need to append it until finish iterating all
// of the records in right side.
// If we didn't get any proper row, then append a single row with empty right.
joinedRow.withRight(rightNullRow).copy
})
} ++ rightIter.zipWithIndex.collect {
// 3. For those unmatched records in right, append additional records with empty left.

// Re-visiting the records in right, and append additional row with empty left, if its not
// in the matched set.
case (r, idx) if (!rightMatchedSet.contains(idx)) => {
joinedRow(leftNullRow, r).copy
}
}
} else {
leftIter.iterator.map[Row] { l =>
joinedRow(l, rightNullRow).copy
} ++ rightIter.iterator.map[Row] { r =>
joinedRow(leftNullRow, r).copy
}
}
}

private[this] def buildHashTable(
iter: Iterator[Row], keyGenerator: Projection): Map[Row, ArrayBuffer[Row]] = {
// TODO: Use Spark's HashMap implementation.
val hashTable = scala.collection.mutable.Map[Row, ArrayBuffer[Row]]()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably at least be using java.util here. The scala collection library seems to have weird performance sometimes.

while (iter.hasNext) {
val currentRow = iter.next()
val rowKey = keyGenerator(currentRow)

val existingMatchList = hashTable.getOrElseUpdate(rowKey, {new ArrayBuffer[Row]()})
existingMatchList += currentRow.copy()
}

hashTable.toMap[Row, ArrayBuffer[Row]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the extra .toMap? Is this doing a full copy?

}

def execute() = {
left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
// TODO this probably can be replaced by external sort (sort merged join?)
// Build HashMap for current partition in left relation
val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output))
// Build HashMap for current partition in right relation
val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output))

val boundCondition =
condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
joinType match {
case LeftOuter => leftHashTable.keysIterator.flatMap { key =>
leftOuterIterator(key, leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST),
rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST))
}
case RightOuter => rightHashTable.keysIterator.flatMap { key =>
rightOuterIterator(key, leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST),
rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST))
}
case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key =>
fullOuterIterator(key,
leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST),
rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST))
}
case x => throw new Exception(s"Need to add implementation for $x")
}
}
}
}

/**
* :: DeveloperApi ::
* Performs an inner hash join of two child relations by first shuffling the data using the join
Expand Down Expand Up @@ -189,7 +368,7 @@ case class LeftSemiJoinHash(
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if(!rowKey.anyNull) {
if (!rowKey.anyNull) {
val keyExists = hashSet.contains(rowKey)
if (!keyExists) {
hashSet.add(rowKey)
Expand Down
138 changes: 134 additions & 4 deletions sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,17 @@

package org.apache.spark.sql

import org.scalatest.BeforeAndAfterEach

import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner}
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner, LeftSemi}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._

class JoinSuite extends QueryTest {
class JoinSuite extends QueryTest with BeforeAndAfterEach {

// Ensures tables are loaded.
TestData
Expand All @@ -34,6 +40,56 @@ class JoinSuite extends QueryTest {
assert(planned.size === 1)
}

test("join operator selection") {
def assertJoin(sqlString: String, c: Class[_]): Any = {
val rdd = sql(sqlString)
val physical = rdd.queryExecution.sparkPlan
val operators = physical.collect {
case j: ShuffledHashJoin => j
case j: HashOuterJoin => j
case j: LeftSemiJoinHash => j
case j: BroadcastHashJoin => j
case j: LeftSemiJoinBNL => j
case j: CartesianProduct => j
case j: BroadcastNestedLoopJoin => j
}

assert(operators.size === 1)
if (operators(0).getClass() != c) {
fail(s"$sqlString expected operator: $c, but got ${operators(0)}\n physical: \n$physical")
}
}

val cases1 = Seq(
("SELECT * FROM testData left semi join testData2 ON key = a", classOf[LeftSemiJoinHash]),
("SELECT * FROM testData left semi join testData2", classOf[LeftSemiJoinBNL]),
("SELECT * FROM testData join testData2", classOf[CartesianProduct]),
("SELECT * FROM testData join testData2 where key=2", classOf[CartesianProduct]),
("SELECT * FROM testData left join testData2", classOf[CartesianProduct]),
("SELECT * FROM testData right join testData2", classOf[CartesianProduct]),
("SELECT * FROM testData full outer join testData2", classOf[CartesianProduct]),
("SELECT * FROM testData left join testData2 where key=2", classOf[CartesianProduct]),
("SELECT * FROM testData right join testData2 where key=2", classOf[CartesianProduct]),
("SELECT * FROM testData full outer join testData2 where key=2", classOf[CartesianProduct]),
("SELECT * FROM testData join testData2 where key>a", classOf[CartesianProduct]),
("SELECT * FROM testData full outer join testData2 where key>a", classOf[CartesianProduct]),
("SELECT * FROM testData join testData2 ON key = a", classOf[ShuffledHashJoin]),
("SELECT * FROM testData join testData2 ON key = a and key=2", classOf[ShuffledHashJoin]),
("SELECT * FROM testData join testData2 ON key = a where key=2", classOf[ShuffledHashJoin]),
("SELECT * FROM testData left join testData2 ON key = a", classOf[HashOuterJoin]),
("SELECT * FROM testData right join testData2 ON key = a where key=2",
classOf[HashOuterJoin]),
("SELECT * FROM testData right join testData2 ON key = a and key=2",
classOf[HashOuterJoin]),
("SELECT * FROM testData full outer join testData2 ON key = a", classOf[HashOuterJoin]),
("SELECT * FROM testData join testData2 ON key = a", classOf[ShuffledHashJoin]),
("SELECT * FROM testData join testData2 ON key = a and key=2", classOf[ShuffledHashJoin]),
("SELECT * FROM testData join testData2 ON key = a where key=2", classOf[ShuffledHashJoin])
// TODO add BroadcastNestedLoopJoin
)
cases1.foreach { c => assertJoin(c._1, c._2) }
}

test("multiple-key equi-join is hash-join") {
val x = testData2.as('x)
val y = testData2.as('y)
Expand Down Expand Up @@ -114,6 +170,33 @@ class JoinSuite extends QueryTest {
(4, "D", 4, "d") ::
(5, "E", null, null) ::
(6, "F", null, null) :: Nil)

checkAnswer(
upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'n > 1)),
(1, "A", null, null) ::
(2, "B", 2, "b") ::
(3, "C", 3, "c") ::
(4, "D", 4, "d") ::
(5, "E", null, null) ::
(6, "F", null, null) :: Nil)

checkAnswer(
upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'N > 1)),
(1, "A", null, null) ::
(2, "B", 2, "b") ::
(3, "C", 3, "c") ::
(4, "D", 4, "d") ::
(5, "E", null, null) ::
(6, "F", null, null) :: Nil)

checkAnswer(
upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'l > 'L)),
(1, "A", 1, "a") ::
(2, "B", 2, "b") ::
(3, "C", 3, "c") ::
(4, "D", 4, "d") ::
(5, "E", null, null) ::
(6, "F", null, null) :: Nil)
}

test("right outer join") {
Expand All @@ -125,11 +208,38 @@ class JoinSuite extends QueryTest {
(4, "d", 4, "D") ::
(null, null, 5, "E") ::
(null, null, 6, "F") :: Nil)
checkAnswer(
lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'n > 1)),
(null, null, 1, "A") ::
(2, "b", 2, "B") ::
(3, "c", 3, "C") ::
(4, "d", 4, "D") ::
(null, null, 5, "E") ::
(null, null, 6, "F") :: Nil)
checkAnswer(
lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'N > 1)),
(null, null, 1, "A") ::
(2, "b", 2, "B") ::
(3, "c", 3, "C") ::
(4, "d", 4, "D") ::
(null, null, 5, "E") ::
(null, null, 6, "F") :: Nil)
checkAnswer(
lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'l > 'L)),
(1, "a", 1, "A") ::
(2, "b", 2, "B") ::
(3, "c", 3, "C") ::
(4, "d", 4, "D") ::
(null, null, 5, "E") ::
(null, null, 6, "F") :: Nil)
}

test("full outer join") {
val left = upperCaseData.where('N <= 4).as('left)
val right = upperCaseData.where('N >= 3).as('right)
upperCaseData.where('N <= 4).registerAsTable("left")
upperCaseData.where('N >= 3).registerAsTable("right")

val left = UnresolvedRelation(None, "left", None)
val right = UnresolvedRelation(None, "right", None)

checkAnswer(
left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)),
Expand All @@ -139,5 +249,25 @@ class JoinSuite extends QueryTest {
(4, "D", 4, "D") ::
(null, null, 5, "E") ::
(null, null, 6, "F") :: Nil)

checkAnswer(
left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("left.N".attr !== 3))),
(1, "A", null, null) ::
(2, "B", null, null) ::
(3, "C", null, null) ::
(null, null, 3, "C") ::
(4, "D", 4, "D") ::
(null, null, 5, "E") ::
(null, null, 6, "F") :: Nil)

checkAnswer(
left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("right.N".attr !== 3))),
(1, "A", null, null) ::
(2, "B", null, null) ::
(3, "C", null, null) ::
(null, null, 3, "C") ::
(4, "D", 4, "D") ::
(null, null, 5, "E") ::
(null, null, 6, "F") :: Nil)
}
}