Skip to content

Commit

Permalink
Add BinarydHashJoinNode and BroadcastHashJoinNode.
Browse files Browse the repository at this point in the history
  • Loading branch information
yhuai committed Oct 2, 2015
1 parent 0cf8d44 commit f262b36
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.local

import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.joins.{HashedRelation, BuildLeft, BuildRight, BuildSide}

/**
* A wrapper of [[HashJoinNode]]. It will build the [[HashedRelation]] according to the value of
* `buildSide`. The actual work of this node will be delegated to the [[HashJoinNode]]
* that is created in `open`.
*/
case class BinarydHashJoinNode(
conf: SQLConf,
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
buildSide: BuildSide,
left: LocalNode,
right: LocalNode) extends BinaryLocalNode(conf) {

private[this] lazy val (buildNode, buildKeys, streamedNode, streamedKeys) = buildSide match {
case BuildLeft => (left, leftKeys, right, rightKeys)
case BuildRight => (right, rightKeys, left, leftKeys)
}

private[this] var hashJoinNode: HashJoinNode = _

override def output: Seq[Attribute] = left.output ++ right.output

private[this] def isUnsafeMode: Boolean = {
(codegenEnabled && unsafeEnabled && UnsafeProjection.canSupport(buildKeys))
}

private[this] def buildSideKeyGenerator: Projection = {
if (isUnsafeMode) {
UnsafeProjection.create(buildKeys, buildNode.output)
} else {
newMutableProjection(buildKeys, buildNode.output)()
}
}

override def open(): Unit = {
// buildNode's prepare has been called in this.prepare.
buildNode.open()
val hashedRelation = HashedRelation(buildNode, buildSideKeyGenerator)
// We have built the HashedRelation. So, close buildNode.
buildNode.close()

// Call the open of streamedNode.
streamedNode.open()
// Create the HashJoinNode based on the streamedNode and HashedRelation.
hashJoinNode =
HashJoinNode(
conf = conf,
streamedKeys = streamedKeys,
streamedNode = streamedNode,
buildSide = buildSide,
buildOutput = buildNode.output,
hashedRelation = hashedRelation,
isWrapped = true)
// Setup this HashJoinNode. We still call these in case there is any setup work
// that needs to be done in this HashJoinNode. Because isWrapped is true,
// prepare and open will not propagate to the child of streamedNode.
hashJoinNode.prepare()
hashJoinNode.open()
}

override def next(): Boolean = {
hashJoinNode.next()
}

override def fetch(): InternalRow = {
hashJoinNode.fetch()
}

override def close(): Unit = {
// Close the internal HashJoinNode. We still call this in case there is any teardown work
// that needs to be done in this HashJoinNode. Because isWrapped is true,
// prepare and open will not propagate to the child of streamedNode.
hashJoinNode.close()
// Now, close the streamedNode.
streamedNode.close()
// Please note that we do not need to call the close method of our buildNode because
// it has been called in this.open.
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.local

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide, HashedRelation}

/**
* A wrapper of [[HashJoinNode]] for broadcast join. The actual work of this node will be
* delegated to the [[HashJoinNode]] that is created in `open`.
*/
case class BroadcastHashJoinNode(
conf: SQLConf,
streamedKeys: Seq[Expression],
streamedNode: LocalNode,
buildSide: BuildSide,
buildOutput: Seq[Attribute],
hashedRelation: Broadcast[HashedRelation]) extends UnaryLocalNode(conf) {

override val child = streamedNode

private[this] var hashJoinNode: HashJoinNode = _

// Because we do not pass in the buildNode, we take the output of buildNode to
// create the inputSet properly.
override def inputSet: AttributeSet = AttributeSet(child.output ++ buildOutput)

override def output: Seq[Attribute] = buildSide match {
case BuildRight => streamedNode.output ++ buildOutput
case BuildLeft => buildOutput ++ streamedNode.output
}

override def open(): Unit = {
// Call the open of streamedNode.
streamedNode.open()
// Create the HashJoinNode based on the streamedNode and HashedRelation.
hashJoinNode =
HashJoinNode(
conf = conf,
streamedKeys = streamedKeys,
streamedNode = streamedNode,
buildSide = buildSide,
buildOutput = buildOutput,
hashedRelation = hashedRelation.value,
isWrapped = true)
// Setup this HashJoinNode. We still call these in case there is any setup work
// that needs to be done in this HashJoinNode. Because isWrapped is true,
// prepare and open will not propagate to the child of streamedNode.
hashJoinNode.prepare()
hashJoinNode.open()
}

override def next(): Boolean = {
hashJoinNode.next()
}

override def fetch(): InternalRow = {
hashJoinNode.fetch()
}

override def close(): Unit = {
// Close the internal HashJoinNode. We still call this in case there is any teardown work
// that needs to be done in this HashJoinNode. Because isWrapped is true,
// prepare and open will not propagate to the child of streamedNode.
hashJoinNode.close()
// Now, close the streamedNode.
streamedNode.close()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.execution.metric.SQLMetrics

/**
* A node for inner hash equi-join. It can be used individually or wrapped by other
* inner hash equi-join nodes such as [[BinarydHashJoinNode]]. This node takes a already
* built [[HashedRelation]] and a [[LocalNode]] representing the streamed side.
* If this node is used individually, `isWrapped` should be set to false.
* If this node is wrapped in another node, `isWrapped` should be set to true
* and the node wrapping this node should call `prepare`, `open`, and `close` on
* the `streamedNode`.
*
* Much of this code is similar to [[org.apache.spark.sql.execution.joins.HashJoin]].
*/
case class HashJoinNode(
Expand All @@ -32,7 +40,8 @@ case class HashJoinNode(
streamedNode: LocalNode,
buildSide: BuildSide,
buildOutput: Seq[Attribute],
hashedRelation: HashedRelation) extends UnaryLocalNode(conf) {
hashedRelation: HashedRelation,
isWrapped: Boolean) extends UnaryLocalNode(conf) {

override val child = streamedNode

Expand All @@ -55,7 +64,6 @@ case class HashJoinNode(
private[this] val hashed: HashedRelation = hashedRelation
private[this] var joinKeys: Projection = _


private[this] def isUnsafeMode: Boolean = {
(codegenEnabled && unsafeEnabled && UnsafeProjection.canSupport(schema))
}
Expand All @@ -68,8 +76,18 @@ case class HashJoinNode(
}
}

override def prepare(): Unit = {
if (!isWrapped) {
// This node is used individually, we should propagate prepare call.
super.prepare()
}
}

override def open(): Unit = {
streamedNode.open()
if (!isWrapped) {
// This node is used individually, we should propagate open call.
streamedNode.open()
}
joinRow = new JoinedRow
resultProjection = {
if (isUnsafeMode) {
Expand Down Expand Up @@ -123,6 +141,9 @@ case class HashJoinNode(
}

override def close(): Unit = {
streamedNode.close()
if (!isWrapped) {
// This node is used individually, we should propagate close call.
streamedNode.close()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, Expression}
import org.apache.spark.sql.execution.joins.{HashedRelation, BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.test.SharedSQLContext


class HashJoinNodeSuite extends LocalNodeTest {
class HashJoinNodeSuite extends LocalNodeTest with SharedSQLContext {

// Test all combinations of the two dimensions: with/out unsafe and build sides
private val maybeUnsafeAndCodegen = Seq(false, true)
Expand Down Expand Up @@ -86,7 +87,12 @@ class HashJoinNodeSuite extends LocalNodeTest {
val rightInputMap = rightInput.toMap
val leftNode = new DummyNode(joinNameAttributes, leftInput)
val rightNode = new DummyNode(joinNicknameAttributes, rightInput)
val makeNode = (node1: LocalNode, node2: LocalNode) => {
val makeBinaryHashJoinNode = (node1: LocalNode, node2: LocalNode) => {
val binaryHashJoinNode =
BinarydHashJoinNode(conf, Seq('id1), Seq('id2), buildSide, node1, node2)
resolveExpressions(binaryHashJoinNode)
}
val makeBroadcastJoinNode = (node1: LocalNode, node2: LocalNode) => {
val leftKeys = Seq('id1.attr)
val rightKeys = Seq('id2.attr)
// Figure out the build side and stream side.
Expand All @@ -98,28 +104,33 @@ class HashJoinNodeSuite extends LocalNodeTest {
val resolvedBuildNode = resolveExpressions(buildNode)
val resolvedBuildKeys = resolveExpressions(buildKeys, resolvedBuildNode)
val hashedRelation = buildHashedRelation(conf, resolvedBuildKeys, resolvedBuildNode)
val broadcastHashedRelation = sqlContext.sparkContext.broadcast(hashedRelation)

// Build the HashJoinNode.
val hashJoinNode =
HashJoinNode(
BroadcastHashJoinNode(
conf,
streamedKeys,
streamedNode,
buildSide,
resolvedBuildNode.output,
hashedRelation)
broadcastHashedRelation)
resolveExpressions(hashJoinNode)
}
val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode
val hashJoinNode = makeUnsafeNode(leftNode, rightNode)

val expectedOutput = leftInput
.filter { case (k, _) => rightInputMap.contains(k) }
.map { case (k, v) => (k, v, k, rightInputMap(k)) }
val actualOutput = hashJoinNode.collect().map { row =>
// (id, name, id, nickname)
(row.getInt(0), row.getString(1), row.getInt(2), row.getString(3))

Seq(makeBinaryHashJoinNode, makeBroadcastJoinNode).foreach { makeNode =>
val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode
val hashJoinNode = makeUnsafeNode(leftNode, rightNode)

val actualOutput = hashJoinNode.collect().map { row =>
// (id, name, id, nickname)
(row.getInt(0), row.getString(1), row.getInt(2), row.getString(3))
}
assert(actualOutput === expectedOutput)
}
assert(actualOutput === expectedOutput)
}

test(s"$testNamePrefix: empty") {
Expand Down

0 comments on commit f262b36

Please sign in to comment.