Skip to content

Commit

Permalink
Address comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
yhuai committed Oct 7, 2015
1 parent ecbd1d0 commit 081e331
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@ abstract class BaseMutableProjection extends MutableProjection

/**
* Generates byte code that produces a [[MutableRow]] object that can update itself based on a new
* input [[InternalRow]] for a fixed set of [[Expression Expressions]]. It exposes a `target`
* method. This method is used to set the row that will be updated. So, when `target` is used, the
* [[MutableRow]] object created internally will not be used. If `target` is not used, the
* [[MutableRow]] object created internally will be used.
* input [[InternalRow]] for a fixed set of [[Expression Expressions]].
* It exposes a `target` method, which is used to set the row that will be updated.
* The internal [[MutableRow]] object created internally is used only when `target` is not used.
*/
object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => MutableProjection] {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.types._


/**
* Generates byte code that produces a [[MutableRow]] object (not a [[UnsafeRow]]) that can update
* Generates byte code that produces a [[MutableRow]] object (not an [[UnsafeRow]]) that can update
* itself based on a new input [[InternalRow]] for a fixed set of [[Expression Expressions]].
*/
object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ case class BinaryHashJoinNode(
}

protected override def doOpen(): Unit = {
// buildNode's prepare has been called in this.prepare.
buildNode.open()
val hashedRelation = HashedRelation(buildNode, buildSideKeyGenerator)
// We have built the HashedRelation. So, close buildNode.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.joins._

/**
* A node for inner hash equi-join. [[BinaryHashJoinNode]] and [[BroadcastHashJoinNode]]
* are based on this.
* An abstract node for sharing common functionality among different implementations of
* inner hash equi-join, notably [[BinaryHashJoinNode]] and [[BroadcastHashJoinNode]].
*
* Much of this code is similar to [[org.apache.spark.sql.execution.joins.HashJoin]].
*/
Expand Down Expand Up @@ -60,21 +60,21 @@ trait HashJoinNode {
}
}

/** Sets the HashedRelation used by this node. */
/**
* Sets the HashedRelation used by this node. This method needs to be called after
* before the first `next` gets called.
*/
protected def withHashedRelation(hashedRelation: HashedRelation): Unit = {
hashed = hashedRelation
}

/**
* For nodes that extends this, they can use doOpen to add operations needed in the open method.
* The implementation of this method should invoke its children's open methods.
* Custom open implementation to be overridden by subclasses.
*/
protected def doOpen(): Unit

override def open(): Unit = {
// First, call doOpen to invoke custom operations for a node.
doOpen()
// Second, initialize common internal states.
joinRow = new JoinedRow
resultProjection = {
if (isUnsafeMode) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Loggin
result
}

protected[sql] def newProjection(
protected def newProjection(
expressions: Seq[Expression],
inputSchema: Seq[Attribute]): Projection = {
log.debug(
Expand All @@ -129,7 +129,7 @@ abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Loggin
}
}

protected[sql] def newMutableProjection(
protected def newMutableProjection(
expressions: Seq[Expression],
inputSchema: Seq[Attribute]): () => MutableProjection = {
log.debug(
Expand All @@ -151,7 +151,7 @@ abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Loggin
}
}

protected[sql] def newPredicate(
protected def newPredicate(
expression: Expression,
inputSchema: Seq[Attribute]): (InternalRow) => Boolean = {
if (codegenEnabled) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@

package org.apache.spark.sql.execution.local

import org.mockito.Mockito.{mock, when}

import org.apache.spark.broadcast.TorrentBroadcast
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, Expression}
import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, UnsafeProjection, Expression}
import org.apache.spark.sql.execution.joins.{HashedRelation, BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.test.SharedSQLContext


class HashJoinNodeSuite extends LocalNodeTest with SharedSQLContext {
class HashJoinNodeSuite extends LocalNodeTest {

// Test all combinations of the two dimensions: with/out unsafe and build sides
private val maybeUnsafeAndCodegen = Seq(false, true)
Expand All @@ -44,26 +45,21 @@ class HashJoinNodeSuite extends LocalNodeTest with SharedSQLContext {
buildKeys: Seq[Expression],
buildNode: LocalNode): HashedRelation = {

// Check if we are in the Unsafe mode.
val isUnsafeMode =
conf.codegenEnabled &&
conf.unsafeEnabled &&
UnsafeProjection.canSupport(buildKeys)

// Create projection used for extracting keys
val buildSideKeyGenerator =
if (isUnsafeMode) {
UnsafeProjection.create(buildKeys, buildNode.output)
} else {
buildNode.newMutableProjection(buildKeys, buildNode.output)()
new InterpretedMutableProjection(buildKeys, buildNode.output)
}

// Setup the node.
buildNode.prepare()
buildNode.open()
// Build the HashedRelation
val hashedRelation = HashedRelation(buildNode, buildSideKeyGenerator)
// Close the node.
buildNode.close()

hashedRelation
Expand Down Expand Up @@ -104,7 +100,8 @@ class HashJoinNodeSuite extends LocalNodeTest with SharedSQLContext {
val resolvedBuildNode = resolveExpressions(buildNode)
val resolvedBuildKeys = resolveExpressions(buildKeys, resolvedBuildNode)
val hashedRelation = buildHashedRelation(conf, resolvedBuildKeys, resolvedBuildNode)
val broadcastHashedRelation = sqlContext.sparkContext.broadcast(hashedRelation)
val broadcastHashedRelation = mock(classOf[TorrentBroadcast[HashedRelation]])
when(broadcastHashedRelation.value).thenReturn(hashedRelation)

val hashJoinNode =
BroadcastHashJoinNode(
Expand Down

0 comments on commit 081e331

Please sign in to comment.