-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-9024] Unsafe HashJoin/HashOuterJoin/HashSemiJoin #7480
Changes from 6 commits
bea4a50
95d0762
6acbb11
184b852
a6c0b7d
60371f2
ab1690f
1a40f02
69e38f5
0f4380d
68f5cd9
ca2b40f
9481ae8
611d2ed
a05b4f6
84c9807
dede020
10583f1
6294b1e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,10 +17,11 @@ | |
|
||
package org.apache.spark.sql.catalyst.expressions; | ||
|
||
import org.apache.spark.sql.catalyst.InternalRow; | ||
import org.apache.spark.sql.catalyst.util.ObjectPool; | ||
import org.apache.spark.unsafe.PlatformDependent; | ||
import org.apache.spark.unsafe.array.ByteArrayMethods; | ||
import org.apache.spark.unsafe.bitset.BitSetMethods; | ||
import org.apache.spark.unsafe.hash.Murmur3_x86_32; | ||
import org.apache.spark.unsafe.types.UTF8String; | ||
|
||
|
||
|
@@ -345,7 +346,7 @@ public double getDouble(int i) { | |
* This method is only supported on UnsafeRows that do not use ObjectPools. | ||
*/ | ||
@Override | ||
public InternalRow copy() { | ||
public UnsafeRow copy() { | ||
if (pool != null) { | ||
throw new UnsupportedOperationException( | ||
"Copy is not supported for UnsafeRows that use object pools"); | ||
|
@@ -365,8 +366,50 @@ public InternalRow copy() { | |
} | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 42); | ||
} | ||
|
||
@Override | ||
public boolean equals(Object other) { | ||
if (other instanceof UnsafeRow) { | ||
UnsafeRow o = (UnsafeRow) other; | ||
return ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset, | ||
sizeInBytes); | ||
} | ||
return false; | ||
} | ||
|
||
/** | ||
* Returns the underline bytes for this UnsafeRow. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "underline" -> "underlying" |
||
*/ | ||
public byte[] getBytes() { | ||
if (baseObject instanceof byte[] && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a nice optimization! |
||
&& (((byte[]) baseObject).length == sizeInBytes)) { | ||
return (byte[]) baseObject; | ||
} else { | ||
byte[] bytes = new byte[sizeInBytes]; | ||
PlatformDependent.copyMemory(baseObject, baseOffset, bytes, | ||
PlatformDependent.BYTE_ARRAY_OFFSET, sizeInBytes); | ||
return bytes; | ||
} | ||
} | ||
|
||
// This is for debugging | ||
@Override | ||
public String toString(){ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Style nit: space after |
||
StringBuilder build = new StringBuilder("["); | ||
for (int i = 0; i < sizeInBytes; i += 8) { | ||
build.append(PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + i)); | ||
build.append(','); | ||
} | ||
build.append(']'); | ||
return build.toString(); | ||
} | ||
|
||
@Override | ||
public boolean anyNull() { | ||
return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes); | ||
return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes / 8); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you add a unit test for this? i'd imagine it affects correctness |
||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -83,19 +83,32 @@ abstract class UnsafeProjection extends Projection { | |
} | ||
|
||
object UnsafeProjection { | ||
def canSupport(schema: StructType): Boolean = canSupport(schema.fields.map(_.dataType)) | ||
def canSupport(types: Seq[DataType]): Boolean = types.forall(UnsafeColumnWriter.canEmbed(_)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could even add a |
||
|
||
def create(schema: StructType): UnsafeProjection = create(schema.fields.map(_.dataType)) | ||
|
||
def create(fields: Seq[DataType]): UnsafeProjection = { | ||
def create(fields: Array[DataType]): UnsafeProjection = { | ||
val exprs = fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true)) | ||
create(exprs) | ||
} | ||
|
||
def create(exprs: Seq[Expression]): UnsafeProjection = { | ||
GenerateUnsafeProjection.generate(exprs) | ||
} | ||
|
||
def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = { | ||
create(exprs.map(BindReferences.bindReference(_, inputSchema))) | ||
} | ||
} | ||
|
||
/** | ||
* A projection that could turn UnsafeRow into GenericInternalRow | ||
*/ | ||
case class FromUnsafeProjection(fields: Seq[DataType]) extends Projection { | ||
|
||
def this(schema: StructType) = this(schema.fields.map(_.dataType)) | ||
|
||
private[this] val expressions = fields.zipWithIndex.map { case (dt, idx) => | ||
new BoundReference(idx, dt, true) | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,16 +40,15 @@ case class BroadcastLeftSemiJoinHash( | |
val buildIter = right.execute().map(_.copy()).collect().toIterator | ||
|
||
if (condition.isEmpty) { | ||
// rowKey may be not serializable (from codegen) | ||
val hashSet = buildKeyHashSet(buildIter, copy = true) | ||
val hashSet = buildKeyHashSet(buildIter) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Orthogonal to this patch, we should work on removing BroadcastLeftSemiJoinHash, and just use an equi-join. Otherwise we have too many paths we need to optimize for. |
||
val broadcastedRelation = sparkContext.broadcast(hashSet) | ||
|
||
left.execute().mapPartitions { streamIter => | ||
hashSemiJoin(streamIter, broadcastedRelation.value) | ||
} | ||
} else { | ||
val hashRelation = HashedRelation(buildIter, rightKeyGenerator) | ||
val broadcastedRelation = sparkContext.broadcast(hashRelation) | ||
val hashed = buildHashRelation(buildIter) | ||
val broadcastedRelation = sparkContext.broadcast(hashed) | ||
|
||
left.execute().mapPartitions { streamIter => | ||
hashSemiJoin(streamIter, broadcastedRelation.value) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,7 +38,7 @@ trait HashOuterJoin { | |
val left: SparkPlan | ||
val right: SparkPlan | ||
|
||
override def outputPartitioning: Partitioning = joinType match { | ||
override def outputPartitioning: Partitioning = joinType match { | ||
case LeftOuter => left.outputPartitioning | ||
case RightOuter => right.outputPartitioning | ||
case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) | ||
|
@@ -59,6 +59,30 @@ override def outputPartitioning: Partitioning = joinType match { | |
} | ||
} | ||
|
||
protected[this] lazy val (buildPlan, streamedPlan) = joinType match { | ||
case RightOuter => (left, right) | ||
case LeftOuter => (right, left) | ||
case x => | ||
throw new IllegalArgumentException( | ||
s"BroadcastHashOuterJoin should not take $x as the JoinType") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this code is now in |
||
} | ||
|
||
protected[this] lazy val (buildKeys, streamedKeys) = joinType match { | ||
case RightOuter => (leftKeys, rightKeys) | ||
case LeftOuter => (rightKeys, leftKeys) | ||
case x => | ||
throw new IllegalArgumentException( | ||
s"BroadcastHashOuterJoin should not take $x as the JoinType") | ||
} | ||
|
||
protected[this] def streamedKeyGenerator(): Projection = { | ||
if (self.codegenEnabled && UnsafeProjection.canSupport(streamedKeys.map(_.dataType))) { | ||
UnsafeProjection.create(streamedKeys, streamedPlan.output) | ||
} else { | ||
newProjection(streamedKeys, streamedPlan.output) | ||
} | ||
} | ||
|
||
@transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null) | ||
@transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]() | ||
|
||
|
@@ -76,8 +100,12 @@ override def outputPartitioning: Partitioning = joinType match { | |
rightIter: Iterable[InternalRow]): Iterator[InternalRow] = { | ||
val ret: Iterable[InternalRow] = { | ||
if (!key.anyNull) { | ||
val temp = rightIter.collect { | ||
case r if boundCondition(joinedRow.withRight(r)) => joinedRow.copy() | ||
val temp = if (rightIter != null) { | ||
rightIter.collect { | ||
case r if boundCondition(joinedRow.withRight(r)) => joinedRow.copy() | ||
} | ||
} else { | ||
List() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that you can use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given that the old code seemed to make a special point of using an |
||
} | ||
if (temp.isEmpty) { | ||
joinedRow.withRight(rightNullRow).copy :: Nil | ||
|
@@ -97,9 +125,13 @@ override def outputPartitioning: Partitioning = joinType match { | |
joinedRow: JoinedRow): Iterator[InternalRow] = { | ||
val ret: Iterable[InternalRow] = { | ||
if (!key.anyNull) { | ||
val temp = leftIter.collect { | ||
case l if boundCondition(joinedRow.withLeft(l)) => | ||
joinedRow.copy() | ||
val temp = if (leftIter != null) { | ||
leftIter.collect { | ||
case l if boundCondition(joinedRow.withLeft(l)) => | ||
joinedRow.copy() | ||
} | ||
} else { | ||
List() | ||
} | ||
if (temp.isEmpty) { | ||
joinedRow.withLeft(leftNullRow).copy :: Nil | ||
|
@@ -178,4 +210,16 @@ override def outputPartitioning: Partitioning = joinType match { | |
|
||
hashTable | ||
} | ||
|
||
protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { | ||
if (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys.map(_.dataType)) | ||
&& UnsafeProjection.canSupport(buildPlan.output.map(_.dataType))) { | ||
UnsafeHashedRelation( | ||
buildIter, | ||
buildKeys.map(BindReferences.bindReference(_, buildPlan.output)), | ||
buildPlan.schema) | ||
} else { | ||
HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output)) | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that we should check whether the rows'
sizeInBytes
are equal before attempting to compare their contents.