diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 242185e803577..75797c02c8b5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -125,6 +125,16 @@ private[execution] sealed trait HashedRelation extends KnownSizeEstimation { private[execution] object HashedRelation { + def createTaskMemoryManager(): TaskMemoryManager = { + new TaskMemoryManager( + new UnifiedMemoryManager( + new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), + Runtime.getRuntime.maxMemory, + Runtime.getRuntime.maxMemory / 2, + 1), + 0) + } + /** * Create a HashedRelation from an Iterator of InternalRow. * @@ -142,13 +152,7 @@ private[execution] object HashedRelation { allowsNullKey: Boolean = false, ignoresDuplicatedKey: Boolean = false): HashedRelation = { val mm = Option(taskMemoryManager).getOrElse { - new TaskMemoryManager( - new UnifiedMemoryManager( - new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), - Runtime.getRuntime.maxMemory, - Runtime.getRuntime.maxMemory / 2, - 1), - 0) + createTaskMemoryManager() } if (!input.hasNext && !allowsNullKey) { @@ -400,13 +404,7 @@ private[joins] class UnsafeHashedRelation( // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory // TODO(josh): This needs to be revisited before we merge this patch; making this change now // so that tests compile: - val taskMemoryManager = new TaskMemoryManager( - new UnifiedMemoryManager( - new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), - Runtime.getRuntime.maxMemory, - Runtime.getRuntime.maxMemory / 2, - 1), - 0) + val taskMemoryManager = HashedRelation.createTaskMemoryManager() val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes) .getOrElse(new SparkConf().get(BUFFER_PAGESIZE).getOrElse(16L * 1024 * 1024)) @@ -574,15 +572,7 @@ private[execution] final class LongToUnsafeRowMap( // needed by serializer def this() = { - this( - new TaskMemoryManager( - new UnifiedMemoryManager( - new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), - Runtime.getRuntime.maxMemory, - Runtime.getRuntime.maxMemory / 2, - 1), - 0), - 0) + this(HashedRelation.createTaskMemoryManager(), 0) } private def init(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index f8fa2f5fe35f4..38c07aee2d74d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.SparkConf import org.apache.spark.SparkException import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE} import org.apache.spark.internal.config.Kryo._ -import org.apache.spark.memory.{TaskMemoryManager, UnifiedMemoryManager} +import org.apache.spark.memory.{TaskMemoryManager, TestMemoryConsumer, UnifiedMemoryManager} import org.apache.spark.network.util.ByteUnit import org.apache.spark.serializer.KryoSerializer import org.apache.spark.sql.catalyst.InternalRow @@ -758,6 +758,13 @@ abstract class HashedRelationSuite extends SharedSparkSession { map.free() } } + + test("Verify TaskMemoryManager creation") { + val taskMemoryManager = HashedRelation.createTaskMemoryManager() + val testMemoryConsumer = new TestMemoryConsumer(taskMemoryManager) + val memoryPage = taskMemoryManager.allocatePage(100, testMemoryConsumer) + assert(memoryPage.size() > 0) + } } class HashedRelationOnHeapSuite extends HashedRelationSuite {