Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,18 @@ object BroadcastUtils {
val handle = ColumnarBatches.getNativeHandle(BackendsApiManager.getBackendName, b)
numRows += b.numRows()
try {
ColumnarBatchSerializerJniWrapper
val unsafeBuffer = ColumnarBatchSerializerJniWrapper
.create(
Runtimes
.contextInstance(
BackendsApiManager.getBackendName,
"BroadcastUtils#serializeStream"))
.serialize(handle)
try {
unsafeBuffer.toByteArray
} finally {
unsafeBuffer.close()
}
} finally {
ColumnarBatches.release(b)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,15 +172,19 @@ class ColumnarCachedBatchSerializer extends CachedBatchSerializer with Logging {

override def next(): CachedBatch = {
val batch = veloxBatches.next()
val results =
ColumnarBatchSerializerJniWrapper
.create(
Runtimes.contextInstance(
BackendsApiManager.getBackendName,
"ColumnarCachedBatchSerializer#serialize"))
.serialize(
ColumnarBatches.getNativeHandle(BackendsApiManager.getBackendName, batch))
CachedColumnarBatch(batch.numRows(), results.length, results)
val unsafeBuffer = ColumnarBatchSerializerJniWrapper
.create(
Runtimes.contextInstance(
BackendsApiManager.getBackendName,
"ColumnarCachedBatchSerializer#serialize"))
.serialize(ColumnarBatches.getNativeHandle(BackendsApiManager.getBackendName, batch))
val bytes =
try {
unsafeBuffer.toByteArray
} finally {
unsafeBuffer.close()
}
CachedColumnarBatch(batch.numRows(), bytes.length, bytes)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,94 +29,94 @@ import org.apache.spark.unsafe.memory.MemoryAllocator
*
* @param arraySize
* underlying array[array[byte]]'s length
* @param bytesBufferLengths
* underlying array[array[byte]] per bytesBuffer length
* @param byteBufferLengths
* underlying array[array[byte]] per byteBuffer length
* @param totalBytes
* all bytesBuffer's length plus together
* all byteBuffer's length plus together
*/
// scalastyle:off no.finalize
@Experimental
case class UnsafeBytesBufferArray(arraySize: Int, bytesBufferLengths: Array[Int], totalBytes: Long)
case class UnsafeByteBufferArray(arraySize: Int, byteBufferLengths: Array[Int], totalBytes: Long)
extends Logging {
{
assert(
arraySize == bytesBufferLengths.length,
arraySize == byteBufferLengths.length,
"Unsafe buffer array size " +
"not equal to buffer lengths!")
assert(totalBytes >= 0, "Unsafe buffer array total bytes can't be negative!")
}
private val allocatedBytes = (totalBytes + 7) / 8 * 8

/**
* A single array to store all bytesBufferArray's value, it's inited once when first time get
* A single array to store all byteBufferArray's value, it's inited once when first time get
* accessed.
*/
private var longArray: LongArray = _

/** Index the start of each byteBuffer's offset to underlying LongArray's initial position. */
private val bytesBufferOffset = if (bytesBufferLengths.isEmpty) {
private val byteBufferOffset = if (byteBufferLengths.isEmpty) {
new Array(0)
} else {
bytesBufferLengths.init.scanLeft(0L)(_ + _)
byteBufferLengths.init.scanLeft(0L)(_ + _)
}

/**
* Put bytesBuffer at specified array index.
* Put byteBuffer at specified array index.
*
* @param index
* index of the array.
* @param bytesBuffer
* bytesBuffer to put.
* @param byteBuffer
* byteBuffer to put.
*/
def putBytesBuffer(index: Int, bytesBuffer: Array[Byte]): Unit = this.synchronized {
def putByteBuffer(index: Int, byteBuffer: Array[Byte]): Unit = this.synchronized {
assert(index < arraySize)
assert(bytesBuffer.length == bytesBufferLengths(index))
assert(byteBuffer.length == byteBufferLengths(index))
// first to allocate underlying long array
if (null == longArray && index == 0) {
GlobalOffHeapMemory.acquire(allocatedBytes)
longArray = new LongArray(MemoryAllocator.UNSAFE.allocate(allocatedBytes))
}

Platform.copyMemory(
bytesBuffer,
byteBuffer,
Platform.BYTE_ARRAY_OFFSET,
longArray.getBaseObject,
longArray.getBaseOffset + bytesBufferOffset(index),
bytesBufferLengths(index))
longArray.getBaseOffset + byteBufferOffset(index),
byteBufferLengths(index))
}

/**
* Get bytesBuffer at specified index.
* Get byteBuffer at specified index.
* @param index
* @return
*/
def getBytesBuffer(index: Int): Array[Byte] = {
def getByteBuffer(index: Int): Array[Byte] = {
assert(index < arraySize)
if (null == longArray) {
return new Array[Byte](0)
}
val bytes = new Array[Byte](bytesBufferLengths(index))
val bytes = new Array[Byte](byteBufferLengths(index))
Platform.copyMemory(
longArray.getBaseObject,
longArray.getBaseOffset + bytesBufferOffset(index),
longArray.getBaseOffset + byteBufferOffset(index),
bytes,
Platform.BYTE_ARRAY_OFFSET,
bytesBufferLengths(index))
byteBufferLengths(index))
bytes
}

/**
* Get the bytesBuffer memory address and length at specified index, usually used when read memory
* Get the byteBuffer memory address and length at specified index, usually used when read memory
* direct from offheap.
*
* @param index
* @return
*/
def getBytesBufferOffsetAndLength(index: Int): (Long, Int) = {
def getByteBufferOffsetAndLength(index: Int): (Long, Int) = {
assert(index < arraySize)
assert(longArray != null, "The broadcast data in offheap should not be null!")
val offset = longArray.getBaseOffset + bytesBufferOffset(index)
val length = bytesBufferLengths(index)
val offset = longArray.getBaseOffset + byteBufferOffset(index)
val length = byteBufferLengths(index)
(offset, length)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ object UnsafeColumnarBuildSideRelation {
// Keep constructors with BroadcastMode for compatibility
def apply(
output: Seq[Attribute],
batches: UnsafeBytesBufferArray,
batches: UnsafeByteBufferArray,
mode: BroadcastMode): UnsafeColumnarBuildSideRelation = {
val boundMode = mode match {
case HashedRelationBroadcastMode(keys, isNullAware) =>
Expand All @@ -65,7 +65,7 @@ object UnsafeColumnarBuildSideRelation {
}
def apply(
output: Seq[Attribute],
bytesBufferArray: Array[Array[Byte]],
byteBufferArray: Array[Array[Byte]],
mode: BroadcastMode): UnsafeColumnarBuildSideRelation = {
val boundMode = mode match {
case HashedRelationBroadcastMode(keys, isNullAware) =>
Expand All @@ -78,7 +78,7 @@ object UnsafeColumnarBuildSideRelation {
}
new UnsafeColumnarBuildSideRelation(
output,
bytesBufferArray,
byteBufferArray,
BroadcastModeUtils.toSafe(boundMode)
)
}
Expand All @@ -97,7 +97,7 @@ object UnsafeColumnarBuildSideRelation {
@Experimental
case class UnsafeColumnarBuildSideRelation(
private var output: Seq[Attribute],
private var batches: UnsafeBytesBufferArray,
private var batches: UnsafeByteBufferArray,
var safeBroadcastMode: SafeBroadcastMode)
extends BuildSideRelation
with Externalizable
Expand All @@ -118,38 +118,38 @@ case class UnsafeColumnarBuildSideRelation(

/** needed for serialization. */
def this() = {
this(null, null.asInstanceOf[UnsafeBytesBufferArray], null)
this(null, null.asInstanceOf[UnsafeByteBufferArray], null)
}

def this(
output: Seq[Attribute],
bytesBufferArray: Array[Array[Byte]],
byteBufferArray: Array[Array[Byte]],
safeMode: SafeBroadcastMode
) = {
this(
output,
UnsafeBytesBufferArray(
bytesBufferArray.length,
bytesBufferArray.map(_.length),
bytesBufferArray.map(_.length.toLong).sum
UnsafeByteBufferArray(
byteBufferArray.length,
byteBufferArray.map(_.length),
byteBufferArray.map(_.length.toLong).sum
),
safeMode
)
val batchesSize = bytesBufferArray.length
val batchesSize = byteBufferArray.length
for (i <- 0 until batchesSize) {
// copy the bytes to off-heap memory.
batches.putBytesBuffer(i, bytesBufferArray(i))
batches.putByteBuffer(i, byteBufferArray(i))
}
}

override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
out.writeObject(output)
out.writeObject(safeBroadcastMode)
out.writeInt(batches.arraySize)
out.writeObject(batches.bytesBufferLengths)
out.writeObject(batches.byteBufferLengths)
out.writeLong(batches.totalBytes)
for (i <- 0 until batches.arraySize) {
val bytes = batches.getBytesBuffer(i)
val bytes = batches.getByteBuffer(i)
out.write(bytes)
}
}
Expand All @@ -158,10 +158,10 @@ case class UnsafeColumnarBuildSideRelation(
kryo.writeObject(out, output.toList)
kryo.writeClassAndObject(out, safeBroadcastMode)
out.writeInt(batches.arraySize)
kryo.writeObject(out, batches.bytesBufferLengths)
kryo.writeObject(out, batches.byteBufferLengths)
out.writeLong(batches.totalBytes)
for (i <- 0 until batches.arraySize) {
val bytes = batches.getBytesBuffer(i)
val bytes = batches.getByteBuffer(i)
out.write(bytes)
}
}
Expand All @@ -170,7 +170,7 @@ case class UnsafeColumnarBuildSideRelation(
output = in.readObject().asInstanceOf[Seq[Attribute]]
safeBroadcastMode = in.readObject().asInstanceOf[SafeBroadcastMode]
val totalArraySize = in.readInt()
val bytesBufferLengths = in.readObject().asInstanceOf[Array[Int]]
val byteBufferLengths = in.readObject().asInstanceOf[Array[Int]]
val totalBytes = in.readLong()

// scalastyle:off
Expand All @@ -180,30 +180,30 @@ case class UnsafeColumnarBuildSideRelation(
*/
// scalastyle:on

batches = UnsafeBytesBufferArray(totalArraySize, bytesBufferLengths, totalBytes)
batches = UnsafeByteBufferArray(totalArraySize, byteBufferLengths, totalBytes)

for (i <- 0 until totalArraySize) {
val length = bytesBufferLengths(i)
val length = byteBufferLengths(i)
val tmpBuffer = new Array[Byte](length)
in.readFully(tmpBuffer)
batches.putBytesBuffer(i, tmpBuffer)
batches.putByteBuffer(i, tmpBuffer)
}
}

override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException {
output = kryo.readObject(in, classOf[List[_]]).asInstanceOf[Seq[Attribute]]
safeBroadcastMode = kryo.readClassAndObject(in).asInstanceOf[SafeBroadcastMode]
val totalArraySize = in.readInt()
val bytesBufferLengths = kryo.readObject(in, classOf[Array[Int]])
val byteBufferLengths = kryo.readObject(in, classOf[Array[Int]])
val totalBytes = in.readLong()

batches = UnsafeBytesBufferArray(totalArraySize, bytesBufferLengths, totalBytes)
batches = UnsafeByteBufferArray(totalArraySize, byteBufferLengths, totalBytes)

for (i <- 0 until totalArraySize) {
val length = bytesBufferLengths(i)
val length = byteBufferLengths(i)
val tmpBuffer = new Array[Byte](length)
in.read(tmpBuffer)
batches.putBytesBuffer(i, tmpBuffer)
batches.putByteBuffer(i, tmpBuffer)
}
}

Expand Down Expand Up @@ -252,7 +252,7 @@ case class UnsafeColumnarBuildSideRelation(

override def next: ColumnarBatch = {
val (offset, length) =
batches.getBytesBufferOffsetAndLength(batchId)
batches.getByteBufferOffsetAndLength(batchId)
batchId += 1
val handle =
jniWrapper.deserializeDirect(serializerHandle, offset, length)
Expand Down Expand Up @@ -309,7 +309,7 @@ case class UnsafeColumnarBuildSideRelation(
}

override def next(): Iterator[InternalRow] = {
val (offset, length) = batches.getBytesBufferOffsetAndLength(batchId)
val (offset, length) = batches.getByteBufferOffsetAndLength(batchId)
batchId += 1
val batchHandle =
serializerJniWrapper.deserializeDirect(serializerHandle, offset, length)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ class UnsafeColumnarBuildSideRelationTest extends SharedSparkSession {
val totalArraySize = 1
val perArraySize = new Array[Int](totalArraySize)
perArraySize(0) = 10
val bytesArray = UnsafeBytesBufferArray(
val bytesArray = UnsafeByteBufferArray(
1,
perArraySize,
10
)
bytesArray.putBytesBuffer(0, "1234567890".getBytes())
bytesArray.putByteBuffer(0, "1234567890".getBytes())
unsafeRelWithIdentityMode = UnsafeColumnarBuildSideRelation(
output,
bytesArray,
Expand Down
Loading