Skip to content

Commit

Permalink
support UnsafeRow in LeftSemiJoinBNL and BroadcastNestedLoopJoin
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Jul 21, 2015
1 parent 611d2ed commit a05b4f6
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) {
/**
* Function for writing a column into an UnsafeRow.
*/
abstract class UnsafeColumnWriter {
private abstract class UnsafeColumnWriter {
/**
* Write a value into an UnsafeRow.
*
Expand All @@ -130,7 +130,7 @@ abstract class UnsafeColumnWriter {
def getSize(source: InternalRow, column: Int): Int
}

object UnsafeColumnWriter {
private object UnsafeColumnWriter {

def forType(dataType: DataType): UnsafeColumnWriter = {
dataType match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@ case class BroadcastNestedLoopJoin(
case BuildLeft => (right, left)
}

override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows
override def canProcessUnsafeRows: Boolean = true

@transient private[this] lazy val resultProjection: Projection = {
if (outputsUnsafeRows) {
UnsafeProjection.create(schema)
} else {
((r: InternalRow) => r).asInstanceOf[Projection]
}
}

override def outputPartitioning: Partitioning = streamed.outputPartitioning

override def output: Seq[Attribute] = {
Expand Down Expand Up @@ -74,6 +85,7 @@ case class BroadcastNestedLoopJoin(
val includedBroadcastTuples =
new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
val joinedRow = new JoinedRow

val leftNulls = new GenericMutableRow(left.output.size)
val rightNulls = new GenericMutableRow(right.output.size)

Expand All @@ -86,11 +98,11 @@ case class BroadcastNestedLoopJoin(
val broadcastedRow = broadcastedRelation.value(i)
buildSide match {
case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) =>
matchedRows += joinedRow(streamedRow, broadcastedRow).copy()
matchedRows += resultProjection(joinedRow(streamedRow, broadcastedRow)).copy()
streamRowMatched = true
includedBroadcastTuples += i
case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) =>
matchedRows += joinedRow(broadcastedRow, streamedRow).copy()
matchedRows += resultProjection(joinedRow(broadcastedRow, streamedRow)).copy()
streamRowMatched = true
includedBroadcastTuples += i
case _ =>
Expand All @@ -100,22 +112,19 @@ case class BroadcastNestedLoopJoin(

(streamRowMatched, joinType, buildSide) match {
case (false, LeftOuter | FullOuter, BuildRight) =>
matchedRows += joinedRow(streamedRow, rightNulls).copy()
matchedRows += resultProjection(joinedRow(streamedRow, rightNulls)).copy()
case (false, RightOuter | FullOuter, BuildLeft) =>
matchedRows += joinedRow(leftNulls, streamedRow).copy()
matchedRows += resultProjection(joinedRow(leftNulls, streamedRow)).copy()
case _ =>
}
}
Iterator((matchedRows, includedBroadcastTuples))
}

val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2)
val allIncludedBroadcastTuples =
if (includedBroadcastTuples.count == 0) {
new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
} else {
includedBroadcastTuples.reduce(_ ++ _)
}
val allIncludedBroadcastTuples = includedBroadcastTuples.fold(
new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
)(_ ++ _)

val leftNulls = new GenericMutableRow(left.output.size)
val rightNulls = new GenericMutableRow(right.output.size)
Expand All @@ -127,8 +136,10 @@ case class BroadcastNestedLoopJoin(
while (i < rel.length) {
if (!allIncludedBroadcastTuples.contains(i)) {
(joinType, buildSide) match {
case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i))
case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls)
case (RightOuter | FullOuter, BuildRight) =>
buf += resultProjection(new JoinedRow(leftNulls, rel(i)))
case (LeftOuter | FullOuter, BuildLeft) =>
buf += resultProjection(new JoinedRow(rel(i), rightNulls))
case _ =>
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ case class LeftSemiJoinBNL(

override def output: Seq[Attribute] = left.output

override def outputsUnsafeRows: Boolean = streamed.outputsUnsafeRows
override def canProcessUnsafeRows: Boolean = true

/** The Streamed Relation */
override def left: SparkPlan = streamed

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,19 @@ private[sql] object EnsureRowFormats extends Rule[SparkPlan] {
}
case operator: SparkPlan if handlesBothSafeAndUnsafeRows(operator) =>
if (operator.children.map(_.outputsUnsafeRows).toSet.size != 1) {
// If this operator's children produce both unsafe and safe rows, then convert everything
// to unsafe rows
operator.withNewChildren {
operator.children.map {
c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c
// If this operator's children produce both unsafe and safe rows,
// convert everything unsafe rows if all the schema of them are support by UnsafeRow
if (operator.children.forall(c => UnsafeProjection.canSupport(c.schema))) {
operator.withNewChildren {
operator.children.map {
c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c
}
}
} else {
operator.withNewChildren {
operator.children.map {
c => if (c.outputsUnsafeRows) ConvertToSafe(c) else c
}
}
}
} else {
Expand Down

0 comments on commit a05b4f6

Please sign in to comment.