Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Jul 20, 2015
1 parent 95d0762 commit 6acbb11
Show file tree
Hide file tree
Showing 12 changed files with 146 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,21 @@ public boolean equals(Object other) {
return false;
}

/**
* Returns the underline bytes for this UnsafeRow.
*/
public byte[] getBytes() {
if (baseObject instanceof byte[] && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET
&& (((byte[]) baseObject).length == sizeInBytes)) {
return (byte[]) baseObject;
} else {
byte[] bytes = new byte[sizeInBytes];
PlatformDependent.copyMemory(baseObject, baseOffset, bytes,
PlatformDependent.BYTE_ARRAY_OFFSET, sizeInBytes);
return bytes;
}
}

// This is for debugging
@Override
public String toString(){
Expand All @@ -395,6 +410,6 @@ public String toString(){

@Override
public boolean anyNull() {
return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes);
return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes / 8);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.types._

/**
Expand All @@ -34,7 +33,23 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)

override def toString: String = s"input[$ordinal]"

override def eval(input: InternalRow): Any = input(ordinal)
// Use special getter for primitive types (for UnsafeRow)
override def eval(input: InternalRow): Any = {
if (input.isNullAt(ordinal)) {
null
} else {
dataType match {
case BooleanType => input.getBoolean(ordinal)
case ByteType => input.getByte(ordinal)
case ShortType => input.getShort(ordinal)
case IntegerType | DateType => input.getInt(ordinal)
case LongType | TimestampType => input.getLong(ordinal)
case FloatType => input.getFloat(ordinal)
case DoubleType => input.getDouble(ordinal)
case _ => input.get(ordinal)
}
}
}

override def name: String = s"i[$ordinal]"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ object UnsafeProjection {
def create(exprs: Seq[Expression]): UnsafeProjection = {
GenerateUnsafeProjection.generate(exprs)
}

def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = {
create(exprs.map(BindReferences.bindReference(_, inputSchema)))
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("FORMAT") {
val f = 'f.string.at(0)
val d1 = 'd.int.at(1)
val s1 = 's.int.at(2)
val s1 = 's.string.at(2)

val row1 = create_row("aa%d%s", 12, "cc")
val row2 = create_row(null, 12, "cc")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.concurrent.duration._
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{BindReferences, UnsafeColumnWriter, Expression}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.util.ThreadUtils
Expand Down Expand Up @@ -62,14 +62,7 @@ case class BroadcastHashJoin(
private val broadcastFuture = future {
// Note that we use .execute().collect() because we don't want to convert data to Scala types
val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
val hashed = if (left.codegenEnabled &&
buildKeys.map(_.dataType).forall(UnsafeColumnWriter.canEmbed(_))) {
UnsafeHashedRelation(input.iterator,
buildKeys.map(BindReferences.bindReference(_, buildPlan.output)),
buildPlan.schema)
} else {
HashedRelation(input.iterator, buildSideKeyGenerator, input.length)
}
val hashed = buildHashRelation(input.iterator)
sparkContext.broadcast(hashed)
}(BroadcastHashJoin.broadcastHashJoinExecutionContext)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.sql.execution.joins

import scala.concurrent._
import scala.concurrent.duration._

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
Expand All @@ -26,10 +29,6 @@ import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.util.ThreadUtils

import scala.collection.JavaConversions._
import scala.concurrent._
import scala.concurrent.duration._

/**
* :: DeveloperApi ::
* Performs a outer hash join for two child relations. When the output RDD of this operator is
Expand Down Expand Up @@ -58,28 +57,12 @@ case class BroadcastHashOuterJoin(
override def requiredChildDistribution: Seq[Distribution] =
UnspecifiedDistribution :: UnspecifiedDistribution :: Nil

private[this] lazy val (buildPlan, streamedPlan) = joinType match {
case RightOuter => (left, right)
case LeftOuter => (right, left)
case x =>
throw new IllegalArgumentException(
s"BroadcastHashOuterJoin should not take $x as the JoinType")
}

private[this] lazy val (buildKeys, streamedKeys) = joinType match {
case RightOuter => (leftKeys, rightKeys)
case LeftOuter => (rightKeys, leftKeys)
case x =>
throw new IllegalArgumentException(
s"BroadcastHashOuterJoin should not take $x as the JoinType")
}

@transient
private val broadcastFuture = future {
// Note that we use .execute().collect() because we don't want to convert data to Scala types
val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
// buildHashTable uses code-generated rows as keys, which are not serializable
val hashed = buildHashTable(input.iterator, newProjection(buildKeys, buildPlan.output))
val hashed = buildHashRelation(input.iterator)
//val hashed = buildHashTable(input.iterator, newProjection(buildKeys, buildPlan.output))
sparkContext.broadcast(hashed)
}(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext)

Expand All @@ -96,14 +79,14 @@ case class BroadcastHashOuterJoin(
streamedIter.flatMap(currentRow => {
val rowKey = keyGenerator(currentRow)
joinedRow.withLeft(currentRow)
leftOuterIterator(rowKey, joinedRow, hashTable.getOrElse(rowKey, EMPTY_LIST))
leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey))
})

case RightOuter =>
streamedIter.flatMap(currentRow => {
val rowKey = keyGenerator(currentRow)
joinedRow.withRight(currentRow)
rightOuterIterator(rowKey, hashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow)
rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow)
})

case x =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,14 @@ trait HashJoin {
}
}
}

protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = {
if (self.codegenEnabled && buildKeys.map(_.dataType).forall(UnsafeColumnWriter.canEmbed(_))) {
UnsafeHashedRelation(buildIter,
buildKeys.map(BindReferences.bindReference(_, buildPlan.output)),
buildPlan.schema)
} else {
HashedRelation(buildIter, buildSideKeyGenerator)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ trait HashOuterJoin {
val left: SparkPlan
val right: SparkPlan

override def outputPartitioning: Partitioning = joinType match {
override def outputPartitioning: Partitioning = joinType match {
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
Expand All @@ -59,6 +59,31 @@ override def outputPartitioning: Partitioning = joinType match {
}
}

protected[this] lazy val (buildPlan, streamedPlan) = joinType match {
case RightOuter => (left, right)
case LeftOuter => (right, left)
case x =>
throw new IllegalArgumentException(
s"BroadcastHashOuterJoin should not take $x as the JoinType")
}

protected[this] lazy val (buildKeys, streamedKeys) = joinType match {
case RightOuter => (leftKeys, rightKeys)
case LeftOuter => (rightKeys, leftKeys)
case x =>
throw new IllegalArgumentException(
s"BroadcastHashOuterJoin should not take $x as the JoinType")
}

protected[this] def streamedKeyGenerator(): Projection = {
if (self.codegenEnabled &&
streamedKeys.map(_.dataType).forall(UnsafeColumnWriter.canEmbed(_))) {
UnsafeProjection.create(streamedKeys, streamedPlan.output)
} else {
newProjection(streamedKeys, streamedPlan.output)
}
}

@transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null)
@transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]()

Expand All @@ -77,8 +102,12 @@ override def outputPartitioning: Partitioning = joinType match {
rightIter: Iterable[InternalRow]): Iterator[InternalRow] = {
val ret: Iterable[InternalRow] = {
if (!key.anyNull) {
val temp = rightIter.collect {
case r if boundCondition(joinedRow.withRight(r)) => joinedRow.copy()
val temp = if (rightIter != null) {
rightIter.collect {
case r if boundCondition(joinedRow.withRight(r)) => joinedRow.copy()
}
} else {
List()
}
if (temp.isEmpty) {
joinedRow.withRight(rightNullRow).copy :: Nil
Expand All @@ -98,9 +127,13 @@ override def outputPartitioning: Partitioning = joinType match {
joinedRow: JoinedRow): Iterator[InternalRow] = {
val ret: Iterable[InternalRow] = {
if (!key.anyNull) {
val temp = leftIter.collect {
case l if boundCondition(joinedRow.withLeft(l)) =>
joinedRow.copy()
val temp = if (leftIter != null) {
leftIter.collect {
case l if boundCondition(joinedRow.withLeft(l)) =>
joinedRow.copy()
}
} else {
List()
}
if (temp.isEmpty) {
joinedRow.withLeft(leftNullRow).copy :: Nil
Expand Down Expand Up @@ -179,4 +212,14 @@ override def outputPartitioning: Partitioning = joinType match {

hashTable
}

protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = {
if (self.codegenEnabled && buildKeys.map(_.dataType).forall(UnsafeColumnWriter.canEmbed(_))) {
UnsafeHashedRelation(buildIter,
buildKeys.map(BindReferences.bindReference(_, buildPlan.output)),
buildPlan.schema)
} else {
HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.util.{HashMap => JavaHashMap}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeProjection, UnsafeRow, Projection}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.collection.CompactBuffer
Expand Down Expand Up @@ -158,13 +158,15 @@ private[joins] object HashedRelation {
*/
private[joins] final class UnsafeHashedRelation(
private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]],
private var keyTypes: Array[DataType])
private var keyTypes: Array[DataType],
private var rowTypes: Array[DataType])
extends HashedRelation with Externalizable {

def this() = this(null, null) // Needed for serialization
def this() = this(null, null, null) // Needed for serialization

// UnsafeProjection is not thread safe
@transient lazy val keyProjection = new ThreadLocal[UnsafeProjection]
@transient lazy val fromUnsafeProjection = new ThreadLocal[FromUnsafeProjection]

override def get(key: InternalRow): CompactBuffer[InternalRow] = {
val unsafeKey = if (key.isInstanceOf[UnsafeRow]) {
Expand All @@ -177,18 +179,38 @@ private[joins] final class UnsafeHashedRelation(
}
proj(key)
}
// reply on type erasure in Scala
hashTable.get(unsafeKey).asInstanceOf[CompactBuffer[InternalRow]]

val values = hashTable.get(unsafeKey)
// Return GenericInternalRow to work with other JoinRow, which
// TODO(davies): return UnsafeRow once we have UnsafeJoinRow.
if (values != null) {
var proj = fromUnsafeProjection.get()
if (proj eq null) {
proj = new FromUnsafeProjection(rowTypes)
fromUnsafeProjection.set(proj)
}
var i = 0
val ret = new CompactBuffer[InternalRow]
while (i < values.length) {
ret += proj(values(i)).copy()
i += 1
}
ret
} else {
null
}
}

override def writeExternal(out: ObjectOutput): Unit = {
writeBytes(out, SparkSqlSerializer.serialize(keyTypes))
writeBytes(out, SparkSqlSerializer.serialize(rowTypes))
val bytes = SparkSqlSerializer.serialize(hashTable)
writeBytes(out, bytes)
}

override def readExternal(in: ObjectInput): Unit = {
keyTypes = SparkSqlSerializer.deserialize(readBytes(in))
rowTypes = SparkSqlSerializer.deserialize(readBytes(in))
hashTable = SparkSqlSerializer.deserialize(readBytes(in))
}
}
Expand Down Expand Up @@ -229,7 +251,8 @@ private[joins] object UnsafeHashedRelation {
}
}

val keySchema = buildKey.map(_.dataType).toArray
new UnsafeHashedRelation(hashTable, keySchema)
val keyTypes = buildKey.map(_.dataType).toArray
val rowTypes = rowSchema.fields.map(_.dataType).toArray
new UnsafeHashedRelation(hashTable, keyTypes, rowTypes)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{BindReferences, UnsafeColumnWriter, Expression}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}

Expand All @@ -46,14 +46,7 @@ case class ShuffledHashJoin(
protected override def doExecute(): RDD[InternalRow] = {
val codegenEnabled = left.codegenEnabled
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
val hashed = if (codegenEnabled &&
buildKeys.map(_.dataType).forall(UnsafeColumnWriter.canEmbed(_))) {
UnsafeHashedRelation(buildIter,
buildKeys.map(BindReferences.bindReference(_, buildPlan.output)),
buildPlan.schema)
} else {
HashedRelation(buildIter, buildSideKeyGenerator)
}
val hashed = buildHashRelation(buildIter)
hashJoin(streamIter, hashed)
}
}
Expand Down
Loading

0 comments on commit 6acbb11

Please sign in to comment.