Skip to content

Commit

Permalink
[SPARK-9024] Unsafe HashJoin/HashOuterJoin/HashSemiJoin
Browse files Browse the repository at this point in the history
This PR introduce unsafe version (using UnsafeRow) of HashJoin, HashOuterJoin and HashSemiJoin, including the broadcast one and shuffle one (except FullOuterJoin, which is better to be implemented using SortMergeJoin).

It use HashMap to store UnsafeRow right now, will change to use BytesToBytesMap for better performance (in another PR).

Author: Davies Liu <davies@databricks.com>

Closes #7480 from davies/unsafe_join and squashes the following commits:

6294b1e [Davies Liu] fix projection
10583f1 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_join
dede020 [Davies Liu] fix test
84c9807 [Davies Liu] address comments
a05b4f6 [Davies Liu] support UnsafeRow in LeftSemiJoinBNL and BroadcastNestedLoopJoin
611d2ed [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_join
9481ae8 [Davies Liu] return UnsafeRow after join()
ca2b40f [Davies Liu] revert unrelated change
68f5cd9 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_join
0f4380d [Davies Liu] ada a comment
69e38f5 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_join
1a40f02 [Davies Liu] refactor
ab1690f [Davies Liu] address comments
60371f2 [Davies Liu] use UnsafeRow in SemiJoin
a6c0b7d [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_join
184b852 [Davies Liu] fix style
6acbb11 [Davies Liu] fix tests
95d0762 [Davies Liu] remove println
bea4a50 [Davies Liu] Unsafe HashJoin
  • Loading branch information
Davies Liu authored and davies committed Jul 22, 2015
1 parent 86f80e2 commit e0b7ba5
Show file tree
Hide file tree
Showing 20 changed files with 444 additions and 135 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
import java.io.IOException;
import java.io.OutputStream;

import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.util.ObjectPool;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.bitset.BitSetMethods;
import org.apache.spark.unsafe.hash.Murmur3_x86_32;
import org.apache.spark.unsafe.types.UTF8String;


Expand Down Expand Up @@ -354,7 +355,7 @@ public double getDouble(int i) {
* This method is only supported on UnsafeRows that do not use ObjectPools.
*/
@Override
public InternalRow copy() {
public UnsafeRow copy() {
if (pool != null) {
throw new UnsupportedOperationException(
"Copy is not supported for UnsafeRows that use object pools");
Expand Down Expand Up @@ -404,8 +405,51 @@ public void writeToStream(OutputStream out, byte[] writeBuffer) throws IOExcepti
}
}

@Override
public int hashCode() {
return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 42);
}

@Override
public boolean equals(Object other) {
if (other instanceof UnsafeRow) {
UnsafeRow o = (UnsafeRow) other;
return (sizeInBytes == o.sizeInBytes) &&
ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset,
sizeInBytes);
}
return false;
}

/**
* Returns the underlying bytes for this UnsafeRow.
*/
public byte[] getBytes() {
if (baseObject instanceof byte[] && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET
&& (((byte[]) baseObject).length == sizeInBytes)) {
return (byte[]) baseObject;
} else {
byte[] bytes = new byte[sizeInBytes];
PlatformDependent.copyMemory(baseObject, baseOffset, bytes,
PlatformDependent.BYTE_ARRAY_OFFSET, sizeInBytes);
return bytes;
}
}

// This is for debugging
@Override
public String toString() {
StringBuilder build = new StringBuilder("[");
for (int i = 0; i < sizeInBytes; i += 8) {
build.append(PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + i));
build.append(',');
}
build.append(']');
return build.toString();
}

@Override
public boolean anyNull() {
return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes);
return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes / 8);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,10 @@
import org.apache.spark.TaskContext;
import org.apache.spark.sql.AbstractScalaRowIterator;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeColumnWriter;
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.catalyst.util.ObjectPool;
import org.apache.spark.sql.types.*;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.util.collection.unsafe.sort.PrefixComparator;
import org.apache.spark.util.collection.unsafe.sort.RecordComparator;
Expand Down Expand Up @@ -176,12 +175,7 @@ public Iterator<InternalRow> sort(Iterator<InternalRow> inputIterator) throws IO
*/
public static boolean supportsSchema(StructType schema) {
// TODO: add spilling note to explain why we do this for now:
for (StructField field : schema.fields()) {
if (!UnsafeColumnWriter.canEmbed(field.dataType())) {
return false;
}
}
return true;
return UnsafeProjection.canSupport(schema);
}

private static final class RowComparator extends RecordComparator {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.types._

/**
Expand All @@ -34,7 +33,23 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)

override def toString: String = s"input[$ordinal, $dataType]"

override def eval(input: InternalRow): Any = input(ordinal)
// Use special getter for primitive types (for UnsafeRow)
override def eval(input: InternalRow): Any = {
if (input.isNullAt(ordinal)) {
null
} else {
dataType match {
case BooleanType => input.getBoolean(ordinal)
case ByteType => input.getByte(ordinal)
case ShortType => input.getShort(ordinal)
case IntegerType | DateType => input.getInt(ordinal)
case LongType | TimestampType => input.getLong(ordinal)
case FloatType => input.getFloat(ordinal)
case DoubleType => input.getDouble(ordinal)
case _ => input.get(ordinal)
}
}
}

override def name: String = s"i[$ordinal]"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,51 @@ abstract class UnsafeProjection extends Projection {
}

object UnsafeProjection {

/*
* Returns whether UnsafeProjection can support given StructType, Array[DataType] or
* Seq[Expression].
*/
def canSupport(schema: StructType): Boolean = canSupport(schema.fields.map(_.dataType))
def canSupport(types: Array[DataType]): Boolean = types.forall(UnsafeColumnWriter.canEmbed(_))
def canSupport(exprs: Seq[Expression]): Boolean = canSupport(exprs.map(_.dataType).toArray)

/**
* Returns an UnsafeProjection for given StructType.
*/
def create(schema: StructType): UnsafeProjection = create(schema.fields.map(_.dataType))

def create(fields: Seq[DataType]): UnsafeProjection = {
/**
* Returns an UnsafeProjection for given Array of DataTypes.
*/
def create(fields: Array[DataType]): UnsafeProjection = {
val exprs = fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true))
create(exprs)
}

/**
* Returns an UnsafeProjection for given sequence of Expressions (bounded).
*/
def create(exprs: Seq[Expression]): UnsafeProjection = {
GenerateUnsafeProjection.generate(exprs)
}

/**
* Returns an UnsafeProjection for given sequence of Expressions, which will be bound to
* `inputSchema`.
*/
def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = {
create(exprs.map(BindReferences.bindReference(_, inputSchema)))
}
}

/**
* A projection that could turn UnsafeRow into GenericInternalRow
*/
case class FromUnsafeProjection(fields: Seq[DataType]) extends Projection {

def this(schema: StructType) = this(schema.fields.map(_.dataType))

private[this] val expressions = fields.zipWithIndex.map { case (dt, idx) =>
new BoundReference(idx, dt, true)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ case class BroadcastHashJoin(
private val broadcastFuture = future {
// Note that we use .execute().collect() because we don't want to convert data to Scala types
val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length)
val hashed = buildHashRelation(input.iterator)
sparkContext.broadcast(hashed)
}(BroadcastHashJoin.broadcastHashJoinExecutionContext)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.sql.execution.joins

import scala.concurrent._
import scala.concurrent.duration._

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
Expand All @@ -26,10 +29,6 @@ import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.util.ThreadUtils

import scala.collection.JavaConversions._
import scala.concurrent._
import scala.concurrent.duration._

/**
* :: DeveloperApi ::
* Performs a outer hash join for two child relations. When the output RDD of this operator is
Expand Down Expand Up @@ -58,28 +57,11 @@ case class BroadcastHashOuterJoin(
override def requiredChildDistribution: Seq[Distribution] =
UnspecifiedDistribution :: UnspecifiedDistribution :: Nil

private[this] lazy val (buildPlan, streamedPlan) = joinType match {
case RightOuter => (left, right)
case LeftOuter => (right, left)
case x =>
throw new IllegalArgumentException(
s"BroadcastHashOuterJoin should not take $x as the JoinType")
}

private[this] lazy val (buildKeys, streamedKeys) = joinType match {
case RightOuter => (leftKeys, rightKeys)
case LeftOuter => (rightKeys, leftKeys)
case x =>
throw new IllegalArgumentException(
s"BroadcastHashOuterJoin should not take $x as the JoinType")
}

@transient
private val broadcastFuture = future {
// Note that we use .execute().collect() because we don't want to convert data to Scala types
val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
// buildHashTable uses code-generated rows as keys, which are not serializable
val hashed = buildHashTable(input.iterator, newProjection(buildKeys, buildPlan.output))
val hashed = buildHashRelation(input.iterator)
sparkContext.broadcast(hashed)
}(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext)

Expand All @@ -89,21 +71,21 @@ case class BroadcastHashOuterJoin(
streamedPlan.execute().mapPartitions { streamedIter =>
val joinedRow = new JoinedRow()
val hashTable = broadcastRelation.value
val keyGenerator = newProjection(streamedKeys, streamedPlan.output)
val keyGenerator = streamedKeyGenerator

joinType match {
case LeftOuter =>
streamedIter.flatMap(currentRow => {
val rowKey = keyGenerator(currentRow)
joinedRow.withLeft(currentRow)
leftOuterIterator(rowKey, joinedRow, hashTable.getOrElse(rowKey, EMPTY_LIST))
leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey))
})

case RightOuter =>
streamedIter.flatMap(currentRow => {
val rowKey = keyGenerator(currentRow)
joinedRow.withRight(currentRow)
rightOuterIterator(rowKey, hashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow)
rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow)
})

case x =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,14 @@ case class BroadcastLeftSemiJoinHash(
val buildIter = right.execute().map(_.copy()).collect().toIterator

if (condition.isEmpty) {
// rowKey may be not serializable (from codegen)
val hashSet = buildKeyHashSet(buildIter, copy = true)
val hashSet = buildKeyHashSet(buildIter)
val broadcastedRelation = sparkContext.broadcast(hashSet)

left.execute().mapPartitions { streamIter =>
hashSemiJoin(streamIter, broadcastedRelation.value)
}
} else {
val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
val hashRelation = buildHashRelation(buildIter)
val broadcastedRelation = sparkContext.broadcast(hashRelation)

left.execute().mapPartitions { streamIter =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,19 @@ case class BroadcastNestedLoopJoin(
case BuildLeft => (right, left)
}

override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows
override def canProcessUnsafeRows: Boolean = true

@transient private[this] lazy val resultProjection: Projection = {
if (outputsUnsafeRows) {
UnsafeProjection.create(schema)
} else {
new Projection {
override def apply(r: InternalRow): InternalRow = r
}
}
}

override def outputPartitioning: Partitioning = streamed.outputPartitioning

override def output: Seq[Attribute] = {
Expand Down Expand Up @@ -74,6 +87,7 @@ case class BroadcastNestedLoopJoin(
val includedBroadcastTuples =
new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
val joinedRow = new JoinedRow

val leftNulls = new GenericMutableRow(left.output.size)
val rightNulls = new GenericMutableRow(right.output.size)

Expand All @@ -86,11 +100,11 @@ case class BroadcastNestedLoopJoin(
val broadcastedRow = broadcastedRelation.value(i)
buildSide match {
case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) =>
matchedRows += joinedRow(streamedRow, broadcastedRow).copy()
matchedRows += resultProjection(joinedRow(streamedRow, broadcastedRow)).copy()
streamRowMatched = true
includedBroadcastTuples += i
case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) =>
matchedRows += joinedRow(broadcastedRow, streamedRow).copy()
matchedRows += resultProjection(joinedRow(broadcastedRow, streamedRow)).copy()
streamRowMatched = true
includedBroadcastTuples += i
case _ =>
Expand All @@ -100,22 +114,19 @@ case class BroadcastNestedLoopJoin(

(streamRowMatched, joinType, buildSide) match {
case (false, LeftOuter | FullOuter, BuildRight) =>
matchedRows += joinedRow(streamedRow, rightNulls).copy()
matchedRows += resultProjection(joinedRow(streamedRow, rightNulls)).copy()
case (false, RightOuter | FullOuter, BuildLeft) =>
matchedRows += joinedRow(leftNulls, streamedRow).copy()
matchedRows += resultProjection(joinedRow(leftNulls, streamedRow)).copy()
case _ =>
}
}
Iterator((matchedRows, includedBroadcastTuples))
}

val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2)
val allIncludedBroadcastTuples =
if (includedBroadcastTuples.count == 0) {
new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
} else {
includedBroadcastTuples.reduce(_ ++ _)
}
val allIncludedBroadcastTuples = includedBroadcastTuples.fold(
new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
)(_ ++ _)

val leftNulls = new GenericMutableRow(left.output.size)
val rightNulls = new GenericMutableRow(right.output.size)
Expand All @@ -127,8 +138,10 @@ case class BroadcastNestedLoopJoin(
while (i < rel.length) {
if (!allIncludedBroadcastTuples.contains(i)) {
(joinType, buildSide) match {
case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i))
case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls)
case (RightOuter | FullOuter, BuildRight) =>
buf += resultProjection(new JoinedRow(leftNulls, rel(i)))
case (LeftOuter | FullOuter, BuildLeft) =>
buf += resultProjection(new JoinedRow(rel(i), rightNulls))
case _ =>
}
}
Expand Down
Loading

0 comments on commit e0b7ba5

Please sign in to comment.