diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinarydHashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinarydHashJoinNode.scala new file mode 100644 index 0000000000000..ce45449c5e798 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinarydHashJoinNode.scala @@ -0,0 +1,103 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.joins.{HashedRelation, BuildLeft, BuildRight, BuildSide} + +/** + * A wrapper of [[HashJoinNode]]. It will build the [[HashedRelation]] according to the value of + * `buildSide`. The actual work of this node will be delegated to the [[HashJoinNode]] + * that is created in `open`. + */ +case class BinarydHashJoinNode( + conf: SQLConf, + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + buildSide: BuildSide, + left: LocalNode, + right: LocalNode) extends BinaryLocalNode(conf) { + + private[this] lazy val (buildNode, buildKeys, streamedNode, streamedKeys) = buildSide match { + case BuildLeft => (left, leftKeys, right, rightKeys) + case BuildRight => (right, rightKeys, left, leftKeys) + } + + private[this] var hashJoinNode: HashJoinNode = _ + + override def output: Seq[Attribute] = left.output ++ right.output + + private[this] def isUnsafeMode: Boolean = { + (codegenEnabled && unsafeEnabled && UnsafeProjection.canSupport(buildKeys)) + } + + private[this] def buildSideKeyGenerator: Projection = { + if (isUnsafeMode) { + UnsafeProjection.create(buildKeys, buildNode.output) + } else { + newMutableProjection(buildKeys, buildNode.output)() + } + } + + override def open(): Unit = { + // buildNode's prepare has been called in this.prepare. + buildNode.open() + val hashedRelation = HashedRelation(buildNode, buildSideKeyGenerator) + // We have built the HashedRelation. So, close buildNode. + buildNode.close() + + // Call the open of streamedNode. + streamedNode.open() + // Create the HashJoinNode based on the streamedNode and HashedRelation. + hashJoinNode = + HashJoinNode( + conf = conf, + streamedKeys = streamedKeys, + streamedNode = streamedNode, + buildSide = buildSide, + buildOutput = buildNode.output, + hashedRelation = hashedRelation, + isWrapped = true) + // Setup this HashJoinNode. We still call these in case there is any setup work + // that needs to be done in this HashJoinNode. Because isWrapped is true, + // prepare and open will not propagate to the child of streamedNode. + hashJoinNode.prepare() + hashJoinNode.open() + } + + override def next(): Boolean = { + hashJoinNode.next() + } + + override def fetch(): InternalRow = { + hashJoinNode.fetch() + } + + override def close(): Unit = { + // Close the internal HashJoinNode. We still call this in case there is any teardown work + // that needs to be done in this HashJoinNode. Because isWrapped is true, + // prepare and open will not propagate to the child of streamedNode. + hashJoinNode.close() + // Now, close the streamedNode. + streamedNode.close() + // Please note that we do not need to call the close method of our buildNode because + // it has been called in this.open. + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala new file mode 100644 index 0000000000000..7f129a5a21872 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala @@ -0,0 +1,87 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide, HashedRelation} + +/** + * A wrapper of [[HashJoinNode]] for broadcast join. The actual work of this node will be + * delegated to the [[HashJoinNode]] that is created in `open`. + */ +case class BroadcastHashJoinNode( + conf: SQLConf, + streamedKeys: Seq[Expression], + streamedNode: LocalNode, + buildSide: BuildSide, + buildOutput: Seq[Attribute], + hashedRelation: Broadcast[HashedRelation]) extends UnaryLocalNode(conf) { + + override val child = streamedNode + + private[this] var hashJoinNode: HashJoinNode = _ + + // Because we do not pass in the buildNode, we take the output of buildNode to + // create the inputSet properly. + override def inputSet: AttributeSet = AttributeSet(child.output ++ buildOutput) + + override def output: Seq[Attribute] = buildSide match { + case BuildRight => streamedNode.output ++ buildOutput + case BuildLeft => buildOutput ++ streamedNode.output + } + + override def open(): Unit = { + // Call the open of streamedNode. + streamedNode.open() + // Create the HashJoinNode based on the streamedNode and HashedRelation. + hashJoinNode = + HashJoinNode( + conf = conf, + streamedKeys = streamedKeys, + streamedNode = streamedNode, + buildSide = buildSide, + buildOutput = buildOutput, + hashedRelation = hashedRelation.value, + isWrapped = true) + // Setup this HashJoinNode. We still call these in case there is any setup work + // that needs to be done in this HashJoinNode. Because isWrapped is true, + // prepare and open will not propagate to the child of streamedNode. + hashJoinNode.prepare() + hashJoinNode.open() + } + + override def next(): Boolean = { + hashJoinNode.next() + } + + override def fetch(): InternalRow = { + hashJoinNode.fetch() + } + + override def close(): Unit = { + // Close the internal HashJoinNode. We still call this in case there is any teardown work + // that needs to be done in this HashJoinNode. Because isWrapped is true, + // prepare and open will not propagate to the child of streamedNode. + hashJoinNode.close() + // Now, close the streamedNode. + streamedNode.close() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala index 10e341a1ecb9d..504c54d8e8ec2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala @@ -24,6 +24,14 @@ import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.execution.metric.SQLMetrics /** + * A node for inner hash equi-join. It can be used individually or wrapped by other + * inner hash equi-join nodes such as [[BinarydHashJoinNode]]. This node takes a already + * built [[HashedRelation]] and a [[LocalNode]] representing the streamed side. + * If this node is used individually, `isWrapped` should be set to false. + * If this node is wrapped in another node, `isWrapped` should be set to true + * and the node wrapping this node should call `prepare`, `open`, and `close` on + * the `streamedNode`. + * * Much of this code is similar to [[org.apache.spark.sql.execution.joins.HashJoin]]. */ case class HashJoinNode( @@ -32,7 +40,8 @@ case class HashJoinNode( streamedNode: LocalNode, buildSide: BuildSide, buildOutput: Seq[Attribute], - hashedRelation: HashedRelation) extends UnaryLocalNode(conf) { + hashedRelation: HashedRelation, + isWrapped: Boolean) extends UnaryLocalNode(conf) { override val child = streamedNode @@ -55,7 +64,6 @@ case class HashJoinNode( private[this] val hashed: HashedRelation = hashedRelation private[this] var joinKeys: Projection = _ - private[this] def isUnsafeMode: Boolean = { (codegenEnabled && unsafeEnabled && UnsafeProjection.canSupport(schema)) } @@ -68,8 +76,18 @@ case class HashJoinNode( } } + override def prepare(): Unit = { + if (!isWrapped) { + // This node is used individually, we should propagate prepare call. + super.prepare() + } + } + override def open(): Unit = { - streamedNode.open() + if (!isWrapped) { + // This node is used individually, we should propagate open call. + streamedNode.open() + } joinRow = new JoinedRow resultProjection = { if (isUnsafeMode) { @@ -123,6 +141,9 @@ case class HashJoinNode( } override def close(): Unit = { - streamedNode.close() + if (!isWrapped) { + // This node is used individually, we should propagate close call. + streamedNode.close() + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala index b37492b23d7a6..901217bd044c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala @@ -21,9 +21,10 @@ import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, Expression} import org.apache.spark.sql.execution.joins.{HashedRelation, BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.test.SharedSQLContext -class HashJoinNodeSuite extends LocalNodeTest { +class HashJoinNodeSuite extends LocalNodeTest with SharedSQLContext { // Test all combinations of the two dimensions: with/out unsafe and build sides private val maybeUnsafeAndCodegen = Seq(false, true) @@ -86,7 +87,12 @@ class HashJoinNodeSuite extends LocalNodeTest { val rightInputMap = rightInput.toMap val leftNode = new DummyNode(joinNameAttributes, leftInput) val rightNode = new DummyNode(joinNicknameAttributes, rightInput) - val makeNode = (node1: LocalNode, node2: LocalNode) => { + val makeBinaryHashJoinNode = (node1: LocalNode, node2: LocalNode) => { + val binaryHashJoinNode = + BinarydHashJoinNode(conf, Seq('id1), Seq('id2), buildSide, node1, node2) + resolveExpressions(binaryHashJoinNode) + } + val makeBroadcastJoinNode = (node1: LocalNode, node2: LocalNode) => { val leftKeys = Seq('id1.attr) val rightKeys = Seq('id2.attr) // Figure out the build side and stream side. @@ -98,28 +104,33 @@ class HashJoinNodeSuite extends LocalNodeTest { val resolvedBuildNode = resolveExpressions(buildNode) val resolvedBuildKeys = resolveExpressions(buildKeys, resolvedBuildNode) val hashedRelation = buildHashedRelation(conf, resolvedBuildKeys, resolvedBuildNode) + val broadcastHashedRelation = sqlContext.sparkContext.broadcast(hashedRelation) - // Build the HashJoinNode. val hashJoinNode = - HashJoinNode( + BroadcastHashJoinNode( conf, streamedKeys, streamedNode, buildSide, resolvedBuildNode.output, - hashedRelation) + broadcastHashedRelation) resolveExpressions(hashJoinNode) } - val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode - val hashJoinNode = makeUnsafeNode(leftNode, rightNode) + val expectedOutput = leftInput .filter { case (k, _) => rightInputMap.contains(k) } .map { case (k, v) => (k, v, k, rightInputMap(k)) } - val actualOutput = hashJoinNode.collect().map { row => - // (id, name, id, nickname) - (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3)) + + Seq(makeBinaryHashJoinNode, makeBroadcastJoinNode).foreach { makeNode => + val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode + val hashJoinNode = makeUnsafeNode(leftNode, rightNode) + + val actualOutput = hashJoinNode.collect().map { row => + // (id, name, id, nickname) + (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3)) + } + assert(actualOutput === expectedOutput) } - assert(actualOutput === expectedOutput) } test(s"$testNamePrefix: empty") {