From f9151cc553ef45eb41e848ddf1c5cc0f82598062 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 19 Oct 2015 13:14:30 -0700 Subject: [PATCH] address comments --- .../org/apache/spark/sql/columnar/ColumnType.scala | 6 ++++-- .../sql/columnar/GenerateColumnAccessor.scala | 14 +++++++------- .../sql/columnar/InMemoryColumnarTableScan.scala | 9 ++++++++- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 1478207ff0827..df5f863a300c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -31,6 +31,8 @@ import org.apache.spark.unsafe.types.UTF8String /** * A help class for fast reading Int/Long/Float/Double from ByteBuffer in native order. + * + * WARNNING: This only works with HeapByteBuffer */ object ByteBufferHelper { def getInt(buffer: ByteBuffer): Int = { @@ -360,7 +362,7 @@ private[sql] object STRING extends NativeColumnType(StringType, 8) { } override def extract(buffer: ByteBuffer): UTF8String = { - val length = ByteBufferHelper.getInt(buffer) + val length = buffer.getInt() assert(buffer.hasArray) val base = buffer.array() val offset = buffer.arrayOffset() @@ -426,7 +428,7 @@ private[sql] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize: } override def extract(buffer: ByteBuffer): JvmType = { - val length = ByteBufferHelper.getInt(buffer) + val length = buffer.getInt() val bytes = new Array[Byte](length) buffer.get(bytes, 0, length) deserialize(bytes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala index 0cebbcb99a9ca..e04bcda5800c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala @@ -41,7 +41,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera protected def create(columnTypes: Seq[DataType]): ColumnarIterator = { val ctx = newCodeGenContext() - val (creaters, accesses) = columnTypes.zipWithIndex.map { case (dt, index) => + val (initializeAccessors, extractors) = columnTypes.zipWithIndex.map { case (dt, index) => val accessorName = ctx.freshName("accessor") val accessorCls = dt match { case NullType => classOf[NullColumnAccessor].getName @@ -92,7 +92,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera private byte[][] buffers = null; private int currentRow = 0; - private int totalRows = 0; + private int numRowsInBatch = 0; private scala.collection.Iterator input = null; private MutableRow mutableRow = null; @@ -117,7 +117,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera } public boolean hasNext() { - if (currentRow < totalRows) { + if (currentRow < numRowsInBatch) { return true; } if (!input.hasNext()) { @@ -126,17 +126,17 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera ${classOf[CachedBatch].getName} batch = (${classOf[CachedBatch].getName}) input.next(); currentRow = 0; - totalRows = batch.count(); - for (int i=0; i