From b9d5bc9efcbd3eae8ce6a114cf5fc16864ccbe0c Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sat, 3 Oct 2015 21:01:13 -0700 Subject: [PATCH] Use withHashedRelation to set the HashedRelation used by a HashJoinNode at runtime. --- ...oinNode.scala => BinaryHashJoinNode.scala} | 25 +++++++++---------- .../local/BroadcastHashJoinNode.scala | 22 ++++++++-------- .../sql/execution/local/HashJoinNode.scala | 10 +++++--- .../execution/local/HashJoinNodeSuite.scala | 2 +- 4 files changed, 31 insertions(+), 28 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/local/{BinarydHashJoinNode.scala => BinaryHashJoinNode.scala} (88%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinarydHashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala similarity index 88% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinarydHashJoinNode.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala index ce45449c5e798..6c1cf90823cdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinarydHashJoinNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.joins.{HashedRelation, BuildLeft, BuildRig * `buildSide`. The actual work of this node will be delegated to the [[HashJoinNode]] * that is created in `open`. */ -case class BinarydHashJoinNode( +case class BinaryHashJoinNode( conf: SQLConf, leftKeys: Seq[Expression], rightKeys: Seq[Expression], @@ -40,8 +40,15 @@ case class BinarydHashJoinNode( case BuildRight => (right, rightKeys, left, leftKeys) } - private[this] var hashJoinNode: HashJoinNode = _ - + private[this] val hashJoinNode: HashJoinNode = { + HashJoinNode( + conf = conf, + streamedKeys = streamedKeys, + streamedNode = streamedNode, + buildSide = buildSide, + buildOutput = buildNode.output, + isWrapped = true) + } override def output: Seq[Attribute] = left.output ++ right.output private[this] def isUnsafeMode: Boolean = { @@ -65,16 +72,8 @@ case class BinarydHashJoinNode( // Call the open of streamedNode. streamedNode.open() - // Create the HashJoinNode based on the streamedNode and HashedRelation. - hashJoinNode = - HashJoinNode( - conf = conf, - streamedKeys = streamedKeys, - streamedNode = streamedNode, - buildSide = buildSide, - buildOutput = buildNode.output, - hashedRelation = hashedRelation, - isWrapped = true) + // Set the HashedRelation used by the HashJoinNode. + hashJoinNode.withHashedRelation(hashedRelation) // Setup this HashJoinNode. We still call these in case there is any setup work // that needs to be done in this HashJoinNode. Because isWrapped is true, // prepare and open will not propagate to the child of streamedNode. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala index 7f129a5a21872..00d3c5f0096c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala @@ -37,7 +37,15 @@ case class BroadcastHashJoinNode( override val child = streamedNode - private[this] var hashJoinNode: HashJoinNode = _ + private[this] val hashJoinNode: HashJoinNode = { + HashJoinNode( + conf = conf, + streamedKeys = streamedKeys, + streamedNode = streamedNode, + buildSide = buildSide, + buildOutput = buildOutput, + isWrapped = true) + } // Because we do not pass in the buildNode, we take the output of buildNode to // create the inputSet properly. @@ -51,16 +59,8 @@ case class BroadcastHashJoinNode( override def open(): Unit = { // Call the open of streamedNode. streamedNode.open() - // Create the HashJoinNode based on the streamedNode and HashedRelation. - hashJoinNode = - HashJoinNode( - conf = conf, - streamedKeys = streamedKeys, - streamedNode = streamedNode, - buildSide = buildSide, - buildOutput = buildOutput, - hashedRelation = hashedRelation.value, - isWrapped = true) + // Set the HashedRelation used by the HashJoinNode. + hashJoinNode.withHashedRelation(hashedRelation.value) // Setup this HashJoinNode. We still call these in case there is any setup work // that needs to be done in this HashJoinNode. Because isWrapped is true, // prepare and open will not propagate to the child of streamedNode. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala index 504c54d8e8ec2..ea4c8c44d82f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics /** * A node for inner hash equi-join. It can be used individually or wrapped by other - * inner hash equi-join nodes such as [[BinarydHashJoinNode]]. This node takes a already + * inner hash equi-join nodes such as [[BinaryHashJoinNode]]. This node takes a already * built [[HashedRelation]] and a [[LocalNode]] representing the streamed side. * If this node is used individually, `isWrapped` should be set to false. * If this node is wrapped in another node, `isWrapped` should be set to true @@ -40,7 +40,6 @@ case class HashJoinNode( streamedNode: LocalNode, buildSide: BuildSide, buildOutput: Seq[Attribute], - hashedRelation: HashedRelation, isWrapped: Boolean) extends UnaryLocalNode(conf) { override val child = streamedNode @@ -61,7 +60,7 @@ case class HashJoinNode( private[this] var joinRow: JoinedRow = _ private[this] var resultProjection: (InternalRow) => InternalRow = _ - private[this] val hashed: HashedRelation = hashedRelation + private[this] var hashed: HashedRelation = _ private[this] var joinKeys: Projection = _ private[this] def isUnsafeMode: Boolean = { @@ -76,6 +75,11 @@ case class HashJoinNode( } } + /** Sets the HashedRelation used by this node. */ + def withHashedRelation(hashedRelation: HashedRelation): Unit = { + hashed = hashedRelation + } + override def prepare(): Unit = { if (!isWrapped) { // This node is used individually, we should propagate prepare call. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala index 901217bd044c2..01702f26560d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala @@ -89,7 +89,7 @@ class HashJoinNodeSuite extends LocalNodeTest with SharedSQLContext { val rightNode = new DummyNode(joinNicknameAttributes, rightInput) val makeBinaryHashJoinNode = (node1: LocalNode, node2: LocalNode) => { val binaryHashJoinNode = - BinarydHashJoinNode(conf, Seq('id1), Seq('id2), buildSide, node1, node2) + BinaryHashJoinNode(conf, Seq('id1), Seq('id2), buildSide, node1, node2) resolveExpressions(binaryHashJoinNode) } val makeBroadcastJoinNode = (node1: LocalNode, node2: LocalNode) => {