From 73eefa2643b70d68a07ce1473d190d9ba996e18a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 7 Oct 2015 17:00:02 -0700 Subject: [PATCH 01/11] improve unrolling of complex types --- .../catalyst/expressions/UnsafeMapData.java | 16 +++ .../spark/sql/columnar/ColumnAccessor.scala | 9 +- .../spark/sql/columnar/ColumnType.scala | 106 ++++++++---------- .../columnar/InMemoryColumnarTableScan.scala | 3 + .../sql/execution/rowFormatConverters.scala | 7 ++ .../spark/sql/columnar/ColumnTypeSuite.scala | 14 +-- .../NullableColumnAccessorSuite.scala | 7 +- .../columnar/NullableColumnBuilderSuite.scala | 13 ++- 8 files changed, 95 insertions(+), 80 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java index e9dab9edb6bd1..3db949f630a38 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java @@ -64,6 +64,22 @@ public UnsafeArrayData valueArray() { return values; } + @Override + public int hashCode() { + int h = numElements; + return (h * 31 + keys.hashCode()) * 31 + values.hashCode(); + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof UnsafeMapData) { + UnsafeMapData map = (UnsafeMapData) obj; + return numElements == map.numElements && keys.equals(map.keyArray()) + && values.equals(map.valueArray()); + } + return false; + } + @Override public UnsafeMapData copy() { return new UnsafeMapData(keys.copy(), values.copy()); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala index f04099f54c41d..65c751e0e6acc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.columnar import java.nio.{ByteBuffer, ByteOrder} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.MutableRow +import org.apache.spark.sql.catalyst.expressions.{MutableRow, UnsafeArrayData, UnsafeMapData, UnsafeRow} import org.apache.spark.sql.columnar.compression.CompressibleColumnAccessor import org.apache.spark.sql.types._ @@ -109,15 +108,15 @@ private[sql] class DecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalTy with NullableColumnAccessor private[sql] class StructColumnAccessor(buffer: ByteBuffer, dataType: StructType) - extends BasicColumnAccessor[InternalRow](buffer, STRUCT(dataType)) + extends BasicColumnAccessor[UnsafeRow](buffer, STRUCT(dataType)) with NullableColumnAccessor private[sql] class ArrayColumnAccessor(buffer: ByteBuffer, dataType: ArrayType) - extends BasicColumnAccessor[ArrayData](buffer, ARRAY(dataType)) + extends BasicColumnAccessor[UnsafeArrayData](buffer, ARRAY(dataType)) with NullableColumnAccessor private[sql] class MapColumnAccessor(buffer: ByteBuffer, dataType: MapType) - extends BasicColumnAccessor[MapData](buffer, MAP(dataType)) + extends BasicColumnAccessor[UnsafeMapData](buffer, MAP(dataType)) with NullableColumnAccessor private[sql] object ColumnAccessor { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 3563eacb3a3e9..8b3304f31918b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -450,123 +450,111 @@ private[sql] object LARGE_DECIMAL { } private[sql] case class STRUCT(dataType: StructType) - extends ByteArrayColumnType[InternalRow](20) { + extends ByteArrayColumnType[UnsafeRow](20) { - private val projection: UnsafeProjection = - UnsafeProjection.create(dataType) private val numOfFields: Int = dataType.fields.size + val unsafeRow = new UnsafeRow - override def setField(row: MutableRow, ordinal: Int, value: InternalRow): Unit = { + override def setField(row: MutableRow, ordinal: Int, value: UnsafeRow): Unit = { row.update(ordinal, value) } - override def getField(row: InternalRow, ordinal: Int): InternalRow = { - row.getStruct(ordinal, numOfFields) + override def getField(row: InternalRow, ordinal: Int): UnsafeRow = { + row.getStruct(ordinal, numOfFields).asInstanceOf[UnsafeRow] } - override def serialize(value: InternalRow): Array[Byte] = { - val unsafeRow = if (value.isInstanceOf[UnsafeRow]) { - value.asInstanceOf[UnsafeRow] - } else { - projection(value) - } - unsafeRow.getBytes + override def actualSize(row: InternalRow, ordinal: Int): Int = { + 4 + getField(row, ordinal).getSizeInBytes + } + + override def serialize(value: UnsafeRow): Array[Byte] = { + value.getBytes } - override def deserialize(bytes: Array[Byte]): InternalRow = { - val unsafeRow = new UnsafeRow + override def deserialize(bytes: Array[Byte]): UnsafeRow = { unsafeRow.pointTo(bytes, numOfFields, bytes.length) unsafeRow } - override def clone(v: InternalRow): InternalRow = v.copy() + override def clone(v: UnsafeRow): UnsafeRow = v.copy() } private[sql] case class ARRAY(dataType: ArrayType) - extends ByteArrayColumnType[ArrayData](16) { + extends ByteArrayColumnType[UnsafeArrayData](16) { + private val array = new UnsafeArrayData - private lazy val projection = UnsafeProjection.create(Array[DataType](dataType)) - private val mutableRow = new GenericMutableRow(new Array[Any](1)) - - override def setField(row: MutableRow, ordinal: Int, value: ArrayData): Unit = { + override def setField(row: MutableRow, ordinal: Int, value: UnsafeArrayData): Unit = { row.update(ordinal, value) } - override def getField(row: InternalRow, ordinal: Int): ArrayData = { - row.getArray(ordinal) + override def getField(row: InternalRow, ordinal: Int): UnsafeArrayData = { + row.getArray(ordinal).asInstanceOf[UnsafeArrayData] } - override def serialize(value: ArrayData): Array[Byte] = { - val unsafeArray = if (value.isInstanceOf[UnsafeArrayData]) { - value.asInstanceOf[UnsafeArrayData] - } else { - mutableRow(0) = value - projection(mutableRow).getArray(0) - } - val outputBuffer = - ByteBuffer.allocate(4 + unsafeArray.getSizeInBytes).order(ByteOrder.nativeOrder()) - outputBuffer.putInt(unsafeArray.numElements()) + override def actualSize(row: InternalRow, ordinal: Int): Int = { + val unsafeArray = row.getArray(ordinal).asInstanceOf[UnsafeArrayData] + 4 + 4 + unsafeArray.getSizeInBytes + } + + override def serialize(value: UnsafeArrayData): Array[Byte] = { + val outputBuffer = ByteBuffer.allocate(4 + value.getSizeInBytes).order(ByteOrder.nativeOrder()) + outputBuffer.putInt(value.numElements()) val underlying = outputBuffer.array() - unsafeArray.writeToMemory(underlying, Platform.BYTE_ARRAY_OFFSET + 4) + value.writeToMemory(underlying, Platform.BYTE_ARRAY_OFFSET + 4) underlying } - override def deserialize(bytes: Array[Byte]): ArrayData = { + override def deserialize(bytes: Array[Byte]): UnsafeArrayData = { val buffer = ByteBuffer.wrap(bytes).order(ByteOrder.nativeOrder()) val numElements = buffer.getInt - val array = new UnsafeArrayData array.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + 4, numElements, bytes.length - 4) array } - override def clone(v: ArrayData): ArrayData = v.copy() + override def clone(v: UnsafeArrayData): UnsafeArrayData = v.copy() } -private[sql] case class MAP(dataType: MapType) extends ByteArrayColumnType[MapData](32) { +private[sql] case class MAP(dataType: MapType) extends ByteArrayColumnType[UnsafeMapData](32) { - private lazy val projection: UnsafeProjection = UnsafeProjection.create(Array[DataType](dataType)) - private val mutableRow = new GenericMutableRow(new Array[Any](1)) + private val keyArray = new UnsafeArrayData + private val valueArray = new UnsafeArrayData - override def setField(row: MutableRow, ordinal: Int, value: MapData): Unit = { + override def setField(row: MutableRow, ordinal: Int, value: UnsafeMapData): Unit = { row.update(ordinal, value) } - override def getField(row: InternalRow, ordinal: Int): MapData = { - row.getMap(ordinal) + override def getField(row: InternalRow, ordinal: Int): UnsafeMapData = { + row.getMap(ordinal).asInstanceOf[UnsafeMapData] } - override def serialize(value: MapData): Array[Byte] = { - val unsafeMap = if (value.isInstanceOf[UnsafeMapData]) { - value.asInstanceOf[UnsafeMapData] - } else { - mutableRow(0) = value - projection(mutableRow).getMap(0) - } + override def actualSize(row: InternalRow, ordinal: Int): Int = { + val unsafeMap = getField(row, ordinal) + 4 + 8 + unsafeMap.keyArray().getSizeInBytes + unsafeMap.valueArray().getSizeInBytes + } + override def serialize(value: UnsafeMapData): Array[Byte] = { val outputBuffer = - ByteBuffer.allocate(8 + unsafeMap.getSizeInBytes).order(ByteOrder.nativeOrder()) - outputBuffer.putInt(unsafeMap.numElements()) - val keyBytes = unsafeMap.keyArray().getSizeInBytes + ByteBuffer.allocate(8 + value.getSizeInBytes).order(ByteOrder.nativeOrder()) + outputBuffer.putInt(value.numElements()) + val keyBytes = value.keyArray().getSizeInBytes outputBuffer.putInt(keyBytes) val underlying = outputBuffer.array() - unsafeMap.keyArray().writeToMemory(underlying, Platform.BYTE_ARRAY_OFFSET + 8) - unsafeMap.valueArray().writeToMemory(underlying, Platform.BYTE_ARRAY_OFFSET + 8 + keyBytes) + value.keyArray().writeToMemory(underlying, Platform.BYTE_ARRAY_OFFSET + 8) + value.valueArray().writeToMemory(underlying, Platform.BYTE_ARRAY_OFFSET + 8 + keyBytes) underlying } - override def deserialize(bytes: Array[Byte]): MapData = { + override def deserialize(bytes: Array[Byte]): UnsafeMapData = { val buffer = ByteBuffer.wrap(bytes).order(ByteOrder.nativeOrder()) val numElements = buffer.getInt val keyArraySize = buffer.getInt - val keyArray = new UnsafeArrayData - val valueArray = new UnsafeArrayData keyArray.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + 8, numElements, keyArraySize) valueArray.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + 8 + keyArraySize, numElements, bytes.length - 8 - keyArraySize) new UnsafeMapData(keyArray, valueArray) } - override def clone(v: MapData): MapData = v.copy() + override def clone(v: UnsafeMapData): UnsafeMapData = v.copy() } private[sql] object ColumnType { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index d7e145f9c2bb8..78887a545e043 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -199,6 +199,9 @@ private[sql] case class InMemoryColumnarTableScan( @transient relation: InMemoryRelation) extends LeafNode { + override def canProcessSafeRows: Boolean = false + override def canProcessUnsafeRows: Boolean = true + override def output: Seq[Attribute] = attributes private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala index 855555dd1d4c4..e2b20a8fa40c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} /** * :: DeveloperApi :: @@ -79,6 +80,12 @@ private[sql] object EnsureRowFormats extends Rule[SparkPlan] { operator.canProcessSafeRows && operator.canProcessUnsafeRows override def apply(operator: SparkPlan): SparkPlan = operator.transformUp { + case operator: InMemoryColumnarTableScan if !operator.relation.child.outputsUnsafeRows => + val cache = operator.relation + val newCache = InMemoryRelation(cache.useCompression, cache.batchSize, cache.storageLevel, + ConvertToUnsafe(cache.child), cache.tableName) + operator.copy(relation = newCache) + case operator: SparkPlan if onlyHandlesSafeRows(operator) => if (operator.children.exists(_.outputsUnsafeRows)) { operator.withNewChildren { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index ceb8ad97bb320..ccac486e6810b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -21,7 +21,7 @@ import java.nio.ByteBuffer import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow} import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types._ import org.apache.spark.{Logging, SparkFunSuite} @@ -55,7 +55,8 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { assertResult(expected, s"Wrong actualSize for $columnType") { val row = new GenericMutableRow(1) row.update(0, CatalystTypeConverters.convertToCatalyst(value)) - columnType.actualSize(row, 0) + val proj = UnsafeProjection.create(Array[DataType](columnType.dataType)) + columnType.actualSize(proj(row), 0) } } @@ -100,19 +101,18 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { def testColumnType[JvmType](columnType: ColumnType[JvmType]): Unit = { val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE) - val seq = (0 until 4).map(_ => makeRandomValue(columnType)) - val converter = CatalystTypeConverters.createToScalaConverter(columnType.dataType) + val proj = UnsafeProjection.create(Array[DataType](columnType.dataType)) + val seq = (0 until 4).map(_ => proj(makeRandomRow(columnType)).copy()) test(s"$columnType append/extract") { buffer.rewind() - seq.foreach(columnType.append(_, buffer)) + seq.foreach(columnType.append(_, 0, buffer)) buffer.rewind() seq.foreach { expected => logInfo("buffer = " + buffer + ", expected = " + expected) val extracted = columnType.extract(buffer) - assert( - converter(expected) === converter(extracted), + assert(expected.get(0, columnType.dataType) === extracted, "Extracted value didn't equal to the original one. " + hexDump(expected) + " != " + hexDump(extracted) + ", buffer = " + dumpBuffer(buffer.duplicate().rewind().asInstanceOf[ByteBuffer])) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index 78cebbf3cc934..aa1605fee8c73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -21,7 +21,7 @@ import java.nio.ByteBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow} import org.apache.spark.sql.types._ class TestNullableColumnAccessor[JvmType]( @@ -64,10 +64,11 @@ class NullableColumnAccessorSuite extends SparkFunSuite { test(s"Nullable $typeName column accessor: access null values") { val builder = TestNullableColumnBuilder(columnType) val randomRow = makeRandomRow(columnType) + val proj = UnsafeProjection.create(Array[DataType](columnType.dataType)) (0 until 4).foreach { _ => - builder.appendFrom(randomRow, 0) - builder.appendFrom(nullRow, 0) + builder.appendFrom(proj(randomRow), 0) + builder.appendFrom(proj(nullRow), 0) } val accessor = TestNullableColumnAccessor(builder.build(), columnType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index fba08e626d720..91404577832a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow} import org.apache.spark.sql.types._ class TestNullableColumnBuilder[JvmType](columnType: ColumnType[JvmType]) @@ -51,6 +51,9 @@ class NullableColumnBuilderSuite extends SparkFunSuite { columnType: ColumnType[JvmType]): Unit = { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") + val dataType = columnType.dataType + val proj = UnsafeProjection.create(Array[DataType](dataType)) + val converter = CatalystTypeConverters.createToScalaConverter(dataType) test(s"$typeName column builder: empty column") { val columnBuilder = TestNullableColumnBuilder(columnType) @@ -65,7 +68,7 @@ class NullableColumnBuilderSuite extends SparkFunSuite { val randomRow = makeRandomRow(columnType) (0 until 4).foreach { _ => - columnBuilder.appendFrom(randomRow, 0) + columnBuilder.appendFrom(proj(randomRow), 0) } val buffer = columnBuilder.build() @@ -77,12 +80,10 @@ class NullableColumnBuilderSuite extends SparkFunSuite { val columnBuilder = TestNullableColumnBuilder(columnType) val randomRow = makeRandomRow(columnType) val nullRow = makeNullRow(1) - val dataType = columnType.dataType - val converter = CatalystTypeConverters.createToScalaConverter(dataType) (0 until 4).foreach { _ => - columnBuilder.appendFrom(randomRow, 0) - columnBuilder.appendFrom(nullRow, 0) + columnBuilder.appendFrom(proj(randomRow), 0) + columnBuilder.appendFrom(proj(nullRow), 0) } val buffer = columnBuilder.build() From c8a0ba3f7fa80078a3593220d21b33cf93091557 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 8 Oct 2015 12:41:08 -0700 Subject: [PATCH 02/11] cleanup --- .../catalyst/expressions/UnsafeMapData.java | 16 ------------- .../columnar/InMemoryColumnarTableScan.scala | 16 ++++++++----- .../sql/execution/rowFormatConverters.scala | 6 ----- .../spark/sql/columnar/ColumnTypeSuite.scala | 23 +++++++------------ 4 files changed, 18 insertions(+), 43 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java index 3db949f630a38..e9dab9edb6bd1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java @@ -64,22 +64,6 @@ public UnsafeArrayData valueArray() { return values; } - @Override - public int hashCode() { - int h = numElements; - return (h * 31 + keys.hashCode()) * 31 + values.hashCode(); - } - - @Override - public boolean equals(Object obj) { - if (obj instanceof UnsafeMapData) { - UnsafeMapData map = (UnsafeMapData) obj; - return numElements == map.numElements && keys.equals(map.keyArray()) - && values.equals(map.valueArray()); - } - return false; - } - @Override public UnsafeMapData copy() { return new UnsafeMapData(keys.copy(), values.copy()); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 78887a545e043..17a4490b40f67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} -import org.apache.spark.sql.execution.{LeafNode, SparkPlan} +import org.apache.spark.sql.execution.{ConvertToUnsafe, LeafNode, SparkPlan} import org.apache.spark.storage.StorageLevel import org.apache.spark.{Accumulable, Accumulator, Accumulators} @@ -37,8 +37,15 @@ private[sql] object InMemoryRelation { batchSize: Int, storageLevel: StorageLevel, child: SparkPlan, - tableName: Option[String]): InMemoryRelation = - new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)() + tableName: Option[String]): InMemoryRelation = { + val newChild = if (child.outputsUnsafeRows) { + child + } else { + ConvertToUnsafe(child) + } + new InMemoryRelation(newChild.output, useCompression, batchSize, storageLevel, newChild, + tableName)() + } } private[sql] case class CachedBatch(buffers: Array[Array[Byte]], stats: InternalRow) @@ -199,9 +206,6 @@ private[sql] case class InMemoryColumnarTableScan( @transient relation: InMemoryRelation) extends LeafNode { - override def canProcessSafeRows: Boolean = false - override def canProcessUnsafeRows: Boolean = true - override def output: Seq[Attribute] = attributes private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala index e2b20a8fa40c9..a2b4b4a410064 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala @@ -80,12 +80,6 @@ private[sql] object EnsureRowFormats extends Rule[SparkPlan] { operator.canProcessSafeRows && operator.canProcessUnsafeRows override def apply(operator: SparkPlan): SparkPlan = operator.transformUp { - case operator: InMemoryColumnarTableScan if !operator.relation.child.outputsUnsafeRows => - val cache = operator.relation - val newCache = InMemoryRelation(cache.useCompression, cache.batchSize, cache.storageLevel, - ConvertToUnsafe(cache.child), cache.tableName) - operator.copy(relation = newCache) - case operator: SparkPlan if onlyHandlesSafeRows(operator) => if (operator.children.exists(_.outputsUnsafeRows)) { operator.withNewChildren { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index ccac486e6810b..039f73695b759 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -102,6 +102,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE) val proj = UnsafeProjection.create(Array[DataType](columnType.dataType)) + val converter = CatalystTypeConverters.createToScalaConverter(columnType.dataType) val seq = (0 until 4).map(_ => proj(makeRandomRow(columnType)).copy()) test(s"$columnType append/extract") { @@ -109,25 +110,17 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { seq.foreach(columnType.append(_, 0, buffer)) buffer.rewind() - seq.foreach { expected => - logInfo("buffer = " + buffer + ", expected = " + expected) - val extracted = columnType.extract(buffer) - assert(expected.get(0, columnType.dataType) === extracted, - "Extracted value didn't equal to the original one. " + - hexDump(expected) + " != " + hexDump(extracted) + - ", buffer = " + dumpBuffer(buffer.duplicate().rewind().asInstanceOf[ByteBuffer])) + seq.foreach { row => + logInfo("buffer = " + buffer + ", expected = " + row) + val expected = converter(row.get(0, columnType.dataType)) + val extracted = converter(columnType.extract(buffer)) + assert(expected === extracted, + s"Extracted value didn't equal to the original one. $expected != $extracted, buffer =" + + dumpBuffer(buffer.duplicate().rewind().asInstanceOf[ByteBuffer])) } } } - private def hexDump(value: Any): String = { - if (value == null) { - "" - } else { - value.toString.map(ch => Integer.toHexString(ch & 0xffff)).mkString(" ") - } - } - private def dumpBuffer(buff: ByteBuffer): Any = { val sb = new StringBuilder() while (buff.hasRemaining) { From 23e127c2a34ab75a3d6c662d907b1e8f4a0fbde8 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 8 Oct 2015 13:53:58 -0700 Subject: [PATCH 03/11] avoid the memory copy in reading path --- .../catalyst/expressions/UnsafeArrayData.java | 10 ++ .../sql/catalyst/expressions/UnsafeRow.java | 10 ++ .../spark/sql/columnar/ColumnType.scala | 153 ++++++++---------- .../columnar/InMemoryColumnarTableScan.scala | 11 +- .../sql/execution/rowFormatConverters.scala | 1 - .../apache/spark/unsafe/types/UTF8String.java | 10 ++ 6 files changed, 102 insertions(+), 93 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index fdd9125613a26..cdc28becd24ea 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -19,6 +19,7 @@ import java.math.BigDecimal; import java.math.BigInteger; +import java.nio.ByteBuffer; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; @@ -306,6 +307,15 @@ public void writeToMemory(Object target, long targetOffset) { Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes); } + public void writeTo(ByteBuffer buffer) { + assert(buffer.hasArray()); + byte[] target = buffer.array(); + int offset = buffer.arrayOffset(); + int pos = buffer.position(); + writeToMemory(target, Platform.BYTE_ARRAY_OFFSET + offset + pos); + buffer.position(pos + sizeInBytes); + } + @Override public UnsafeArrayData copy() { UnsafeArrayData arrayCopy = new UnsafeArrayData(); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index e8ac2999c2d29..1c530f748564c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -21,6 +21,7 @@ import java.io.OutputStream; import java.math.BigDecimal; import java.math.BigInteger; +import java.nio.ByteBuffer; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; @@ -596,4 +597,13 @@ public boolean anyNull() { public void writeToMemory(Object target, long targetOffset) { Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes); } + + public void writeTo(ByteBuffer buffer) { + assert(buffer.hasArray()); + byte[] target = buffer.array(); + int offset = buffer.arrayOffset(); + int pos = buffer.position(); + writeToMemory(target, Platform.BYTE_ARRAY_OFFSET + offset + pos); + buffer.position(pos + sizeInBytes); + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 8b3304f31918b..a8b87551e5269 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.columnar import java.math.{BigDecimal, BigInteger} -import java.nio.{ByteOrder, ByteBuffer} +import java.nio.ByteBuffer import scala.reflect.runtime.universe.TypeTag @@ -92,7 +92,7 @@ private[sql] sealed abstract class ColumnType[JvmType] { * boxing/unboxing costs whenever possible. */ def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { - to.update(toOrdinal, from.get(fromOrdinal, dataType)) + setField(to, toOrdinal, getField(from, fromOrdinal)) } /** @@ -146,10 +146,6 @@ private[sql] object INT extends NativeColumnType(IntegerType, 4) { } override def getField(row: InternalRow, ordinal: Int): Int = row.getInt(ordinal) - - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { - to.setInt(toOrdinal, from.getInt(fromOrdinal)) - } } private[sql] object LONG extends NativeColumnType(LongType, 8) { @@ -174,10 +170,6 @@ private[sql] object LONG extends NativeColumnType(LongType, 8) { } override def getField(row: InternalRow, ordinal: Int): Long = row.getLong(ordinal) - - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { - to.setLong(toOrdinal, from.getLong(fromOrdinal)) - } } private[sql] object FLOAT extends NativeColumnType(FloatType, 4) { @@ -202,10 +194,6 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 4) { } override def getField(row: InternalRow, ordinal: Int): Float = row.getFloat(ordinal) - - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { - to.setFloat(toOrdinal, from.getFloat(fromOrdinal)) - } } private[sql] object DOUBLE extends NativeColumnType(DoubleType, 8) { @@ -230,10 +218,6 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 8) { } override def getField(row: InternalRow, ordinal: Int): Double = row.getDouble(ordinal) - - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { - to.setDouble(toOrdinal, from.getDouble(fromOrdinal)) - } } private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 1) { @@ -256,10 +240,6 @@ private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 1) { } override def getField(row: InternalRow, ordinal: Int): Boolean = row.getBoolean(ordinal) - - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { - to.setBoolean(toOrdinal, from.getBoolean(fromOrdinal)) - } } private[sql] object BYTE extends NativeColumnType(ByteType, 1) { @@ -284,10 +264,6 @@ private[sql] object BYTE extends NativeColumnType(ByteType, 1) { } override def getField(row: InternalRow, ordinal: Int): Byte = row.getByte(ordinal) - - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { - to.setByte(toOrdinal, from.getByte(fromOrdinal)) - } } private[sql] object SHORT extends NativeColumnType(ShortType, 2) { @@ -312,10 +288,6 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 2) { } override def getField(row: InternalRow, ordinal: Int): Short = row.getShort(ordinal) - - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { - to.setShort(toOrdinal, from.getShort(fromOrdinal)) - } } private[sql] object STRING extends NativeColumnType(StringType, 8) { @@ -324,15 +296,18 @@ private[sql] object STRING extends NativeColumnType(StringType, 8) { } override def append(v: UTF8String, buffer: ByteBuffer): Unit = { - val stringBytes = v.getBytes - buffer.putInt(stringBytes.length).put(stringBytes, 0, stringBytes.length) + buffer.putInt(v.numBytes()) + v.writeTo(buffer) } override def extract(buffer: ByteBuffer): UTF8String = { val length = buffer.getInt() - val stringBytes = new Array[Byte](length) - buffer.get(stringBytes, 0, length) - UTF8String.fromBytes(stringBytes) + assert(buffer.hasArray) + val base = buffer.array() + val offset = buffer.arrayOffset() + val cursor = buffer.position() + buffer.position(cursor + length) + UTF8String.fromBytes(base, offset + cursor, length) } override def setField(row: MutableRow, ordinal: Int, value: UTF8String): Unit = { @@ -343,10 +318,6 @@ private[sql] object STRING extends NativeColumnType(StringType, 8) { row.getUTF8String(ordinal) } - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { - setField(to, toOrdinal, getField(from, fromOrdinal)) - } - override def clone(v: UTF8String): UTF8String = v.clone() } @@ -368,10 +339,6 @@ private[sql] case class COMPACT_DECIMAL(precision: Int, scale: Int) override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = { row.setDecimal(ordinal, value, precision) } - - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { - setField(to, toOrdinal, getField(from, fromOrdinal)) - } } private[sql] object COMPACT_DECIMAL { @@ -386,11 +353,6 @@ private[sql] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize: def serialize(value: JvmType): Array[Byte] def deserialize(bytes: Array[Byte]): JvmType - override def actualSize(row: InternalRow, ordinal: Int): Int = { - // TODO: grow the buffer in append(), so serialize() will not be called twice - serialize(getField(row, ordinal)).length + 4 - } - override def append(v: JvmType, buffer: ByteBuffer): Unit = { val bytes = serialize(v) buffer.putInt(bytes.length).put(bytes, 0, bytes.length) @@ -416,6 +378,10 @@ private[sql] object BINARY extends ByteArrayColumnType[Array[Byte]](16) { row.getBinary(ordinal) } + override def actualSize(row: InternalRow, ordinal: Int): Int = { + row.getBinary(ordinal).length + 4 + } + def serialize(value: Array[Byte]): Array[Byte] = value def deserialize(bytes: Array[Byte]): Array[Byte] = bytes } @@ -433,6 +399,10 @@ private[sql] case class LARGE_DECIMAL(precision: Int, scale: Int) row.setDecimal(ordinal, value, precision) } + override def actualSize(row: InternalRow, ordinal: Int): Int = { + 4 + getField(row, ordinal).toJavaBigDecimal.unscaledValue().bitLength() / 8 + 1 + } + override def serialize(value: Decimal): Array[Byte] = { value.toJavaBigDecimal.unscaledValue().toByteArray } @@ -449,11 +419,12 @@ private[sql] object LARGE_DECIMAL { } } -private[sql] case class STRUCT(dataType: StructType) - extends ByteArrayColumnType[UnsafeRow](20) { +private[sql] case class STRUCT(dataType: StructType) extends ColumnType[UnsafeRow] { private val numOfFields: Int = dataType.fields.size - val unsafeRow = new UnsafeRow + private val unsafeRow = new UnsafeRow + + override def defaultSize: Int = 20 override def setField(row: MutableRow, ordinal: Int, value: UnsafeRow): Unit = { row.update(ordinal, value) @@ -467,22 +438,30 @@ private[sql] case class STRUCT(dataType: StructType) 4 + getField(row, ordinal).getSizeInBytes } - override def serialize(value: UnsafeRow): Array[Byte] = { - value.getBytes + override def append(value: UnsafeRow, buffer: ByteBuffer): Unit = { + buffer.putInt(value.getSizeInBytes) + value.writeTo(buffer) } - override def deserialize(bytes: Array[Byte]): UnsafeRow = { - unsafeRow.pointTo(bytes, numOfFields, bytes.length) + override def extract(buffer: ByteBuffer): UnsafeRow = { + val sizeInBytes = buffer.getInt() + assert(buffer.hasArray) + val base = buffer.array() + val offset = buffer.arrayOffset() + val cursor = buffer.position() + unsafeRow.pointTo(base, Platform.BYTE_ARRAY_OFFSET + offset + cursor, numOfFields, sizeInBytes) + buffer.position(cursor + sizeInBytes) unsafeRow } override def clone(v: UnsafeRow): UnsafeRow = v.copy() } -private[sql] case class ARRAY(dataType: ArrayType) - extends ByteArrayColumnType[UnsafeArrayData](16) { +private[sql] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArrayData] { private val array = new UnsafeArrayData + override def defaultSize: Int = 16 + override def setField(row: MutableRow, ordinal: Int, value: UnsafeArrayData): Unit = { row.update(ordinal, value) } @@ -492,33 +471,38 @@ private[sql] case class ARRAY(dataType: ArrayType) } override def actualSize(row: InternalRow, ordinal: Int): Int = { - val unsafeArray = row.getArray(ordinal).asInstanceOf[UnsafeArrayData] + val unsafeArray = getField(row, ordinal) 4 + 4 + unsafeArray.getSizeInBytes } - override def serialize(value: UnsafeArrayData): Array[Byte] = { - val outputBuffer = ByteBuffer.allocate(4 + value.getSizeInBytes).order(ByteOrder.nativeOrder()) - outputBuffer.putInt(value.numElements()) - val underlying = outputBuffer.array() - value.writeToMemory(underlying, Platform.BYTE_ARRAY_OFFSET + 4) - underlying + override def append(value: UnsafeArrayData, buffer: ByteBuffer): Unit = { + buffer.putInt(value.numElements()) + buffer.putInt(value.getSizeInBytes) + value.writeTo(buffer) } - override def deserialize(bytes: Array[Byte]): UnsafeArrayData = { - val buffer = ByteBuffer.wrap(bytes).order(ByteOrder.nativeOrder()) + override def extract(buffer: ByteBuffer): UnsafeArrayData = { val numElements = buffer.getInt - array.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + 4, numElements, bytes.length - 4) + val sizeInBytes = buffer.getInt + assert(buffer.hasArray) + val base = buffer.array() + val offset = buffer.arrayOffset() + val cursor = buffer.position() + array.pointTo(base, Platform.BYTE_ARRAY_OFFSET + offset + cursor, numElements, sizeInBytes) + buffer.position(cursor + sizeInBytes) array } override def clone(v: UnsafeArrayData): UnsafeArrayData = v.copy() } -private[sql] case class MAP(dataType: MapType) extends ByteArrayColumnType[UnsafeMapData](32) { +private[sql] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData] { private val keyArray = new UnsafeArrayData private val valueArray = new UnsafeArrayData + override def defaultSize: Int = 32 + override def setField(row: MutableRow, ordinal: Int, value: UnsafeMapData): Unit = { row.update(ordinal, value) } @@ -529,28 +513,29 @@ private[sql] case class MAP(dataType: MapType) extends ByteArrayColumnType[Unsaf override def actualSize(row: InternalRow, ordinal: Int): Int = { val unsafeMap = getField(row, ordinal) - 4 + 8 + unsafeMap.keyArray().getSizeInBytes + unsafeMap.valueArray().getSizeInBytes + 12 + unsafeMap.keyArray().getSizeInBytes + unsafeMap.valueArray().getSizeInBytes } - override def serialize(value: UnsafeMapData): Array[Byte] = { - val outputBuffer = - ByteBuffer.allocate(8 + value.getSizeInBytes).order(ByteOrder.nativeOrder()) - outputBuffer.putInt(value.numElements()) - val keyBytes = value.keyArray().getSizeInBytes - outputBuffer.putInt(keyBytes) - val underlying = outputBuffer.array() - value.keyArray().writeToMemory(underlying, Platform.BYTE_ARRAY_OFFSET + 8) - value.valueArray().writeToMemory(underlying, Platform.BYTE_ARRAY_OFFSET + 8 + keyBytes) - underlying + override def append(value: UnsafeMapData, buffer: ByteBuffer): Unit = { + buffer.putInt(value.numElements()) + buffer.putInt(value.keyArray().getSizeInBytes) + buffer.putInt(value.valueArray().getSizeInBytes) + value.keyArray().writeTo(buffer) + value.valueArray().writeTo(buffer) } - override def deserialize(bytes: Array[Byte]): UnsafeMapData = { - val buffer = ByteBuffer.wrap(bytes).order(ByteOrder.nativeOrder()) + override def extract(buffer: ByteBuffer): UnsafeMapData = { val numElements = buffer.getInt val keyArraySize = buffer.getInt - keyArray.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + 8, numElements, keyArraySize) - valueArray.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + 8 + keyArraySize, numElements, - bytes.length - 8 - keyArraySize) + val valueArraySize = buffer.getInt + assert(buffer.hasArray) + val base = buffer.array() + val offset = buffer.arrayOffset() + val cursor = buffer.position() + keyArray.pointTo(base, Platform.BYTE_ARRAY_OFFSET + offset + cursor, numElements, keyArraySize) + valueArray.pointTo(base, Platform.BYTE_ARRAY_OFFSET + offset + cursor + keyArraySize, + numElements, valueArraySize) + buffer.position(cursor + keyArraySize + valueArraySize) new UnsafeMapData(keyArray, valueArray) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 17a4490b40f67..d967814f627cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -37,15 +37,10 @@ private[sql] object InMemoryRelation { batchSize: Int, storageLevel: StorageLevel, child: SparkPlan, - tableName: Option[String]): InMemoryRelation = { - val newChild = if (child.outputsUnsafeRows) { - child - } else { - ConvertToUnsafe(child) - } - new InMemoryRelation(newChild.output, useCompression, batchSize, storageLevel, newChild, + tableName: Option[String]): InMemoryRelation = + new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, + if (child.outputsUnsafeRows) child else ConvertToUnsafe(child), tableName)() - } } private[sql] case class CachedBatch(buffers: Array[Array[Byte]], stats: InternalRow) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala index a2b4b4a410064..855555dd1d4c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} /** * :: DeveloperApi :: diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 216aeea60d1c8..b7aecb5102ba6 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -19,6 +19,7 @@ import javax.annotation.Nonnull; import java.io.*; +import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.Arrays; import java.util.Map; @@ -137,6 +138,15 @@ public void writeToMemory(Object target, long targetOffset) { Platform.copyMemory(base, offset, target, targetOffset, numBytes); } + public void writeTo(ByteBuffer buffer) { + assert(buffer.hasArray()); + byte[] target = buffer.array(); + int offset = buffer.arrayOffset(); + int pos = buffer.position(); + writeToMemory(target, Platform.BYTE_ARRAY_OFFSET + offset + pos); + buffer.position(pos + numBytes); + } + /** * Returns the number of bytes for a code point with the first byte as `b` * @param b The first byte of a code point From 297b06ef1468a658ae68dc4e6884d65c915fb628 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 8 Oct 2015 16:17:31 -0700 Subject: [PATCH 04/11] specialized --- .../scala/org/apache/spark/sql/columnar/ColumnType.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index a8b87551e5269..049e8451f0ea7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -87,10 +87,7 @@ private[sql] sealed abstract class ColumnType[JvmType] { */ def setField(row: MutableRow, ordinal: Int, value: JvmType): Unit - /** - * Copies `from(fromOrdinal)` to `to(toOrdinal)`. Subclasses should override this method to avoid - * boxing/unboxing costs whenever possible. - */ + @specialized(Boolean, Byte, Short, Int, Long) def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { setField(to, toOrdinal, getField(from, fromOrdinal)) } From 6d16b4baf9a6eafe473c9576bedeb126f8ed3c06 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 8 Oct 2015 16:30:48 -0700 Subject: [PATCH 05/11] can't reuse the objections --- .../org/apache/spark/sql/columnar/ColumnType.scala | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 049e8451f0ea7..e90a01723cad6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -419,7 +419,6 @@ private[sql] object LARGE_DECIMAL { private[sql] case class STRUCT(dataType: StructType) extends ColumnType[UnsafeRow] { private val numOfFields: Int = dataType.fields.size - private val unsafeRow = new UnsafeRow override def defaultSize: Int = 20 @@ -446,8 +445,9 @@ private[sql] case class STRUCT(dataType: StructType) extends ColumnType[UnsafeRo val base = buffer.array() val offset = buffer.arrayOffset() val cursor = buffer.position() - unsafeRow.pointTo(base, Platform.BYTE_ARRAY_OFFSET + offset + cursor, numOfFields, sizeInBytes) buffer.position(cursor + sizeInBytes) + val unsafeRow = new UnsafeRow + unsafeRow.pointTo(base, Platform.BYTE_ARRAY_OFFSET + offset + cursor, numOfFields, sizeInBytes) unsafeRow } @@ -455,7 +455,6 @@ private[sql] case class STRUCT(dataType: StructType) extends ColumnType[UnsafeRo } private[sql] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArrayData] { - private val array = new UnsafeArrayData override def defaultSize: Int = 16 @@ -485,8 +484,9 @@ private[sql] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArra val base = buffer.array() val offset = buffer.arrayOffset() val cursor = buffer.position() - array.pointTo(base, Platform.BYTE_ARRAY_OFFSET + offset + cursor, numElements, sizeInBytes) buffer.position(cursor + sizeInBytes) + val array = new UnsafeArrayData + array.pointTo(base, Platform.BYTE_ARRAY_OFFSET + offset + cursor, numElements, sizeInBytes) array } @@ -495,9 +495,6 @@ private[sql] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArra private[sql] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData] { - private val keyArray = new UnsafeArrayData - private val valueArray = new UnsafeArrayData - override def defaultSize: Int = 32 override def setField(row: MutableRow, ordinal: Int, value: UnsafeMapData): Unit = { @@ -529,7 +526,9 @@ private[sql] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData] val base = buffer.array() val offset = buffer.arrayOffset() val cursor = buffer.position() + val keyArray = new UnsafeArrayData keyArray.pointTo(base, Platform.BYTE_ARRAY_OFFSET + offset + cursor, numElements, keyArraySize) + val valueArray = new UnsafeArrayData valueArray.pointTo(base, Platform.BYTE_ARRAY_OFFSET + offset + cursor + keyArraySize, numElements, valueArraySize) buffer.position(cursor + keyArraySize + valueArraySize) From 96661a893a01c195c7eb372aae4660a6ef01c637 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 8 Oct 2015 16:36:58 -0700 Subject: [PATCH 06/11] fix specialized --- .../main/scala/org/apache/spark/sql/columnar/ColumnType.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index e90a01723cad6..87e00f35ea2df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -34,7 +34,8 @@ import org.apache.spark.unsafe.types.UTF8String * * @tparam JvmType Underlying Java type to represent the elements. */ -private[sql] sealed abstract class ColumnType[JvmType] { +private[sql] +sealed abstract class ColumnType[@specialized(Boolean, Byte, Short, Int, Long) JvmType] { // The catalyst data type of this column. def dataType: DataType @@ -87,7 +88,6 @@ private[sql] sealed abstract class ColumnType[JvmType] { */ def setField(row: MutableRow, ordinal: Int, value: JvmType): Unit - @specialized(Boolean, Byte, Short, Int, Long) def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { setField(to, toOrdinal, getField(from, fromOrdinal)) } From b29314b180d475b53adbc4fcd696bfe061b6ae12 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 8 Oct 2015 17:41:42 -0700 Subject: [PATCH 07/11] support UDT in UnsafeProjection --- .../expressions/codegen/CodeGenerator.scala | 5 ++++ .../codegen/GenerateSafeProjection.scala | 1 + .../codegen/GenerateUnsafeProjection.scala | 28 +++++++++++++------ 3 files changed, 25 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2dd680454b4cf..51444b9f5757e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -129,6 +129,7 @@ class CodeGenContext { case _: ArrayType => s"$input.getArray($ordinal)" case _: MapType => s"$input.getMap($ordinal)" case NullType => "null" + case udt: UserDefinedType[_] => getValue(input, udt.sqlType, ordinal) case _ => s"($jt)$input.get($ordinal, null)" } } @@ -143,6 +144,7 @@ class CodeGenContext { case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" // The UTF8String may came from UnsafeRow, otherwise clone is cheap (re-use the bytes) case StringType => s"$row.update($ordinal, $value.clone())" + case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value) case _ => s"$row.update($ordinal, $value)" } } @@ -177,6 +179,7 @@ class CodeGenContext { case _: MapType => "MapData" case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName + case udt: UserDefinedType[_] => javaType(udt.sqlType) case _ => "Object" } @@ -220,6 +223,7 @@ class CodeGenContext { case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2" case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2" case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2" + case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2) case other => s"$c1.equals($c2)" } @@ -253,6 +257,7 @@ class CodeGenContext { addNewFunction(compareFunc, funcCode) s"this.$compareFunc($c1, $c2)" case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)" + case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2) case _ => throw new IllegalArgumentException("cannot generate compare code for un-comparable type") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 9873630937d31..ee50587ed097e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -124,6 +124,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) // UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe. case StringType => GeneratedExpressionCode("", "false", s"$input.clone()") + case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType) case _ => GeneratedExpressionCode("", "false", input) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 3e0e81733fb1f..c2abb78e6c3b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -39,6 +39,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case t: StructType => t.toSeq.forall(field => canSupport(field.dataType)) case t: ArrayType if canSupport(t.elementType) => true case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true + case udt: UserDefinedType[_] => canSupport(udt.sqlType) case _ => false } @@ -77,7 +78,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx.addMutableState(rowWriterClass, rowWriter, s"this.$rowWriter = new $rowWriterClass();") val writeFields = inputs.zip(inputTypes).zipWithIndex.map { - case ((input, dt), index) => + case ((input, dataType), index) => + val dt = dataType match { + case udt: UserDefinedType[_] => udt.sqlType + case other => other + } val tmpCursor = ctx.freshName("tmpCursor") val setNull = dt match { @@ -167,15 +172,20 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val index = ctx.freshName("index") val element = ctx.freshName("element") - val jt = ctx.javaType(elementType) + val et = elementType match { + case udt: UserDefinedType[_] => udt.sqlType + case other => other + } + + val jt = ctx.javaType(et) - val fixedElementSize = elementType match { + val fixedElementSize = et match { case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => 8 - case _ if ctx.isPrimitiveType(jt) => elementType.defaultSize + case _ if ctx.isPrimitiveType(jt) => et.defaultSize case _ => 0 } - val writeElement = elementType match { + val writeElement = et match { case t: StructType => s""" $arrayWriter.setOffset($index); @@ -194,13 +204,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)} """ - case _ if ctx.isPrimitiveType(elementType) => + case _ if ctx.isPrimitiveType(et) => // Should we do word align? - val dataSize = elementType.defaultSize + val dataSize = et.defaultSize s""" $arrayWriter.setOffset($index); - ${writePrimitiveType(ctx, element, elementType, + ${writePrimitiveType(ctx, element, et, s"$bufferHolder.buffer", s"$bufferHolder.cursor")} $bufferHolder.cursor += $dataSize; """ @@ -237,7 +247,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro if ($input.isNullAt($index)) { $arrayWriter.setNullAt($index); } else { - final $jt $element = ${ctx.getValue(input, elementType, index)}; + final $jt $element = ${ctx.getValue(input, et, index)}; $writeElement } } From 6e050a7a0f9519e014dfd87342306b49a3fcc384 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 8 Oct 2015 22:08:47 -0700 Subject: [PATCH 08/11] fix tests --- .../catalyst/expressions/codegen/GenerateUnsafeProjection.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index c2abb78e6c3b6..1b957a508d10e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -39,6 +39,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case t: StructType => t.toSeq.forall(field => canSupport(field.dataType)) case t: ArrayType if canSupport(t.elementType) => true case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true + case dt: OpenHashSetUDT => false // it's not a standard UDT case udt: UserDefinedType[_] => canSupport(udt.sqlType) case _ => false } From 1716bcd8cddc28b5b2ab34c0dc8bb45bbbc31410 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 9 Oct 2015 08:27:31 -0700 Subject: [PATCH 09/11] udf in unsafeRow --- .../apache/spark/sql/catalyst/expressions/UnsafeArrayData.java | 2 ++ .../org/apache/spark/sql/catalyst/expressions/UnsafeRow.java | 2 ++ 2 files changed, 4 insertions(+) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index cdc28becd24ea..796f8abec9a1d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -146,6 +146,8 @@ public Object get(int ordinal, DataType dataType) { return getArray(ordinal); } else if (dataType instanceof MapType) { return getMap(ordinal); + } else if (dataType instanceof UserDefinedType) { + return get(ordinal, ((UserDefinedType)dataType).sqlType()); } else { throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString()); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index a7ece49b7ef07..36859fbab9744 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -327,6 +327,8 @@ public Object get(int ordinal, DataType dataType) { return getArray(ordinal); } else if (dataType instanceof MapType) { return getMap(ordinal); + } else if (dataType instanceof UserDefinedType) { + return get(ordinal, ((UserDefinedType)dataType).sqlType()); } else { throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString()); } From 55a92ba9be5afd3a20a563fd819b2d99e0512114 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 12 Oct 2015 16:05:48 -0700 Subject: [PATCH 10/11] rollback copyField --- .../spark/sql/columnar/ColumnType.scala | 44 ++++++++++++++++++- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 87e00f35ea2df..d3871f03b4343 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -34,8 +34,7 @@ import org.apache.spark.unsafe.types.UTF8String * * @tparam JvmType Underlying Java type to represent the elements. */ -private[sql] -sealed abstract class ColumnType[@specialized(Boolean, Byte, Short, Int, Long) JvmType] { +private[sql] sealed abstract class ColumnType[JvmType] { // The catalyst data type of this column. def dataType: DataType @@ -88,6 +87,10 @@ sealed abstract class ColumnType[@specialized(Boolean, Byte, Short, Int, Long) J */ def setField(row: MutableRow, ordinal: Int, value: JvmType): Unit + /** + * Copies `from(fromOrdinal)` to `to(toOrdinal)`. Subclasses should override this method to avoid + * boxing/unboxing costs whenever possible. + */ def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { setField(to, toOrdinal, getField(from, fromOrdinal)) } @@ -143,6 +146,11 @@ private[sql] object INT extends NativeColumnType(IntegerType, 4) { } override def getField(row: InternalRow, ordinal: Int): Int = row.getInt(ordinal) + + + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + to.setInt(toOrdinal, from.getInt(fromOrdinal)) + } } private[sql] object LONG extends NativeColumnType(LongType, 8) { @@ -167,6 +175,10 @@ private[sql] object LONG extends NativeColumnType(LongType, 8) { } override def getField(row: InternalRow, ordinal: Int): Long = row.getLong(ordinal) + + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + to.setLong(toOrdinal, from.getLong(fromOrdinal)) + } } private[sql] object FLOAT extends NativeColumnType(FloatType, 4) { @@ -191,6 +203,10 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 4) { } override def getField(row: InternalRow, ordinal: Int): Float = row.getFloat(ordinal) + + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + to.setFloat(toOrdinal, from.getFloat(fromOrdinal)) + } } private[sql] object DOUBLE extends NativeColumnType(DoubleType, 8) { @@ -215,6 +231,10 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 8) { } override def getField(row: InternalRow, ordinal: Int): Double = row.getDouble(ordinal) + + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + to.setDouble(toOrdinal, from.getDouble(fromOrdinal)) + } } private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 1) { @@ -237,6 +257,10 @@ private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 1) { } override def getField(row: InternalRow, ordinal: Int): Boolean = row.getBoolean(ordinal) + + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + to.setBoolean(toOrdinal, from.getBoolean(fromOrdinal)) + } } private[sql] object BYTE extends NativeColumnType(ByteType, 1) { @@ -261,6 +285,10 @@ private[sql] object BYTE extends NativeColumnType(ByteType, 1) { } override def getField(row: InternalRow, ordinal: Int): Byte = row.getByte(ordinal) + + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + to.setByte(toOrdinal, from.getByte(fromOrdinal)) + } } private[sql] object SHORT extends NativeColumnType(ShortType, 2) { @@ -285,6 +313,10 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 2) { } override def getField(row: InternalRow, ordinal: Int): Short = row.getShort(ordinal) + + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + to.setShort(toOrdinal, from.getShort(fromOrdinal)) + } } private[sql] object STRING extends NativeColumnType(StringType, 8) { @@ -315,6 +347,10 @@ private[sql] object STRING extends NativeColumnType(StringType, 8) { row.getUTF8String(ordinal) } + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + setField(to, toOrdinal, getField(from, fromOrdinal)) + } + override def clone(v: UTF8String): UTF8String = v.clone() } @@ -336,6 +372,10 @@ private[sql] case class COMPACT_DECIMAL(precision: Int, scale: Int) override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = { row.setDecimal(ordinal, value, precision) } + + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + setField(to, toOrdinal, getField(from, fromOrdinal)) + } } private[sql] object COMPACT_DECIMAL { From 615d9a320c04d4ece116da8e652bea82c8af65a2 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 12 Oct 2015 16:32:10 -0700 Subject: [PATCH 11/11] using UnsafeReader --- .../spark/sql/columnar/ColumnType.scala | 36 ++++++++----------- .../spark/sql/columnar/ColumnTypeSuite.scala | 4 +-- 2 files changed, 16 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index d3871f03b4343..2bc2c96b61634 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -512,22 +512,20 @@ private[sql] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArra } override def append(value: UnsafeArrayData, buffer: ByteBuffer): Unit = { + buffer.putInt(4 + value.getSizeInBytes) buffer.putInt(value.numElements()) - buffer.putInt(value.getSizeInBytes) value.writeTo(buffer) } override def extract(buffer: ByteBuffer): UnsafeArrayData = { - val numElements = buffer.getInt - val sizeInBytes = buffer.getInt + val numBytes = buffer.getInt assert(buffer.hasArray) - val base = buffer.array() - val offset = buffer.arrayOffset() val cursor = buffer.position() - buffer.position(cursor + sizeInBytes) - val array = new UnsafeArrayData - array.pointTo(base, Platform.BYTE_ARRAY_OFFSET + offset + cursor, numElements, sizeInBytes) - array + buffer.position(cursor + numBytes) + UnsafeReaders.readArray( + buffer.array(), + Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor, + numBytes) } override def clone(v: UnsafeArrayData): UnsafeArrayData = v.copy() @@ -551,28 +549,22 @@ private[sql] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData] } override def append(value: UnsafeMapData, buffer: ByteBuffer): Unit = { + buffer.putInt(8 + value.keyArray().getSizeInBytes + value.valueArray().getSizeInBytes) buffer.putInt(value.numElements()) buffer.putInt(value.keyArray().getSizeInBytes) - buffer.putInt(value.valueArray().getSizeInBytes) value.keyArray().writeTo(buffer) value.valueArray().writeTo(buffer) } override def extract(buffer: ByteBuffer): UnsafeMapData = { - val numElements = buffer.getInt - val keyArraySize = buffer.getInt - val valueArraySize = buffer.getInt + val numBytes = buffer.getInt assert(buffer.hasArray) - val base = buffer.array() - val offset = buffer.arrayOffset() val cursor = buffer.position() - val keyArray = new UnsafeArrayData - keyArray.pointTo(base, Platform.BYTE_ARRAY_OFFSET + offset + cursor, numElements, keyArraySize) - val valueArray = new UnsafeArrayData - valueArray.pointTo(base, Platform.BYTE_ARRAY_OFFSET + offset + cursor + keyArraySize, - numElements, valueArraySize) - buffer.position(cursor + keyArraySize + valueArraySize) - new UnsafeMapData(keyArray, valueArray) + buffer.position(cursor + numBytes) + UnsafeReaders.readMap( + buffer.array(), + Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor, + numBytes) } override def clone(v: UnsafeMapData): UnsafeMapData = v.copy() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 039f73695b759..0e6e1bcf72896 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.columnar -import java.nio.ByteBuffer +import java.nio.{ByteOrder, ByteBuffer} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.CatalystTypeConverters @@ -100,7 +100,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { def testColumnType[JvmType](columnType: ColumnType[JvmType]): Unit = { - val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE) + val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE).order(ByteOrder.nativeOrder()) val proj = UnsafeProjection.create(Array[DataType](columnType.dataType)) val converter = CatalystTypeConverters.createToScalaConverter(columnType.dataType) val seq = (0 until 4).map(_ => proj(makeRandomRow(columnType)).copy())