From 1265fda3c7399a635b985565b0d915d901d48382 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 26 Jan 2017 18:30:25 +0100 Subject: [PATCH] [SPARK-19353][CORE] Generalize PipedRDD to use I/O formats This commit allows to use arbitrary input and output formats when streaming data to and from the piped process. The API uses java.io.Data{Input,Output} for I/O, therefore all method operating on multibyte primitives assume big-endian byte order. The change is fully backward-compatible in terms of both public API and behaviour. Additionally, existing line-based format is available via TextInputWriter/TextOutputReader. --- .../scala/org/apache/spark/rdd/PipedRDD.scala | 188 +++++++++++++----- .../main/scala/org/apache/spark/rdd/RDD.scala | 35 +++- .../org/apache/spark/rdd/PipedRDDSuite.scala | 75 ++++++- 3 files changed, 231 insertions(+), 67 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index 02b28b72fb0e7..3739c036d3524 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -17,41 +17,37 @@ package org.apache.spark.rdd -import java.io.BufferedWriter -import java.io.File -import java.io.FilenameFilter -import java.io.IOException -import java.io.OutputStreamWriter -import java.io.PrintWriter +import java.io._ +import java.util import java.util.StringTokenizer import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.mutable.ArrayBuffer -import scala.io.Source +import scala.io.Codec import scala.reflect.ClassTag +import com.google.common.io.ByteStreams + import org.apache.spark.{Partition, SparkEnv, TaskContext} import org.apache.spark.util.Utils - /** - * An RDD that pipes the contents of each parent partition through an external command - * (printing them one per line) and returns the output as a collection of strings. + * An RDD that pipes the contents of each parent partition through an + * external command and returns the output. */ -private[spark] class PipedRDD[T: ClassTag]( - prev: RDD[T], +private[spark] class PipedRDD[I: ClassTag, O: ClassTag]( + prev: RDD[I], command: Seq[String], envVars: Map[String, String], - printPipeContext: (String => Unit) => Unit, - printRDDElement: (T, String => Unit) => Unit, separateWorkingDir: Boolean, bufferSize: Int, - encoding: String) - extends RDD[String](prev) { + inputWriter: InputWriter[I], + outputReader: OutputReader[O] +) extends RDD[O](prev) { - override def getPartitions: Array[Partition] = firstParent[T].partitions + override def getPartitions: Array[Partition] = firstParent[I].partitions /** * A FilenameFilter that accepts anything that isn't equal to the name passed in. @@ -63,7 +59,7 @@ private[spark] class PipedRDD[T: ClassTag]( } } - override def compute(split: Partition, context: TaskContext): Iterator[String] = { + override def compute(split: Partition, context: TaskContext): Iterator[O] = { val pb = new ProcessBuilder(command.asJava) // Add the environmental variables to the process. val currentEnvVars = pb.environment() @@ -115,17 +111,13 @@ private[spark] class PipedRDD[T: ClassTag]( // Start a thread to print the process's stderr to ours new Thread(s"stderr reader for $command") { override def run(): Unit = { - val err = proc.getErrorStream + val os = System.err try { - for (line <- Source.fromInputStream(err)(encoding).getLines) { - // scalastyle:off println - System.err.println(line) - // scalastyle:on println - } + ByteStreams.copy(proc.getErrorStream, os) } catch { case t: Throwable => childThreadException.set(t) } finally { - err.close() + os.close() } } }.start() @@ -134,54 +126,47 @@ private[spark] class PipedRDD[T: ClassTag]( new Thread(s"stdin writer for $command") { override def run(): Unit = { TaskContext.setTaskContext(context) - val out = new PrintWriter(new BufferedWriter( - new OutputStreamWriter(proc.getOutputStream, encoding), bufferSize)) + val dos = new DataOutputStream( + new BufferedOutputStream(proc.getOutputStream, bufferSize)) try { - // scalastyle:off println - // input the pipe context firstly - if (printPipeContext != null) { - printPipeContext(out.println) + for (elem <- firstParent[I].iterator(split, context)) { + inputWriter.write(dos, elem) } - for (elem <- firstParent[T].iterator(split, context)) { - if (printRDDElement != null) { - printRDDElement(elem, out.println) - } else { - out.println(elem) - } - } - // scalastyle:on println } catch { case t: Throwable => childThreadException.set(t) } finally { - out.close() + dos.close() } } }.start() - // Return an iterator that read lines from the process's stdout - val lines = Source.fromInputStream(proc.getInputStream)(encoding).getLines - new Iterator[String] { - def next(): String = { + val dis = new DataInputStream( + new BufferedInputStream(proc.getInputStream, bufferSize)) + new Iterator[O] { + def next(): O = { if (!hasNext()) { throw new NoSuchElementException() } - lines.next() + + outputReader.read(dis) } def hasNext(): Boolean = { - val result = if (lines.hasNext) { - true - } else { + dis.mark(1) + val eof = dis.read() < 0 + dis.reset() + + if (eof) { val exitStatus = proc.waitFor() cleanup() if (exitStatus != 0) { throw new IllegalStateException(s"Subprocess exited with status $exitStatus. " + s"Command ran: " + command.mkString(" ")) } - false } + propagateChildException() - result + !eof } private def cleanup(): Unit = { @@ -198,10 +183,10 @@ private[spark] class PipedRDD[T: ClassTag]( val t = childThreadException.get() if (t != null) { val commandRan = command.mkString(" ") - logError(s"Caught exception while running pipe() operator. Command ran: $commandRan. " + - s"Exception: ${t.getMessage}") - proc.destroy() + logError("Caught exception while running pipe() operator. " + + s"Command ran: $commandRan.", t) cleanup() + proc.destroy() throw t } } @@ -209,6 +194,103 @@ private[spark] class PipedRDD[T: ClassTag]( } } +/** Specifies how to write the elements of the input [[RDD]] into the pipe. */ +trait InputWriter[T] extends Serializable { + def write(dos: DataOutput, elem: T): Unit +} + +/** Specifies how to read the elements from the pipe into the output [[RDD]]. */ +trait OutputReader[T] extends Serializable { + /** + * Reads the next element. + * + * The input is guaranteed to have at least one byte. + */ + def read(dis: DataInput): T +} + +class TextInputWriter[I]( + encoding: String = Codec.defaultCharsetCodec.name, + printPipeContext: (String => Unit) => Unit = null, + printRDDElement: (I, String => Unit) => Unit = null +) extends InputWriter[I] { + + private[this] val lineSeparator = System.lineSeparator().getBytes(encoding) + private[this] var initialized = printPipeContext == null + + private def writeLine(dos: DataOutput, s: String): Unit = { + dos.write(s.getBytes(encoding)) + dos.write(lineSeparator) + } + + override def write(dos: DataOutput, elem: I): Unit = { + if (!initialized) { + printPipeContext(writeLine(dos, _)) + initialized = true + } + + if (printRDDElement == null) { + writeLine(dos, String.valueOf(elem)) + } else { + printRDDElement(elem, writeLine(dos, _)) + } + } +} + +class TextOutputReader( + encoding: String = Codec.defaultCharsetCodec.name +) extends OutputReader[String] { + + private[this] val lf = "\n".getBytes(encoding) + private[this] val cr = "\r".getBytes(encoding) + private[this] val crlf = cr ++ lf + private[this] var buf = Array.ofDim[Byte](64) + private[this] var used = 0 + + @inline + /** Checks that the suffix of [[buf]] matches [[other]]. */ + private def endsWith(other: Array[Byte]): Boolean = { + var i = used - 1 + var j = other.length - 1 + (j <= i) && { + while (j >= 0) { + if (buf(i) != other(j)) { + return false + } + i -= 1 + j -= 1 + } + true + } + } + + override def read(dis: DataInput): String = { + used = 0 + + try { + do { + val ch = dis.readByte() + if (buf.length <= used) { + buf = util.Arrays.copyOf(buf, used + (used >>> 1)) // 1.5x + } + + buf(used) = ch + used += 1 + } while (!(endsWith(lf) || endsWith(cr))) + + if (endsWith(crlf)) { + used -= crlf.length + } else { // endsWith(lf) || endsWith(cr) + used -= lf.length + } + } catch { + case _: EOFException => + } + + new String(buf, 0, used, encoding) + } +} + private object PipedRDD { // Split a string into words using a standard StringTokenizer def tokenize(command: String): Seq[String] = { diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e524675332d1b..7b12eb971aff1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -774,12 +774,41 @@ abstract class RDD[T: ClassTag]( separateWorkingDir: Boolean = false, bufferSize: Int = 8192, encoding: String = Codec.defaultCharsetCodec.name): RDD[String] = withScope { - new PipedRDD(this, command, env, + val inputWriter = new TextInputWriter[T]( + encoding, if (printPipeContext ne null) sc.clean(printPipeContext) else null, - if (printRDDElement ne null) sc.clean(printRDDElement) else null, + if (printRDDElement ne null) sc.clean(printRDDElement) else null) + val outputReader = new TextOutputReader(encoding) + pipeFormatted(command, env, separateWorkingDir, bufferSize, inputWriter, outputReader) + } + + /** + * Return an RDD created by piping elements to a forked external process. The resulting RDD + * is computed by executing the given process once per partition. All elements + * of each input partition are written to a process's stdin. The resulting partition + * consists of the process's stdout output. + * + * @param command command to run in forked process. + * @param env environment variables to set. + * @param separateWorkingDir Use separate working directories for each task. + * @param bufferSize Buffer size for the stdin writer for the piped process. + * @param inputWriter the format to use for serializing the elements of this RDD into + * the process's stdin. + * @param outputReader the format to use for reading elements into the resulting RDD + * from process's stdout. + * @return the result RDD + */ + def pipeFormatted[O: ClassTag]( + command: Seq[String], + env: Map[String, String] = Map(), + separateWorkingDir: Boolean = false, + bufferSize: Int = 8192, + inputWriter: InputWriter[T], + outputReader: OutputReader[O]): RDD[O] = withScope { + new PipedRDD(this, command, env, separateWorkingDir, bufferSize, - encoding) + inputWriter, outputReader) } /** diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index 1a0eb250e7cdc..1e3e005878d15 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import java.io.File +import java.io.{DataInput, DataOutput, File} import scala.collection.Map import scala.io.Codec @@ -220,16 +220,7 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext { } } val hadoopPart1 = generateFakeHadoopPartition() - val pipedRdd = - new PipedRDD( - nums, - PipedRDD.tokenize(s"$envCommand $varName"), - Map(), - null, - null, - false, - 4092, - Codec.defaultCharsetCodec.name) + val pipedRdd = nums.pipe(s"$envCommand $varName") val tContext = TaskContext.empty() val rddIter = pipedRdd.compute(hadoopPart1, tContext) val arr = rddIter.toArray @@ -244,4 +235,66 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext { new HadoopPartition(sc.newRddId(), 1, split) } + test("pipe works for non-default encoding") { + if (TestUtils.testCommandAvailable("cat")) { + val elems = sc.parallelize(Array("foobar")) + .pipe(Seq("cat"), encoding = "utf-32") + .collect() + + assert(elems.size === 1) + assert(elems.head === "foobar") + } + } + + test("pipe works for null") { + if (TestUtils.testCommandAvailable("cat")) { + val elems = sc.parallelize(Array(null)) + .pipe(Seq("cat")) + .collect() + + assert(elems.size === 1) + assert(elems.head === "null") + } + } + + test("pipe works for rawbytes") { + if (TestUtils.testCommandAvailable("cat")) { + val kv = "foo".getBytes -> "bar".getBytes + val elems = sc.parallelize(Array(kv)).pipeFormatted(Seq("cat"), + inputWriter = new RawBytesInputWriter(), + outputReader = new RawBytesOutputReader() + ).collect() + + assert(elems.size === 1) + val Array((key, value)) = elems + assert(key sameElements kv._1) + assert(value sameElements kv._2) + } + } +} + +class RawBytesInputWriter extends InputWriter[(Array[Byte], Array[Byte])] { + override def write(dos: DataOutput, elem: (Array[Byte], Array[Byte])): Unit = { + val (key, value) = elem + dos.writeInt(key.length) + dos.write(key) + dos.writeInt(value.length) + dos.write(value) + } +} + +class RawBytesOutputReader extends OutputReader[(Array[Byte], Array[Byte])] { + private def readLengthPrefixed(dis: DataInput): Array[Byte] = { + val length = dis.readInt() + assert(length >= 0) + val result = Array.ofDim[Byte](length) + dis.readFully(result) + result + } + + override def read(dis: DataInput): (Array[Byte], Array[Byte]) = { + val key = readLengthPrefixed(dis) + val value = readLengthPrefixed(dis) + key -> value + } }