Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-9024] Unsafe HashJoin/HashOuterJoin/HashSemiJoin #7480

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@

package org.apache.spark.sql.catalyst.expressions;

import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.util.ObjectPool;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.bitset.BitSetMethods;
import org.apache.spark.unsafe.hash.Murmur3_x86_32;
import org.apache.spark.unsafe.types.UTF8String;


Expand Down Expand Up @@ -345,7 +346,7 @@ public double getDouble(int i) {
* This method is only supported on UnsafeRows that do not use ObjectPools.
*/
@Override
public InternalRow copy() {
public UnsafeRow copy() {
if (pool != null) {
throw new UnsupportedOperationException(
"Copy is not supported for UnsafeRows that use object pools");
Expand All @@ -365,8 +366,50 @@ public InternalRow copy() {
}
}

@Override
public int hashCode() {
return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 42);
}

@Override
public boolean equals(Object other) {
if (other instanceof UnsafeRow) {
UnsafeRow o = (UnsafeRow) other;
return ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we should check whether the rows' sizeInBytes are equal before attempting to compare their contents.

sizeInBytes);
}
return false;
}

/**
* Returns the underline bytes for this UnsafeRow.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"underline" -> "underlying"

*/
public byte[] getBytes() {
if (baseObject instanceof byte[] && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a nice optimization!

&& (((byte[]) baseObject).length == sizeInBytes)) {
return (byte[]) baseObject;
} else {
byte[] bytes = new byte[sizeInBytes];
PlatformDependent.copyMemory(baseObject, baseOffset, bytes,
PlatformDependent.BYTE_ARRAY_OFFSET, sizeInBytes);
return bytes;
}
}

// This is for debugging
@Override
public String toString(){
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style nit: space after ().

StringBuilder build = new StringBuilder("[");
for (int i = 0; i < sizeInBytes; i += 8) {
build.append(PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + i));
build.append(',');
}
build.append(']');
return build.toString();
}

@Override
public boolean anyNull() {
return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes);
return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes / 8);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a unit test for this? i'd imagine it affects correctness

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,10 @@
import org.apache.spark.TaskContext;
import org.apache.spark.sql.AbstractScalaRowIterator;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeColumnWriter;
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.catalyst.util.ObjectPool;
import org.apache.spark.sql.types.*;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.util.collection.unsafe.sort.PrefixComparator;
import org.apache.spark.util.collection.unsafe.sort.RecordComparator;
Expand Down Expand Up @@ -176,12 +175,7 @@ public Iterator<InternalRow> sort(Iterator<InternalRow> inputIterator) throws IO
*/
public static boolean supportsSchema(StructType schema) {
// TODO: add spilling note to explain why we do this for now:
for (StructField field : schema.fields()) {
if (!UnsafeColumnWriter.canEmbed(field.dataType())) {
return false;
}
}
return true;
return UnsafeProjection.canSupport(schema);
}

private static final class RowComparator extends RecordComparator {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.types._

/**
Expand All @@ -34,7 +33,23 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)

override def toString: String = s"input[$ordinal]"

override def eval(input: InternalRow): Any = input(ordinal)
// Use special getter for primitive types (for UnsafeRow)
override def eval(input: InternalRow): Any = {
if (input.isNullAt(ordinal)) {
null
} else {
dataType match {
case BooleanType => input.getBoolean(ordinal)
case ByteType => input.getByte(ordinal)
case ShortType => input.getShort(ordinal)
case IntegerType | DateType => input.getInt(ordinal)
case LongType | TimestampType => input.getLong(ordinal)
case FloatType => input.getFloat(ordinal)
case DoubleType => input.getDouble(ordinal)
case _ => input.get(ordinal)
}
}
}

override def name: String = s"i[$ordinal]"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,32 @@ abstract class UnsafeProjection extends Projection {
}

object UnsafeProjection {
def canSupport(schema: StructType): Boolean = canSupport(schema.fields.map(_.dataType))
def canSupport(types: Seq[DataType]): Boolean = types.forall(UnsafeColumnWriter.canEmbed(_))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could even add a canSupport(exprs: Seq[Expression]) to be able to save some characters elsewhere.


def create(schema: StructType): UnsafeProjection = create(schema.fields.map(_.dataType))

def create(fields: Seq[DataType]): UnsafeProjection = {
def create(fields: Array[DataType]): UnsafeProjection = {
val exprs = fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true))
create(exprs)
}

def create(exprs: Seq[Expression]): UnsafeProjection = {
GenerateUnsafeProjection.generate(exprs)
}

def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = {
create(exprs.map(BindReferences.bindReference(_, inputSchema)))
}
}

/**
* A projection that could turn UnsafeRow into GenericInternalRow
*/
case class FromUnsafeProjection(fields: Seq[DataType]) extends Projection {

def this(schema: StructType) = this(schema.fields.map(_.dataType))

private[this] val expressions = fields.zipWithIndex.map { case (dt, idx) =>
new BoundReference(idx, dt, true)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) {
/**
* Function for writing a column into an UnsafeRow.
*/
private abstract class UnsafeColumnWriter {
abstract class UnsafeColumnWriter {
/**
* Write a value into an UnsafeRow.
*
Expand All @@ -130,7 +130,7 @@ private abstract class UnsafeColumnWriter {
def getSize(source: InternalRow, column: Int): Int
}

private object UnsafeColumnWriter {
object UnsafeColumnWriter {

def forType(dataType: DataType): UnsafeColumnWriter = {
dataType match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("FORMAT") {
val f = 'f.string.at(0)
val d1 = 'd.int.at(1)
val s1 = 's.int.at(2)
val s1 = 's.string.at(2)

val row1 = create_row("aa%d%s", 12, "cc")
val row2 = create_row(null, 12, "cc")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ case class BroadcastHashJoin(
private val broadcastFuture = future {
// Note that we use .execute().collect() because we don't want to convert data to Scala types
val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length)
val hashed = buildHashRelation(input.iterator)
sparkContext.broadcast(hashed)
}(BroadcastHashJoin.broadcastHashJoinExecutionContext)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.sql.execution.joins

import scala.concurrent._
import scala.concurrent.duration._

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
Expand All @@ -26,10 +29,6 @@ import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.util.ThreadUtils

import scala.collection.JavaConversions._
import scala.concurrent._
import scala.concurrent.duration._

/**
* :: DeveloperApi ::
* Performs a outer hash join for two child relations. When the output RDD of this operator is
Expand Down Expand Up @@ -58,28 +57,11 @@ case class BroadcastHashOuterJoin(
override def requiredChildDistribution: Seq[Distribution] =
UnspecifiedDistribution :: UnspecifiedDistribution :: Nil

private[this] lazy val (buildPlan, streamedPlan) = joinType match {
case RightOuter => (left, right)
case LeftOuter => (right, left)
case x =>
throw new IllegalArgumentException(
s"BroadcastHashOuterJoin should not take $x as the JoinType")
}

private[this] lazy val (buildKeys, streamedKeys) = joinType match {
case RightOuter => (leftKeys, rightKeys)
case LeftOuter => (rightKeys, leftKeys)
case x =>
throw new IllegalArgumentException(
s"BroadcastHashOuterJoin should not take $x as the JoinType")
}

@transient
private val broadcastFuture = future {
// Note that we use .execute().collect() because we don't want to convert data to Scala types
val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
// buildHashTable uses code-generated rows as keys, which are not serializable
val hashed = buildHashTable(input.iterator, newProjection(buildKeys, buildPlan.output))
val hashed = buildHashRelation(input.iterator)
sparkContext.broadcast(hashed)
}(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext)

Expand All @@ -96,14 +78,14 @@ case class BroadcastHashOuterJoin(
streamedIter.flatMap(currentRow => {
val rowKey = keyGenerator(currentRow)
joinedRow.withLeft(currentRow)
leftOuterIterator(rowKey, joinedRow, hashTable.getOrElse(rowKey, EMPTY_LIST))
leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey))
})

case RightOuter =>
streamedIter.flatMap(currentRow => {
val rowKey = keyGenerator(currentRow)
joinedRow.withRight(currentRow)
rightOuterIterator(rowKey, hashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow)
rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow)
})

case x =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,15 @@ case class BroadcastLeftSemiJoinHash(
val buildIter = right.execute().map(_.copy()).collect().toIterator

if (condition.isEmpty) {
// rowKey may be not serializable (from codegen)
val hashSet = buildKeyHashSet(buildIter, copy = true)
val hashSet = buildKeyHashSet(buildIter)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Orthogonal to this patch, we should work on removing BroadcastLeftSemiJoinHash, and just use an equi-join. Otherwise we have too many paths we need to optimize for.

val broadcastedRelation = sparkContext.broadcast(hashSet)

left.execute().mapPartitions { streamIter =>
hashSemiJoin(streamIter, broadcastedRelation.value)
}
} else {
val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
val broadcastedRelation = sparkContext.broadcast(hashRelation)
val hashed = buildHashRelation(buildIter)
val broadcastedRelation = sparkContext.broadcast(hashed)

left.execute().mapPartitions { streamIter =>
hashSemiJoin(streamIter, broadcastedRelation.value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,16 @@ trait HashJoin {
}
}
}

protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = {
if (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys.map(_.dataType))
&& UnsafeProjection.canSupport(buildPlan.output.map(_.dataType))) {
UnsafeHashedRelation(
buildIter,
buildKeys.map(BindReferences.bindReference(_, buildPlan.output)),
buildPlan.schema)
} else {
HashedRelation(buildIter, buildSideKeyGenerator)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ trait HashOuterJoin {
val left: SparkPlan
val right: SparkPlan

override def outputPartitioning: Partitioning = joinType match {
override def outputPartitioning: Partitioning = joinType match {
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
Expand All @@ -59,6 +59,30 @@ override def outputPartitioning: Partitioning = joinType match {
}
}

protected[this] lazy val (buildPlan, streamedPlan) = joinType match {
case RightOuter => (left, right)
case LeftOuter => (right, left)
case x =>
throw new IllegalArgumentException(
s"BroadcastHashOuterJoin should not take $x as the JoinType")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this code is now in HashOuterJoin instead of BroadcastHashOuterJoin, I think we should update this error message to reference the new class.

}

protected[this] lazy val (buildKeys, streamedKeys) = joinType match {
case RightOuter => (leftKeys, rightKeys)
case LeftOuter => (rightKeys, leftKeys)
case x =>
throw new IllegalArgumentException(
s"BroadcastHashOuterJoin should not take $x as the JoinType")
}

protected[this] def streamedKeyGenerator(): Projection = {
if (self.codegenEnabled && UnsafeProjection.canSupport(streamedKeys.map(_.dataType))) {
UnsafeProjection.create(streamedKeys, streamedPlan.output)
} else {
newProjection(streamedKeys, streamedPlan.output)
}
}

@transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null)
@transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]()

Expand All @@ -76,8 +100,12 @@ override def outputPartitioning: Partitioning = joinType match {
rightIter: Iterable[InternalRow]): Iterator[InternalRow] = {
val ret: Iterable[InternalRow] = {
if (!key.anyNull) {
val temp = rightIter.collect {
case r if boundCondition(joinedRow.withRight(r)) => joinedRow.copy()
val temp = if (rightIter != null) {
rightIter.collect {
case r if boundCondition(joinedRow.withRight(r)) => joinedRow.copy()
}
} else {
List()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that you can use List.empty here, which, as far as I know, returns an immutable singleton. Not sure if List() creates a new instance or not...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that the old code seemed to make a special point of using an EMPTY_LIST constant I'm thinking that it may be important to make sure we're not creating new objects here.

}
if (temp.isEmpty) {
joinedRow.withRight(rightNullRow).copy :: Nil
Expand All @@ -97,9 +125,13 @@ override def outputPartitioning: Partitioning = joinType match {
joinedRow: JoinedRow): Iterator[InternalRow] = {
val ret: Iterable[InternalRow] = {
if (!key.anyNull) {
val temp = leftIter.collect {
case l if boundCondition(joinedRow.withLeft(l)) =>
joinedRow.copy()
val temp = if (leftIter != null) {
leftIter.collect {
case l if boundCondition(joinedRow.withLeft(l)) =>
joinedRow.copy()
}
} else {
List()
}
if (temp.isEmpty) {
joinedRow.withLeft(leftNullRow).copy :: Nil
Expand Down Expand Up @@ -178,4 +210,16 @@ override def outputPartitioning: Partitioning = joinType match {

hashTable
}

protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = {
if (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys.map(_.dataType))
&& UnsafeProjection.canSupport(buildPlan.output.map(_.dataType))) {
UnsafeHashedRelation(
buildIter,
buildKeys.map(BindReferences.bindReference(_, buildPlan.output)),
buildPlan.schema)
} else {
HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output))
}
}
}
Loading