Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SQL][SPARK-2212]Hash Outer Join #1147

Closed
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
leftKeys, rightKeys, buildSide, planLater(left), planLater(right))
condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil

case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) =>
execution.HashOuterJoin(
leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil

case _ => Nil
}
}
Expand Down
194 changes: 192 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,135 @@ case object BuildLeft extends BuildSide
@DeveloperApi
case object BuildRight extends BuildSide

/**
* Constant Value for Binary Join Node
*/
object BinaryJoinNode {
val SINGLE_NULL_LIST = Seq[Row](null)
val EMPTY_NULL_LIST = Seq[Row]()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just do Seq.empty[Row] in the cases where you use this variable. Due to erasure no extra object needs to be allocated.

}

// TODO If join key was null should be considered as equal? In Hive this is configurable.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was mostly a bug that they were forced to continue supporting due to backwards compatibility.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I got it. Will remove the TODO.


/**
* Output the tuples for the matched (with the same join key) join group, base on the join types,
* Both input iterators should be repeatable.
*/
trait BinaryRepeatableIteratorNode extends BinaryNode {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this trait is only used once, what do you think of just putting all of this code in HashOuterJoin?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I am planning to do the sort merge join in next step, and I think it will use this trait.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but we can always add abstraction later when it is needed. In general I think its better to have code be only as general as is currently required and no more so.

self: Product =>

val leftNullRow = new GenericRow(left.output.length)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to do these inside of the closures where they are used so they don't add to the serialized size of the operator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have another concern, if we put the entire row inside of closures, the object left will be serialized as well, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that is likely already happening.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignore my previous comment please.

val rightNullRow = new GenericRow(right.output.length)

val joinedRow = new JoinedRow()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm worried about having a mutable structure that isn't explicitly allocated per partition. @rxin is doing a lot of work trying to make us more efficient by broadcasting closures per job instead of serializing them per task and I think this could break in subtle ways.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I put the joinedRow inside of the closures?, as well as the leftNullRow and rightNullRow? Sorry, I have no idea about the broadcasting closures.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think that would be safer. It'll be kind of a long function, but I think in this case its probably okay to just inline all of the functions you have here into the match so they can reuse things like joinedRow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sharing this, understood it now.


val boundCondition = InterpretedPredicate(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use newPredicate.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I think you should move this into the closure too. When codegen is turned on the code needs to be jit-ed on each machine when the task is run. The code generation logic internally does caching to prevent recompiling the same logic multiple times on the same machine.

condition
.map(c => BindReferences.bindReference(c, left.output ++ right.output))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just pass the input schema to newPredicate instead of binding the references manually (this was all cleaned up by the recent codegen PR).

.getOrElse(Literal(true)))

def condition: Option[Expression]
def joinType: JoinType

// TODO we need to rewrite all of the iterators with our own implementation instead of the scala
// iterator for performance / memory usage reason.

def leftOuterIterator(key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row])
: Iterator[Row] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't usually orphan the return type on its own line. Instead I'd wrap the arguments.

leftIter.iterator.flatMap { l =>
joinedRow.withLeft(l)
var matched = false
(if (!key.anyNull) rightIter else BinaryJoinNode.EMPTY_NULL_LIST).collect {
case r if (boundCondition(joinedRow.withRight(r))) => {
matched = true
joinedRow.copy
}
} ++ BinaryJoinNode.SINGLE_NULL_LIST.collect {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic of using a list with a dummy null value and a collect to do an if is very subtle. (That is what this is doing right?) Is this just ++ if (!matched) Seq(joinedRow.withRight(rightNullRow).copy) else Nil?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's pretty cool implementation, I will update it.

case dummy if (!matched) => {
joinedRow.withRight(rightNullRow).copy
}
}
}
}

// TODO need to unit test this, currently it's the dead code, but should be used in SortMergeJoin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets remove dead code.

def leftSemiIterator(key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row])
: Iterator[Row] = {
leftIter.iterator.filter { l =>
joinedRow.withLeft(l)
(if (!key.anyNull) rightIter else BinaryJoinNode.EMPTY_NULL_LIST).exists {
case r => (boundCondition(joinedRow.withRight(r)))
}
}
}

def rightOuterIterator(key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row])
: Iterator[Row] = {
rightIter.iterator.flatMap{r =>
joinedRow.withRight(r)
var matched = false
(if (!key.anyNull) leftIter else BinaryJoinNode.EMPTY_NULL_LIST).collect {
case l if (boundCondition(joinedRow.withLeft(l))) => {
matched = true
joinedRow.copy
}
} ++ BinaryJoinNode.SINGLE_NULL_LIST.collect {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm having a pretty hard time following the logic of each of these cases that are getting ++ together. I think it might be more clear if you assign each intermediate phase to a variable (e.g., matchedPairs, unmatchedLeftRows...) with a comment about what conditions are being checked, why and what the result it.

case dummy if (!matched) => {
joinedRow.withLeft(leftNullRow).copy
}
}
}
}

def fullOuterIterator(key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row])
: Iterator[Row] = {
if (!key.anyNull) {
val rightMatchedSet = scala.collection.mutable.Set[Int]()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only one space after =

leftIter.iterator.flatMap[Row] { l =>
joinedRow.withLeft(l)
var matched = false
rightIter.zipWithIndex.collect {
case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> {
matched = true
rightMatchedSet.add(idx)
joinedRow.copy
}
} ++ BinaryJoinNode.SINGLE_NULL_LIST.collect {
case dummy if (!matched) => {
joinedRow.withRight(rightNullRow).copy
}
}
} ++ rightIter.zipWithIndex.collect {
case (r, idx) if (!rightMatchedSet.contains(idx)) => {
joinedRow(leftNullRow, r).copy
}
}
} else {
leftIter.iterator.map[Row] { l =>
joinedRow(l, rightNullRow).copy
} ++ rightIter.iterator.map[Row] { r =>
joinedRow(leftNullRow, r).copy
}
}
}

// TODO need to unit test this, currently it's the dead code, but should be used in SortMergeJoin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove dead code.

def innerIterator(key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row])
: Iterator[Row] = {
// ignore the join filter for inner join, we assume it will done in the select filter
if (!key.anyNull) {
leftIter.iterator.flatMap { l =>
joinedRow.withLeft(l)
rightIter.iterator.collect {
case r if boundCondition(joinedRow.withRight(r)) => joinedRow
}
}
} else {
BinaryJoinNode.EMPTY_NULL_LIST.iterator
}
}
}

trait HashJoin {
self: SparkPlan =>

Expand Down Expand Up @@ -72,7 +201,7 @@ trait HashJoin {
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if(!rowKey.anyNull) {
if (!rowKey.anyNull) {
val existingMatchList = hashTable.get(rowKey)
val matchList = if (existingMatchList == null) {
val newMatchList = new ArrayBuffer[Row]()
Expand Down Expand Up @@ -136,6 +265,67 @@ trait HashJoin {
}
}

/**
* :: DeveloperApi ::
* Performs a hash join of two child relations by shuffling the data using the join keys.
* This operator requires loading both tables into memory.
*/
@DeveloperApi
case class HashOuterJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
condition: Option[Expression],
left: SparkPlan,
right: SparkPlan) extends BinaryRepeatableIteratorNode {

override def outputPartitioning: Partitioning = left.outputPartitioning

override def requiredChildDistribution =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

def output = left.output ++ right.output

private[this] def buildHashTable(iter: Iterator[Row], keyGenerator: Projection)
: Map[Row, ArrayBuffer[Row]] = {
// TODO: Use Spark's HashMap implementation.
val hashTable = scala.collection.mutable.Map[Row, ArrayBuffer[Row]]()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably at least be using java.util here. The scala collection library seems to have weird performance sometimes.

while (iter.hasNext) {
val currentRow = iter.next()
val rowKey = keyGenerator(currentRow)

val existingMatchList = hashTable.getOrElseUpdate(rowKey, {new ArrayBuffer[Row]()})
existingMatchList += currentRow.copy()
}

hashTable.toMap[Row, ArrayBuffer[Row]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the extra .toMap? Is this doing a full copy?

}

def execute() = {
left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
// TODO this probably can be replaced by external sort (sort merged join?)
val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output))
val rightHashTable= buildHashTable(rightIter, newProjection(rightKeys, right.output))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Space before =


joinType match {
case LeftOuter => leftHashTable.keysIterator.flatMap { key =>
leftOuterIterator(key, leftHashTable.getOrElse(key, BinaryJoinNode.EMPTY_NULL_LIST),
rightHashTable.getOrElse(key, BinaryJoinNode.EMPTY_NULL_LIST))
}
case RightOuter => rightHashTable.keysIterator.flatMap { key =>
rightOuterIterator(key, leftHashTable.getOrElse(key, BinaryJoinNode.EMPTY_NULL_LIST),
rightHashTable.getOrElse(key, BinaryJoinNode.EMPTY_NULL_LIST))
}
case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key =>
fullOuterIterator(key, leftHashTable.getOrElse(key, BinaryJoinNode.EMPTY_NULL_LIST),
rightHashTable.getOrElse(key, BinaryJoinNode.EMPTY_NULL_LIST))
}
case x => throw new Exception(s"Need to add implementation for $x")
}
}
}
}

/**
* :: DeveloperApi ::
* Performs an inner hash join of two child relations by first shuffling the data using the join
Expand Down Expand Up @@ -189,7 +379,7 @@ case class LeftSemiJoinHash(
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if(!rowKey.anyNull) {
if (!rowKey.anyNull) {
val keyExists = hashSet.contains(rowKey)
if (!keyExists) {
hashSet.add(rowKey)
Expand Down
Loading