From 842d2de710b8d15d711e52071e80f1835f1377c8 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Fri, 31 Oct 2025 12:31:20 +0100 Subject: [PATCH 1/6] LongHashedRelation off-heap --- .../sql/execution/joins/HashedRelation.scala | 146 +++++++++--------- 1 file changed, 77 insertions(+), 69 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 85c198290542..5f11b253469f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -22,7 +22,7 @@ import java.io._ import com.esotericsoftware.kryo.{Kryo, KryoSerializable} import com.esotericsoftware.kryo.io.{Input, Output} -import org.apache.spark.{SparkConf, SparkEnv, SparkException, SparkUnsupportedOperationException} +import org.apache.spark.{SparkConf, SparkEnv, SparkUnsupportedOperationException} import org.apache.spark.internal.config.{BUFFER_PAGESIZE, MEMORY_OFFHEAP_ENABLED} import org.apache.spark.memory._ import org.apache.spark.sql.catalyst.InternalRow @@ -32,6 +32,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types.LongType import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap +import org.apache.spark.unsafe.memory.MemoryBlock import org.apache.spark.util.{KnownSizeEstimation, Utils} /** @@ -535,7 +536,7 @@ private[execution] final class LongToUnsafeRowMap( val mm: TaskMemoryManager, capacity: Int, ignoresDuplicatedKey: Boolean = false) - extends MemoryConsumer(mm, MemoryMode.ON_HEAP) with Externalizable with KryoSerializable { + extends MemoryConsumer(mm, mm.getTungstenMemoryMode) with Externalizable with KryoSerializable { // Whether the keys are stored in dense mode or not. private var isDense = false @@ -550,15 +551,15 @@ private[execution] final class LongToUnsafeRowMap( // // Sparse mode: [key1] [offset1 | size1] [key2] [offset | size2] ... // Dense mode: [offset1 | size1] [offset2 | size2] - private var array: Array[Long] = null + private var array: UnsafeLongArray = null private var mask: Int = 0 // The page to store all bytes of UnsafeRow and the pointer to next rows. // [row1][pointer1] [row2][pointer2] - private var page: Array[Long] = null + private var page: MemoryBlock = null // Current write cursor in the page. - private var cursor: Long = Platform.LONG_ARRAY_OFFSET + private var cursor: Long = -1 // The number of bits for size in address private val SIZE_BITS = 28 @@ -583,24 +584,15 @@ private[execution] final class LongToUnsafeRowMap( 0) } - private def ensureAcquireMemory(size: Long): Unit = { - // do not support spilling - val got = acquireMemory(size) - if (got < size) { - freeMemory(got) - throw QueryExecutionErrors.cannotAcquireMemoryToBuildLongHashedRelationError(size, got) - } - } - private def init(): Unit = { if (mm != null) { require(capacity < 512000000, "Cannot broadcast 512 million or more rows") var n = 1 while (n < capacity) n *= 2 - ensureAcquireMemory(n * 2L * 8 + (1 << 20)) - array = new Array[Long](n * 2) + array = new UnsafeLongArray(n * 2) mask = n * 2 - 2 - page = new Array[Long](1 << 17) // 1M bytes + page = allocatePage(1 << 20)// 1M bytes + cursor = page.getBaseOffset } } @@ -616,7 +608,7 @@ private[execution] final class LongToUnsafeRowMap( /** * Returns total memory consumption. */ - def getTotalMemoryConsumption: Long = array.length * 8L + page.length * 8L + def getTotalMemoryConsumption: Long = array.length * 8L + page.size() /** * Returns the first slot of array that store the keys (sparse mode). @@ -632,11 +624,11 @@ private[execution] final class LongToUnsafeRowMap( private def nextSlot(pos: Int): Int = (pos + 2) & mask private[this] def toAddress(offset: Long, size: Int): Long = { - ((offset - Platform.LONG_ARRAY_OFFSET) << SIZE_BITS) | size + (offset << SIZE_BITS) | size } private[this] def toOffset(address: Long): Long = { - (address >>> SIZE_BITS) + Platform.LONG_ARRAY_OFFSET + (address >>> SIZE_BITS) } private[this] def toSize(address: Long): Int = { @@ -644,7 +636,7 @@ private[execution] final class LongToUnsafeRowMap( } private def getRow(address: Long, resultRow: UnsafeRow): UnsafeRow = { - resultRow.pointTo(page, toOffset(address), toSize(address)) + resultRow.pointTo(page.getBaseObject, page.getBaseOffset + toOffset(address), toSize(address)) resultRow } @@ -681,8 +673,8 @@ private[execution] final class LongToUnsafeRowMap( override def next(): UnsafeRow = { val offset = toOffset(addr) val size = toSize(addr) - resultRow.pointTo(page, offset, size) - addr = Platform.getLong(page, offset + size) + resultRow.pointTo(page.getBaseObject, page.getBaseOffset + offset, size) + addr = Platform.getLong(page.getBaseObject, page.getBaseOffset + offset + size) resultRow } } @@ -777,12 +769,13 @@ private[execution] final class LongToUnsafeRowMap( // copy the bytes of UnsafeRow val offset = cursor - Platform.copyMemory(row.getBaseObject, row.getBaseOffset, page, cursor, row.getSizeInBytes) + Platform.copyMemory(row.getBaseObject, row.getBaseOffset, page.getBaseObject, cursor, + row.getSizeInBytes) cursor += row.getSizeInBytes - Platform.putLong(page, cursor, 0) + Platform.putLong(page.getBaseObject, cursor, 0) cursor += 8 numValues += 1 - updateIndex(key, pos, toAddress(offset, row.getSizeInBytes)) + updateIndex(key, pos, toAddress(offset - page.getBaseOffset, row.getSizeInBytes)) } private def findKeyPosition(key: Long): Int = { @@ -816,26 +809,24 @@ private[execution] final class LongToUnsafeRowMap( } else { // there are some values for this key, put the address in the front of them. val pointer = toOffset(address) + toSize(address) - Platform.putLong(page, pointer, array(pos + 1)) + Platform.putLong(page.getBaseObject, page.getBaseOffset + pointer, array(pos + 1)) array(pos + 1) = address } } private def grow(inputRowSize: Int): Unit = { // There is 8 bytes for the pointer to next value - val neededNumWords = (cursor - Platform.LONG_ARRAY_OFFSET + 8 + inputRowSize + 7) / 8 - if (neededNumWords > page.length) { + val neededNumWords = (cursor - page.getBaseOffset + 8 + inputRowSize + 7) / 8 + if (neededNumWords > page.size() / 8) { if (neededNumWords > (1 << 30)) { throw QueryExecutionErrors.cannotBuildHashedRelationLargerThan8GError() } - val newNumWords = math.max(neededNumWords, math.min(page.length * 2, 1 << 30)) - ensureAcquireMemory(newNumWords * 8L) - val newPage = new Array[Long](newNumWords.toInt) - Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET, - cursor - Platform.LONG_ARRAY_OFFSET) - val used = page.length + val newNumWords = math.max(neededNumWords, math.min(page.size() / 8 * 2, 1 << 30)) + val newPage = allocatePage(newNumWords.toInt * 8) + Platform.copyMemory(page.getBaseObject, page.getBaseOffset, newPage.getBaseObject, + newPage.getBaseOffset, cursor - page.getBaseOffset) + freePage(page) page = newPage - freeMemory(used * 8L) } } @@ -843,8 +834,7 @@ private[execution] final class LongToUnsafeRowMap( var old_array = array val n = array.length numKeys = 0 - ensureAcquireMemory(n * 2 * 8L) - array = new Array[Long](n * 2) + array = new UnsafeLongArray(n * 2) mask = n * 2 - 2 var i = 0 while (i < old_array.length) { @@ -854,8 +844,8 @@ private[execution] final class LongToUnsafeRowMap( } i += 2 } + old_array.free() old_array = null // release the reference to old array - freeMemory(n * 8L) } /** @@ -866,14 +856,7 @@ private[execution] final class LongToUnsafeRowMap( // Convert to dense mode if it does not require more memory or could fit within L1 cache // SPARK-16740: Make sure range doesn't overflow if minKey has a large negative value if (range >= 0 && (range < array.length || range < 1024)) { - try { - ensureAcquireMemory((range + 1) * 8L) - } catch { - case e: SparkException => - // there is no enough memory to convert - return - } - val denseArray = new Array[Long]((range + 1).toInt) + val denseArray = new UnsafeLongArray((range + 1).toInt) var i = 0 while (i < array.length) { if (array(i + 1) > 0) { @@ -882,10 +865,9 @@ private[execution] final class LongToUnsafeRowMap( } i += 2 } - val old_length = array.length + array.free() array = denseArray isDense = true - freeMemory(old_length * 8L) } } @@ -894,25 +876,26 @@ private[execution] final class LongToUnsafeRowMap( */ def free(): Unit = { if (page != null) { - freeMemory(page.length * 8L) + freePage(page) page = null } if (array != null) { - freeMemory(array.length * 8L) + array.free() array = null } } - private def writeLongArray( + private def writeBytes( writeBuffer: (Array[Byte], Int, Int) => Unit, - arr: Array[Long], + baseObject: Object, + baseOffset: Long, len: Int): Unit = { val buffer = new Array[Byte](4 << 10) - var offset: Long = Platform.LONG_ARRAY_OFFSET - val end = len * 8L + Platform.LONG_ARRAY_OFFSET + var offset: Long = baseOffset + val end = len * 8L + offset while (offset < end) { val size = Math.min(buffer.length, end - offset) - Platform.copyMemory(arr, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size) + Platform.copyMemory(baseObject, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size) writeBuffer(buffer, 0, size.toInt) offset += size } @@ -929,10 +912,11 @@ private[execution] final class LongToUnsafeRowMap( writeLong(numValues) writeLong(array.length) - writeLongArray(writeBuffer, array, array.length) - val used = ((cursor - Platform.LONG_ARRAY_OFFSET) / 8).toInt + writeBytes(writeBuffer, + array.memoryBlock.getBaseObject, array.memoryBlock.getBaseOffset, array.length) + val used = ((cursor - page.getBaseOffset) / 8).toInt writeLong(used) - writeLongArray(writeBuffer, page, used) + writeBytes(writeBuffer, page.getBaseObject, page.getBaseOffset, used) } override def writeExternal(output: ObjectOutput): Unit = { @@ -943,20 +927,20 @@ private[execution] final class LongToUnsafeRowMap( write(out.writeBoolean, out.writeLong, out.write) } - private def readLongArray( + private def readData( readBuffer: (Array[Byte], Int, Int) => Unit, - length: Int): Array[Long] = { - val array = new Array[Long](length) + baseObject: Object, + baseOffset: Long, + length: Int): Unit = { val buffer = new Array[Byte](4 << 10) - var offset: Long = Platform.LONG_ARRAY_OFFSET - val end = length * 8L + Platform.LONG_ARRAY_OFFSET + var offset: Long = baseOffset + val end = length * 8L + baseOffset while (offset < end) { val size = Math.min(buffer.length, end - offset) readBuffer(buffer, 0, size.toInt) - Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, array, offset, size) + Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, baseObject, offset, size) offset += size } - array } private def read( @@ -971,11 +955,15 @@ private[execution] final class LongToUnsafeRowMap( val length = readLong().toInt mask = length - 2 - array = readLongArray(readBuffer, length) + array.free() + array = new UnsafeLongArray(length) + readData(readBuffer, array.memoryBlock.getBaseObject, array.memoryBlock.getBaseOffset, length) val pageLength = readLong().toInt - page = readLongArray(readBuffer, pageLength) + freePage(page) + page = allocatePage(pageLength * 8) + readData(readBuffer, page.getBaseObject, page.getBaseOffset, pageLength) // Restore cursor variable to make this map able to be serialized again on executors. - cursor = pageLength * 8 + Platform.LONG_ARRAY_OFFSET + cursor = pageLength * 8 + page.getBaseOffset } override def readExternal(in: ObjectInput): Unit = { @@ -985,6 +973,26 @@ private[execution] final class LongToUnsafeRowMap( override def read(kryo: Kryo, in: Input): Unit = { read(() => in.readBoolean(), () => in.readLong(), in.readBytes) } + + private class UnsafeLongArray(val length: Int) { + val memoryBlock = allocatePage(length * 8) + + for (i <- 0 until length) { + update(i, 0) + } + + def apply(index: Int): Long = { + Platform.getLong(memoryBlock.getBaseObject, memoryBlock.getBaseOffset + index * 8) + } + + def update(index: Int, value: Long): Unit = { + Platform.putLong(memoryBlock.getBaseObject, memoryBlock.getBaseOffset + index * 8, value) + } + + def free(): Unit = { + freePage(memoryBlock) + } + } } class LongHashedRelation( From 0709a236bdbfa416a4a0f0f1819a02479b939cce Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Mon, 10 Nov 2025 19:02:25 +0100 Subject: [PATCH 2/6] fixup --- .../execution/joins/HashedRelationSuite.scala | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index b88a76bbfb57..0e83e9fda512 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -25,9 +25,10 @@ import scala.util.Random import org.apache.spark.SparkConf import org.apache.spark.SparkException -import org.apache.spark.internal.config._ +import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE} import org.apache.spark.internal.config.Kryo._ import org.apache.spark.memory.{TaskMemoryManager, UnifiedMemoryManager} +import org.apache.spark.network.util.ByteUnit import org.apache.spark.serializer.KryoSerializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -39,9 +40,13 @@ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.collection.CompactBuffer -class HashedRelationSuite extends SharedSparkSession { +abstract class HashedRelationSuite extends SharedSparkSession { + protected def useOffHeapMemoryMode: Boolean + val umm = new UnifiedMemoryManager( - new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), + new SparkConf() + .set(MEMORY_OFFHEAP_ENABLED, useOffHeapMemoryMode) + .set(MEMORY_OFFHEAP_SIZE, ByteUnit.GiB.toBytes(1L)), Long.MaxValue, Long.MaxValue / 2, 1) @@ -754,3 +759,11 @@ class HashedRelationSuite extends SharedSparkSession { } } } + +class HashedRelationOnHeapSuite extends HashedRelationSuite { + override protected def useOffHeapMemoryMode: Boolean = true +} + +class HashedRelationOffHeapSuite extends HashedRelationSuite { + override protected def useOffHeapMemoryMode: Boolean = false +} From d533ff452e9cfb51478d6b6245927618837c1a48 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Tue, 11 Nov 2025 14:50:12 +0100 Subject: [PATCH 3/6] fixup --- .../org/apache/spark/sql/errors/QueryExecutionErrors.scala | 7 ------- 1 file changed, 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 1f7d2a149a7b..157f094c13c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -1176,13 +1176,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE cause = null) } - def cannotAcquireMemoryToBuildLongHashedRelationError(size: Long, got: Long): Throwable = { - new SparkException( - errorClass = "_LEGACY_ERROR_TEMP_2106", - messageParameters = Map("size" -> size.toString(), "got" -> got.toString()), - cause = null) - } - def cannotAcquireMemoryToBuildUnsafeHashedRelationError(): Throwable = { new SparkOutOfMemoryError( "_LEGACY_ERROR_TEMP_2107", From 745b73ae4fc8d4034d5f134e6eba76192c17cb48 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Tue, 11 Nov 2025 15:00:21 +0100 Subject: [PATCH 4/6] fixup --- .../apache/spark/sql/execution/joins/HashedRelation.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 5f11b253469f..62b8b5f1de0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -816,7 +816,8 @@ private[execution] final class LongToUnsafeRowMap( private def grow(inputRowSize: Int): Unit = { // There is 8 bytes for the pointer to next value - val neededNumWords = (cursor - page.getBaseOffset + 8 + inputRowSize + 7) / 8 + val usedBytes = cursor - page.getBaseOffset + val neededNumWords = (usedBytes + 8 + inputRowSize + 7) / 8 if (neededNumWords > page.size() / 8) { if (neededNumWords > (1 << 30)) { throw QueryExecutionErrors.cannotBuildHashedRelationLargerThan8GError() @@ -824,9 +825,10 @@ private[execution] final class LongToUnsafeRowMap( val newNumWords = math.max(neededNumWords, math.min(page.size() / 8 * 2, 1 << 30)) val newPage = allocatePage(newNumWords.toInt * 8) Platform.copyMemory(page.getBaseObject, page.getBaseOffset, newPage.getBaseObject, - newPage.getBaseOffset, cursor - page.getBaseOffset) + newPage.getBaseOffset, usedBytes) freePage(page) page = newPage + cursor = page.getBaseOffset + usedBytes } } From b9b24d60e0ad4b926fb06e1c22788db6bcf6918e Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Wed, 12 Nov 2025 14:06:24 +0100 Subject: [PATCH 5/6] fixup --- .../org/apache/spark/sql/execution/joins/HashedRelation.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 62b8b5f1de0d..95224033f390 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -977,7 +977,7 @@ private[execution] final class LongToUnsafeRowMap( } private class UnsafeLongArray(val length: Int) { - val memoryBlock = allocatePage(length * 8) + val memoryBlock: MemoryBlock = allocatePage(length * 8) for (i <- 0 until length) { update(i, 0) From c8d965ceceeb8768ca25915009c6eb082404347a Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Wed, 12 Nov 2025 14:26:51 +0100 Subject: [PATCH 6/6] fixup --- .../spark/sql/execution/joins/HashedRelationSuite.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 0e83e9fda512..e3ccc5333d49 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -146,6 +146,9 @@ abstract class HashedRelationSuite extends SharedSparkSession { Seq(BoundReference(0, LongType, false), BoundReference(1, IntegerType, true))) val rows = (0 until 100).map(i => unsafeProj(InternalRow(Int.int2long(i), i + 1)).copy()) val key = Seq(BoundReference(0, LongType, false)) + while (true) { + LongHashedRelation(rows.iterator, key, 10, mm) + } val longRelation = LongHashedRelation(rows.iterator, key, 10, mm) assert(longRelation.keyIsUnique) (0 until 100).foreach { i => @@ -761,9 +764,9 @@ abstract class HashedRelationSuite extends SharedSparkSession { } class HashedRelationOnHeapSuite extends HashedRelationSuite { - override protected def useOffHeapMemoryMode: Boolean = true + override protected def useOffHeapMemoryMode: Boolean = false } class HashedRelationOffHeapSuite extends HashedRelationSuite { - override protected def useOffHeapMemoryMode: Boolean = false + override protected def useOffHeapMemoryMode: Boolean = true }