Skip to content

Commit

Permalink
Build HashedRelation outside of HashJoinNode.
Browse files Browse the repository at this point in the history
  • Loading branch information
yhuai committed Sep 30, 2015
1 parent 03cca5d commit 0cf8d44
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,21 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
*/
case class HashJoinNode(
conf: SQLConf,
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
streamedKeys: Seq[Expression],
streamedNode: LocalNode,
buildSide: BuildSide,
left: LocalNode,
right: LocalNode) extends BinaryLocalNode(conf) {
buildOutput: Seq[Attribute],
hashedRelation: HashedRelation) extends UnaryLocalNode(conf) {

private[this] lazy val (buildNode, buildKeys, streamedNode, streamedKeys) = buildSide match {
case BuildLeft => (left, leftKeys, right, rightKeys)
case BuildRight => (right, rightKeys, left, leftKeys)
override val child = streamedNode

// Because we do not pass in the buildNode, we take the output of buildNode to
// create the inputSet properly.
override def inputSet: AttributeSet = AttributeSet(child.output ++ buildOutput)

override def output: Seq[Attribute] = buildSide match {
case BuildRight => streamedNode.output ++ buildOutput
case BuildLeft => buildOutput ++ streamedNode.output
}

private[this] var currentStreamedRow: InternalRow = _
Expand All @@ -46,23 +52,12 @@ case class HashJoinNode(
private[this] var joinRow: JoinedRow = _
private[this] var resultProjection: (InternalRow) => InternalRow = _

private[this] var hashed: HashedRelation = _
private[this] val hashed: HashedRelation = hashedRelation
private[this] var joinKeys: Projection = _

override def output: Seq[Attribute] = left.output ++ right.output

private[this] def isUnsafeMode: Boolean = {
(codegenEnabled && unsafeEnabled
&& UnsafeProjection.canSupport(buildKeys)
&& UnsafeProjection.canSupport(schema))
}

private[this] def buildSideKeyGenerator: Projection = {
if (isUnsafeMode) {
UnsafeProjection.create(buildKeys, buildNode.output)
} else {
newMutableProjection(buildKeys, buildNode.output)()
}
(codegenEnabled && unsafeEnabled && UnsafeProjection.canSupport(schema))
}

private[this] def streamSideKeyGenerator: Projection = {
Expand All @@ -74,8 +69,6 @@ case class HashJoinNode(
}

override def open(): Unit = {
buildNode.open()
hashed = HashedRelation(buildNode, buildSideKeyGenerator)
streamedNode.open()
joinRow = new JoinedRow
resultProjection = {
Expand Down Expand Up @@ -130,7 +123,6 @@ case class HashJoinNode(
}

override def close(): Unit = {
left.close()
right.close()
streamedNode.close()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Loggin
result
}

protected def newProjection(
protected[sql] def newProjection(
expressions: Seq[Expression],
inputSchema: Seq[Attribute]): Projection = {
log.debug(
Expand All @@ -129,7 +129,7 @@ abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Loggin
}
}

protected def newMutableProjection(
protected[sql] def newMutableProjection(
expressions: Seq[Expression],
inputSchema: Seq[Attribute]): () => MutableProjection = {
log.debug(
Expand All @@ -151,7 +151,7 @@ abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Loggin
}
}

protected def newPredicate(
protected[sql] def newPredicate(
expression: Expression,
inputSchema: Seq[Attribute]): (InternalRow) => Boolean = {
if (codegenEnabled) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ package org.apache.spark.sql.execution.local

import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, Expression}
import org.apache.spark.sql.execution.joins.{HashedRelation, BuildLeft, BuildRight, BuildSide}


class HashJoinNodeSuite extends LocalNodeTest {
Expand All @@ -33,6 +34,40 @@ class HashJoinNodeSuite extends LocalNodeTest {
}
}

/**
* Builds a [[HashedRelation]] based on a resolved `buildKeys`
* and a resolved `buildNode`.
*/
private def buildHashedRelation(
conf: SQLConf,
buildKeys: Seq[Expression],
buildNode: LocalNode): HashedRelation = {

// Check if we are in the Unsafe mode.
val isUnsafeMode =
conf.codegenEnabled &&
conf.unsafeEnabled &&
UnsafeProjection.canSupport(buildKeys)

// Create projection used for extracting keys
val buildSideKeyGenerator =
if (isUnsafeMode) {
UnsafeProjection.create(buildKeys, buildNode.output)
} else {
buildNode.newMutableProjection(buildKeys, buildNode.output)()
}

// Setup the node.
buildNode.prepare()
buildNode.open()
// Build the HashedRelation
val hashedRelation = HashedRelation(buildNode, buildSideKeyGenerator)
// Close the node.
buildNode.close()

hashedRelation
}

/**
* Test inner hash join with varying degrees of matches.
*/
Expand All @@ -52,8 +87,28 @@ class HashJoinNodeSuite extends LocalNodeTest {
val leftNode = new DummyNode(joinNameAttributes, leftInput)
val rightNode = new DummyNode(joinNicknameAttributes, rightInput)
val makeNode = (node1: LocalNode, node2: LocalNode) => {
resolveExpressions(new HashJoinNode(
conf, Seq('id1), Seq('id2), buildSide, node1, node2))
val leftKeys = Seq('id1.attr)
val rightKeys = Seq('id2.attr)
// Figure out the build side and stream side.
val (buildNode, buildKeys, streamedNode, streamedKeys) = buildSide match {
case BuildLeft => (node1, leftKeys, node2, rightKeys)
case BuildRight => (node2, rightKeys, node1, leftKeys)
}
// Resolve the expressions of the build side and then create a HashedRelation.
val resolvedBuildNode = resolveExpressions(buildNode)
val resolvedBuildKeys = resolveExpressions(buildKeys, resolvedBuildNode)
val hashedRelation = buildHashedRelation(conf, resolvedBuildKeys, resolvedBuildNode)

// Build the HashJoinNode.
val hashJoinNode =
HashJoinNode(
conf,
streamedKeys,
streamedNode,
buildSide,
resolvedBuildNode.output,
hashedRelation)
resolveExpressions(hashJoinNode)
}
val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode
val hashJoinNode = makeUnsafeNode(leftNode, rightNode)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.local
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.{Expression, AttributeReference}
import org.apache.spark.sql.types.{IntegerType, StringType}


Expand Down Expand Up @@ -67,4 +67,22 @@ class LocalNodeTest extends SparkFunSuite {
}
}

/**
* Resolve all expressions in `expressions` based on the `output` of `localNode`.
* It assumes that all expressions in the `localNode` are resolved.
*/
protected def resolveExpressions(
expressions: Seq[Expression],
localNode: LocalNode): Seq[Expression] = {
require(localNode.expressions.forall(_.resolved))
val inputMap = localNode.output.map { a => (a.name, a) }.toMap
expressions.map { expression =>
expression.transformUp {
case UnresolvedAttribute(Seq(u)) =>
inputMap.getOrElse(u,
sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
}
}
}

}

0 comments on commit 0cf8d44

Please sign in to comment.