Skip to content
Permalink
Browse files

[PYSPARK][SQL] Updates to RowQueue

Tested with updates to RowQueueSuite
  • Loading branch information...
squito committed Sep 6, 2018
1 parent 09dd34c commit 6d742d1bd71aa3803dce91a830b37284cb18cf70
@@ -21,9 +21,10 @@ import java.io._

import com.google.common.io.Closeables

import org.apache.spark.SparkException
import org.apache.spark.{SparkEnv, SparkException}
import org.apache.spark.io.NioBufferedFileInputStream
import org.apache.spark.memory.{MemoryConsumer, TaskMemoryManager}
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.memory.MemoryBlock
@@ -108,9 +109,13 @@ private[python] abstract class InMemoryRowQueue(val page: MemoryBlock, numFields
* A RowQueue that is backed by a file on disk. This queue will stop accepting new rows once any
* reader has begun reading from the queue.
*/
private[python] case class DiskRowQueue(file: File, fields: Int) extends RowQueue {
private var out = new DataOutputStream(
new BufferedOutputStream(new FileOutputStream(file.toString)))
private[python] case class DiskRowQueue(
file: File,
fields: Int,
serMgr: SerializerManager) extends RowQueue {

private var out = new DataOutputStream(serMgr.wrapForEncryption(
new BufferedOutputStream(new FileOutputStream(file.toString))))
private var unreadBytes = 0L

private var in: DataInputStream = _
@@ -131,7 +136,8 @@ private[python] case class DiskRowQueue(file: File, fields: Int) extends RowQueu
if (out != null) {
out.close()
out = null
in = new DataInputStream(new NioBufferedFileInputStream(file))
in = new DataInputStream(serMgr.wrapForEncryption(
new NioBufferedFileInputStream(file)))
}

if (unreadBytes > 0) {
@@ -166,7 +172,8 @@ private[python] case class DiskRowQueue(file: File, fields: Int) extends RowQueu
private[python] case class HybridRowQueue(
memManager: TaskMemoryManager,
tempDir: File,
numFields: Int)
numFields: Int,
serMgr: SerializerManager)
extends MemoryConsumer(memManager) with RowQueue {

// Each buffer should have at least one row
@@ -212,7 +219,7 @@ private[python] case class HybridRowQueue(
}

private def createDiskQueue(): RowQueue = {
DiskRowQueue(File.createTempFile("buffer", "", tempDir), numFields)
DiskRowQueue(File.createTempFile("buffer", "", tempDir), numFields, serMgr)
}

private def createNewQueue(required: Long): RowQueue = {
@@ -279,3 +286,9 @@ private[python] case class HybridRowQueue(
}
}
}

private[python] object HybridRowQueue {
def apply(taskMemoryMgr: TaskMemoryManager, file: File, fields: Int): HybridRowQueue = {
HybridRowQueue(taskMemoryMgr, file, fields, SparkEnv.get.serializerManager)
}
}
@@ -20,12 +20,15 @@ package org.apache.spark.sql.execution.python
import java.io.File

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.memory.{MemoryManager, TaskMemoryManager, TestMemoryManager}
import org.apache.spark.internal.config._
import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager}
import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite}
import org.apache.spark.serializer.{JavaSerializer, SerializerManager}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.unsafe.memory.MemoryBlock
import org.apache.spark.util.Utils

class RowQueueSuite extends SparkFunSuite {
class RowQueueSuite extends SparkFunSuite with EncryptionFunSuite {

test("in-memory queue") {
val page = MemoryBlock.fromLongArray(new Array[Long](1<<10))
@@ -53,10 +56,20 @@ class RowQueueSuite extends SparkFunSuite {
queue.close()
}

test("disk queue") {
private def createSerializerManager(conf: SparkConf): SerializerManager = {
val ioEncryptionKey = if (conf.get(IO_ENCRYPTION_ENABLED)) {
Some(CryptoStreamUtils.createKey(conf))
} else {
None
}
new SerializerManager(new JavaSerializer(conf), conf, ioEncryptionKey)
}

encryptionTest("disk queue") { conf =>
val serManager = createSerializerManager(conf)
val dir = Utils.createTempDir().getCanonicalFile
dir.mkdirs()
val queue = DiskRowQueue(new File(dir, "buffer"), 1)
val queue = DiskRowQueue(new File(dir, "buffer"), 1, serManager)
val row = new UnsafeRow(1)
row.pointTo(new Array[Byte](16), 16)
val n = 1000
@@ -81,11 +94,12 @@ class RowQueueSuite extends SparkFunSuite {
queue.close()
}

test("hybrid queue") {
val mem = new TestMemoryManager(new SparkConf())
encryptionTest("hybrid queue") { conf =>
val serManager = createSerializerManager(conf)
val mem = new TestMemoryManager(conf)
mem.limit(4<<10)
val taskM = new TaskMemoryManager(mem, 0)
val queue = HybridRowQueue(taskM, Utils.createTempDir().getCanonicalFile, 1)
val queue = HybridRowQueue(taskM, Utils.createTempDir().getCanonicalFile, 1, serManager)
val row = new UnsafeRow(1)
row.pointTo(new Array[Byte](16), 16)
val n = (4<<10) / 16 * 3

0 comments on commit 6d742d1

Please sign in to comment.
You can’t perform that action at this time.