From ac1a40b3f04fd934ed2c4e82f7c62c28c4059e35 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 11 Mar 2016 15:18:20 -0800 Subject: [PATCH 1/4] fast serialization for collecting DataFrame/Dataset --- .../org/apache/spark/sql/DataFrame.scala | 4 +- .../spark/sql/execution/SparkPlan.scala | 58 ++++++++++++++++++- 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index f1791e6943bb7..69c79d25c49b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1803,14 +1803,14 @@ class Dataset[T] private[sql]( */ def collectAsList(): java.util.List[T] = withCallback("collectAsList", toDF()) { _ => withNewExecutionId { - val values = queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow) + val values = queryExecution.executedPlan.executeCollect().map(boundTEncoder.fromRow) java.util.Arrays.asList(values : _*) } } private def collect(needCallback: Boolean): Array[T] = { def execute(): Array[T] = withNewExecutionId { - queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow) + queryExecution.executedPlan.executeCollect().map(boundTEncoder.fromRow) } if (needCallback) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 3be4cce045fea..1b35789b8856a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import java.nio.ByteBuffer import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable.ArrayBuffer @@ -34,6 +35,7 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetric} import org.apache.spark.sql.types.DataType +import org.apache.spark.unsafe.Platform import org.apache.spark.util.ThreadUtils /** @@ -220,7 +222,61 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * Runs this query returning the result as an array. */ def executeCollect(): Array[InternalRow] = { - execute().map(_.copy()).collect() + // Packing the UnsafeRows into byte array for faster serialization. + // The byte arrays are in the following format: + // [size] [bytes of UnsafeRow] [size] [bytes of UnsafeRow] ... [-1] + val byteArrayRdd = execute().mapPartitionsInternal { iter => + new Iterator[Array[Byte]] { + private var row: UnsafeRow = _ + override def hasNext: Boolean = row != null || iter.hasNext + override def next: Array[Byte] = { + var cap = 1 << 20 // 1 MB + if (row != null) { + // the buffered row could be larger than default buffer size + cap = Math.max(cap, 4 + row.getSizeInBytes + 4) // reverse 4 bytes for ending mark (-1). + } + val buffer = ByteBuffer.allocate(cap) + if (row != null) { + buffer.putInt(row.getSizeInBytes) + row.writeTo(buffer) + row = null + } + while (iter.hasNext) { + row = iter.next().asInstanceOf[UnsafeRow] + // Reserve last 4 bytes for ending mark + if (4 + row.getSizeInBytes + 4 <= buffer.remaining()) { + buffer.putInt(row.getSizeInBytes) + row.writeTo(buffer) + row = null + } else { + buffer.putInt(-1) + return buffer.array() + } + } + buffer.putInt(-1) + // copy the used bytes to make it smaller + val bytes = new Array[Byte](buffer.limit()) + System.arraycopy(buffer.array(), 0, bytes, 0, buffer.limit()) + bytes + } + } + } + // Collect the byte arrays back to driver, then decode them as UnsafeRows. + val nFields = schema.length + byteArrayRdd.collect().flatMap { bytes => + val buffer = ByteBuffer.wrap(bytes) + new Iterator[InternalRow] { + private var sizeInBytes = buffer.getInt() + override def hasNext: Boolean = sizeInBytes >= 0 + override def next: InternalRow = { + val row = new UnsafeRow(nFields) + row.pointTo(buffer.array(), Platform.BYTE_ARRAY_OFFSET + buffer.position(), sizeInBytes) + buffer.position(buffer.position() + sizeInBytes) + sizeInBytes = buffer.getInt() + row + } + } + } } /** From a85939264610870d41402225ba8983b101814476 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 14 Mar 2016 13:55:18 -0700 Subject: [PATCH 2/4] compress the bytes --- .../spark/sql/execution/SparkPlan.scala | 83 ++++++++----------- .../BenchmarkWholeStageCodegen.scala | 25 ++++++ 2 files changed, 59 insertions(+), 49 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 842c018296c49..36b393ebc0224 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -17,15 +17,15 @@ package org.apache.spark.sql.execution -import java.nio.ByteBuffer +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable.ArrayBuffer import scala.concurrent.{Await, ExecutionContext, Future} import scala.concurrent.duration._ -import org.apache.spark.Logging -import org.apache.spark.broadcast +import org.apache.spark.{Logging, SparkEnv, broadcast} +import org.apache.spark.io.CompressionCodec import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} @@ -35,7 +35,6 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetric} import org.apache.spark.sql.types.DataType -import org.apache.spark.unsafe.Platform import org.apache.spark.util.ThreadUtils /** @@ -225,58 +224,44 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ // Packing the UnsafeRows into byte array for faster serialization. // The byte arrays are in the following format: // [size] [bytes of UnsafeRow] [size] [bytes of UnsafeRow] ... [-1] + // + // UnsafeRow is highly compressible (at least 8 bytes for any column), the byte array is also + // compressed. val byteArrayRdd = execute().mapPartitionsInternal { iter => - new Iterator[Array[Byte]] { - private var row: UnsafeRow = _ - override def hasNext: Boolean = row != null || iter.hasNext - override def next: Array[Byte] = { - var cap = 1 << 20 // 1 MB - if (row != null) { - // the buffered row could be larger than default buffer size - cap = Math.max(cap, 4 + row.getSizeInBytes + 4) // reverse 4 bytes for ending mark (-1). - } - val buffer = ByteBuffer.allocate(cap) - if (row != null) { - buffer.putInt(row.getSizeInBytes) - row.writeTo(buffer) - row = null - } - while (iter.hasNext) { - row = iter.next().asInstanceOf[UnsafeRow] - // Reserve last 4 bytes for ending mark - if (4 + row.getSizeInBytes + 4 <= buffer.remaining()) { - buffer.putInt(row.getSizeInBytes) - row.writeTo(buffer) - row = null - } else { - buffer.putInt(-1) - return buffer.array() - } - } - buffer.putInt(-1) - // copy the used bytes to make it smaller - val bytes = new Array[Byte](buffer.limit()) - System.arraycopy(buffer.array(), 0, bytes, 0, buffer.limit()) - bytes - } + val buffer = new Array[Byte](4 << 10) // 4K + val codec = CompressionCodec.createCodec(SparkEnv.get.conf) + val bos = new ByteArrayOutputStream() + val out = new DataOutputStream(codec.compressedOutputStream(bos)) + while (iter.hasNext) { + val row = iter.next().asInstanceOf[UnsafeRow] + out.writeInt(row.getSizeInBytes) + row.writeToStream(out, buffer) } + out.writeInt(-1) + out.flush() + out.close() + Iterator(bos.toByteArray) } + // Collect the byte arrays back to driver, then decode them as UnsafeRows. val nFields = schema.length - byteArrayRdd.collect().flatMap { bytes => - val buffer = ByteBuffer.wrap(bytes) - new Iterator[InternalRow] { - private var sizeInBytes = buffer.getInt() - override def hasNext: Boolean = sizeInBytes >= 0 - override def next: InternalRow = { - val row = new UnsafeRow(nFields) - row.pointTo(buffer.array(), Platform.BYTE_ARRAY_OFFSET + buffer.position(), sizeInBytes) - buffer.position(buffer.position() + sizeInBytes) - sizeInBytes = buffer.getInt() - row - } + val results = ArrayBuffer[InternalRow]() + + byteArrayRdd.collect().foreach { bytes => + val codec = CompressionCodec.createCodec(SparkEnv.get.conf) + val bis = new ByteArrayInputStream(bytes) + val ins = new DataInputStream(codec.compressedInputStream(bis)) + var sizeOfNextRow = ins.readInt() + while (sizeOfNextRow >= 0) { + val bs = new Array[Byte](sizeOfNextRow) + ins.readFully(bs) + val row = new UnsafeRow(nFields) + row.pointTo(bs, sizeOfNextRow) + results += row + sizeOfNextRow = ins.readInt() } } + results.toArray } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 2d3e34d0e1292..9f33e4ab62298 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -428,4 +428,29 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { */ benchmark.run() } + + ignore("collect") { + val N = 1 << 20 + + val benchmark = new Benchmark("collect", N) + benchmark.addCase("collect 1 million") { iter => + sqlContext.range(N).collect() + } + benchmark.addCase("collect 2 millions") { iter => + sqlContext.range(N * 2).collect() + } + benchmark.addCase("collect 4 millions") { iter => + sqlContext.range(N * 4).collect() + } + benchmark.run() + + /** + * Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + collect: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + collect 1 million 775 / 1170 1.4 738.9 1.0X + collect 2 millions 1153 / 1758 0.9 1099.3 0.7X + collect 4 millions 4451 / 5124 0.2 4244.9 0.2X + */ + } } From 4f9cf91a50e15f0246087b70ee855d08f84b4c3e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 14 Mar 2016 14:02:26 -0700 Subject: [PATCH 3/4] fix style --- .../main/scala/org/apache/spark/sql/execution/SparkPlan.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 36b393ebc0224..e04683c499a32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -24,7 +24,7 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.{Await, ExecutionContext, Future} import scala.concurrent.duration._ -import org.apache.spark.{Logging, SparkEnv, broadcast} +import org.apache.spark.{broadcast, Logging, SparkEnv} import org.apache.spark.io.CompressionCodec import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.sql.{Row, SQLContext} From 5f00d67f8df4aef7d6010643cc33f8fe218d3660 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 14 Mar 2016 20:55:00 -0700 Subject: [PATCH 4/4] fix tests --- .../scala/org/apache/spark/sql/ExtraStrategiesSuite.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala index 2c4b4f80ff9ed..b1987c690811d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala @@ -29,7 +29,9 @@ case class FastOperator(output: Seq[Attribute]) extends SparkPlan { override protected def doExecute(): RDD[InternalRow] = { val str = Literal("so fast").value val row = new GenericInternalRow(Array[Any](str)) - sparkContext.parallelize(Seq(row)) + val unsafeProj = UnsafeProjection.create(schema) + val unsafeRow = unsafeProj(row).copy() + sparkContext.parallelize(Seq(unsafeRow)) } override def producedAttributes: AttributeSet = outputSet