From 15ebb3b3794d6888e9e08a834511f66ec3934f00 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 21 Oct 2015 12:02:17 -0700 Subject: [PATCH 1/6] output UnsafeRow from columnar cache --- .../sql/catalyst/expressions/UnsafeRow.java | 19 +++++ .../expressions/codegen/UnsafeRowWriter.java | 46 ++++++++++- .../codegen/GenerateUnsafeProjection.scala | 9 +-- .../spark/sql/columnar/ColumnType.scala | 81 ++++++++++++++++--- .../sql/columnar/GenerateColumnAccessor.scala | 61 ++++++++++++-- .../columnar/InMemoryColumnarTableScan.scala | 6 +- 6 files changed, 192 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 366615f6fe69f..4b7f285d338c6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -618,6 +618,25 @@ public void writeTo(ByteBuffer buffer) { buffer.position(pos + sizeInBytes); } + /** + * Write the bytes of var-length field into ByteBuffer + */ + public void writeFieldTo(int ordinal, ByteBuffer buffer) { + final long offsetAndSize = getLong(ordinal); + final int offset = (int) (offsetAndSize >> 32); + final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + + buffer.putInt(size); + int pos = buffer.position(); + buffer.position(pos + size); + Platform.copyMemory( + baseObject, + baseOffset + offset, + buffer.array(), + Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + pos, + size); + } + @Override public void writeExternal(ObjectOutput out) throws IOException { byte[] bytes = getBytes(); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index e1f5a05d1d446..6568ef4af7d31 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -58,6 +58,10 @@ private void zeroOutPaddingBytes(int numBytes) { } } + public boolean isNullAt(int ordinal) { + return BitSetMethods.isSet(holder.buffer, startingOffset, ordinal); + } + public void setNullAt(int ordinal) { BitSetMethods.set(holder.buffer, startingOffset, ordinal); Platform.putLong(holder.buffer, getFieldOffset(ordinal), 0L); @@ -95,6 +99,40 @@ public void alignToWords(int numBytes) { } } + public void write(int ordinal, boolean value) { + Platform.putBoolean(holder.buffer, getFieldOffset(ordinal), value); + } + + public void write(int ordinal, byte value) { + Platform.putByte(holder.buffer, getFieldOffset(ordinal), value); + } + + public void write(int ordinal, short value) { + Platform.putShort(holder.buffer, getFieldOffset(ordinal), value); + } + + public void write(int ordinal, int value) { + Platform.putInt(holder.buffer, getFieldOffset(ordinal), value); + } + + public void write(int ordinal, long value) { + Platform.putLong(holder.buffer, getFieldOffset(ordinal), value); + } + + public void write(int ordinal, float value) { + if (Float.isNaN(value)) { + value = Float.NaN; + } + Platform.putFloat(holder.buffer, getFieldOffset(ordinal), value); + } + + public void write(int ordinal, double value) { + if (Double.isNaN(value)) { + value = Double.NaN; + } + Platform.putDouble(holder.buffer, getFieldOffset(ordinal), value); + } + public void writeCompactDecimal(int ordinal, Decimal input, int precision, int scale) { // make sure Decimal object has the same scale as DecimalType if (input.changePrecision(precision, scale)) { @@ -151,7 +189,10 @@ public void write(int ordinal, UTF8String input) { } public void write(int ordinal, byte[] input) { - final int numBytes = input.length; + write(ordinal, input, 0, input.length); + } + + public void write(int ordinal, byte[] input, int offset, int numBytes) { final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); // grow the global buffer before writing data. @@ -160,7 +201,8 @@ public void write(int ordinal, byte[] input) { zeroOutPaddingBytes(numBytes); // Write the bytes to the variable length portion. - Platform.copyMemory(input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes); + Platform.copyMemory(input, Platform.BYTE_ARRAY_OFFSET + offset, + holder.buffer, holder.cursor, numBytes); setOffsetAndSize(ordinal, numBytes); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index dbe92d6a83502..23ee3b32b15d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -69,7 +69,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro """ } - private def writeExpressionsToBuffer( + def writeExpressionsToBuffer( ctx: CodeGenContext, row: String, inputs: Seq[GeneratedExpressionCode], @@ -89,7 +89,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val setNull = dt match { case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => // Can't call setNullAt() for DecimalType with precision larger than 18. - s"$rowWriter.write($index, null, ${t.precision}, ${t.scale});" + s"$rowWriter.write($index, (Decimal) null, ${t.precision}, ${t.scale});" case _ => s"$rowWriter.setNullAt($index);" } @@ -124,11 +124,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro """ case _ if ctx.isPrimitiveType(dt) => - val fieldOffset = ctx.freshName("fieldOffset") s""" - final long $fieldOffset = $rowWriter.getFieldOffset($index); - Platform.putLong($bufferHolder.buffer, $fieldOffset, 0L); - ${writePrimitiveType(ctx, input.value, dt, s"$bufferHolder.buffer", fieldOffset)} + $rowWriter.write($index, ${input.value}); """ case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 72fa299aa937b..68e509eb5047d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -32,6 +32,13 @@ import org.apache.spark.unsafe.types.UTF8String /** * A help class for fast reading Int/Long/Float/Double from ByteBuffer in native order. * + * Note: There is not much difference between ByteBuffer.getByte/getShort and + * Unsafe.getByte/getShort, so we do not have helper methods for them. + * + * The unrolling (building columnar cache) is already slow, putLong/putDouble will not help much, + * so we do not have helper methods for them. + * + * * WARNNING: This only works with HeapByteBuffer */ object ByteBufferHelper { @@ -351,7 +358,38 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 2) { } } -private[sql] object STRING extends NativeColumnType(StringType, 8) { +/** + * A fast path to copy var-length bytes between ByteBuffer and UnsafeRow without creating wrapper + * objects. + */ +private[sql] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType] { + + // copy the bytes from ByteBuffer to UnsafeRow + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + if (row.isInstanceOf[MutableUnsafeRow]) { + val numBytes = buffer.getInt + val cursor = buffer.position() + buffer.position(cursor + numBytes) + row.asInstanceOf[MutableUnsafeRow].writer.write(ordinal, buffer.array(), + buffer.arrayOffset() + cursor, numBytes) + } else { + setField(row, ordinal, extract(buffer)) + } + } + + // copy the bytes from UnsafeRow to ByteBuffer + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + if (row.isInstanceOf[UnsafeRow]) { + row.asInstanceOf[UnsafeRow].writeFieldTo(ordinal, buffer) + } else { + super.append(row, ordinal, buffer) + } + } +} + +private[sql] object STRING + extends NativeColumnType(StringType, 8) with DirectCopyColumnType[UTF8String] { + override def actualSize(row: InternalRow, ordinal: Int): Int = { row.getUTF8String(ordinal).numBytes() + 4 } @@ -363,16 +401,17 @@ private[sql] object STRING extends NativeColumnType(StringType, 8) { override def extract(buffer: ByteBuffer): UTF8String = { val length = buffer.getInt() - assert(buffer.hasArray) - val base = buffer.array() - val offset = buffer.arrayOffset() val cursor = buffer.position() buffer.position(cursor + length) - UTF8String.fromBytes(base, offset + cursor, length) + UTF8String.fromBytes(buffer.array(), buffer.arrayOffset() + cursor, length) } override def setField(row: MutableRow, ordinal: Int, value: UTF8String): Unit = { - row.update(ordinal, value.clone()) + if (row.isInstanceOf[MutableUnsafeRow]) { + row.asInstanceOf[MutableUnsafeRow].writer.write(ordinal, value) + } else { + row.update(ordinal, value.clone()) + } } override def getField(row: InternalRow, ordinal: Int): UTF8String = { @@ -393,10 +432,28 @@ private[sql] case class COMPACT_DECIMAL(precision: Int, scale: Int) Decimal(ByteBufferHelper.getLong(buffer), precision, scale) } + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + if (row.isInstanceOf[MutableUnsafeRow]) { + // copy it as Long + row.setLong(ordinal, ByteBufferHelper.getLong(buffer)) + } else { + setField(row, ordinal, extract(buffer)) + } + } + override def append(v: Decimal, buffer: ByteBuffer): Unit = { buffer.putLong(v.toUnscaledLong) } + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + if (row.isInstanceOf[UnsafeRow]) { + // copy it as Long + buffer.putLong(row.getLong(ordinal)) + } else { + append(getField(row, ordinal), buffer) + } + } + override def getField(row: InternalRow, ordinal: Int): Decimal = { row.getDecimal(ordinal, precision, scale) } @@ -417,7 +474,7 @@ private[sql] object COMPACT_DECIMAL { } private[sql] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize: Int) - extends ColumnType[JvmType] { + extends ColumnType[JvmType] with DirectCopyColumnType[JvmType] { def serialize(value: JvmType): Array[Byte] def deserialize(bytes: Array[Byte]): JvmType @@ -488,7 +545,8 @@ private[sql] object LARGE_DECIMAL { } } -private[sql] case class STRUCT(dataType: StructType) extends ColumnType[UnsafeRow] { +private[sql] case class STRUCT(dataType: StructType) + extends ColumnType[UnsafeRow] with DirectCopyColumnType[UnsafeRow] { private val numOfFields: Int = dataType.fields.size @@ -528,7 +586,8 @@ private[sql] case class STRUCT(dataType: StructType) extends ColumnType[UnsafeRo override def clone(v: UnsafeRow): UnsafeRow = v.copy() } -private[sql] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArrayData] { +private[sql] case class ARRAY(dataType: ArrayType) + extends ColumnType[UnsafeArrayData] with DirectCopyColumnType[UnsafeArrayData] { override def defaultSize: Int = 16 @@ -566,7 +625,8 @@ private[sql] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArra override def clone(v: UnsafeArrayData): UnsafeArrayData = v.copy() } -private[sql] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData] { +private[sql] case class MAP(dataType: MapType) + extends ColumnType[UnsafeMapData] with DirectCopyColumnType[UnsafeMapData] { override def defaultSize: Int = 32 @@ -590,7 +650,6 @@ private[sql] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData] override def extract(buffer: ByteBuffer): UnsafeMapData = { val numBytes = buffer.getInt - assert(buffer.hasArray) val cursor = buffer.position() buffer.position(cursor + numBytes) val map = new UnsafeMapData diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala index e04bcda5800c7..052eec23d7452 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala @@ -20,17 +20,41 @@ package org.apache.spark.sql.columnar import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodeGenerator} +import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeRowWriter, CodeFormatter, CodeGenerator} import org.apache.spark.sql.types._ /** - * An Iterator to walk throught the InternalRows from a CachedBatch + * An Iterator to walk through the InternalRows from a CachedBatch */ abstract class ColumnarIterator extends Iterator[InternalRow] { - def initialize(input: Iterator[CachedBatch], mutableRow: MutableRow, columnTypes: Array[DataType], + def initialize(input: Iterator[CachedBatch], columnTypes: Array[DataType], columnIndexes: Array[Int]): Unit } + +/** + * An helper class to update the fields of UnsafeRow, used by ColumnAccessor + */ +class MutableUnsafeRow(val writer: UnsafeRowWriter) extends GenericMutableRow(null) { + + override def isNullAt(i: Int): Boolean = writer.isNullAt(i) + override def setNullAt(i: Int): Unit = writer.setNullAt(i) + + override def setBoolean(i: Int, v: Boolean): Unit = writer.write(i, v) + override def setByte(i: Int, v: Byte): Unit = writer.write(i, v) + override def setShort(i: Int, v: Short): Unit = writer.write(i, v) + override def setInt(i: Int, v: Int): Unit = writer.write(i, v) + override def setLong(i: Int, v: Long): Unit = writer.write(i, v) + override def setFloat(i: Int, v: Float): Unit = writer.write(i, v) + override def setDouble(i: Int, v: Double): Unit = writer.write(i, v) + + // the writer will be used directly to avoid creating wrapper objects + override def setDecimal(i: Int, v: Decimal, precision: Int): Unit = ??? + override def update(i: Int, v: Any): Unit = ??? + + // all other methods inherited from GenericMutableRow are not need +} + /** * Generates bytecode for an [[ColumnarIterator]] for columnar cache. */ @@ -41,6 +65,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera protected def create(columnTypes: Seq[DataType]): ColumnarIterator = { val ctx = newCodeGenContext() + val numFields = columnTypes.size val (initializeAccessors, extractors) = columnTypes.zipWithIndex.map { case (dt, index) => val accessorName = ctx.freshName("accessor") val accessorCls = dt match { @@ -74,13 +99,25 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera } val extract = s"$accessorName.extractTo(mutableRow, $index);" - - (createCode, extract) + val patch = dt match { + case DecimalType.Fixed(p, s) if p > Decimal.MAX_LONG_DIGITS => + // For large Decimal, it should have 16 bytes for future update even it's null now. + s""" + if (mutableRow.isNullAt($index)) { + rowWriter.write($index, (Decimal) null, $p, $s); + } + """ + case other => "" + } + (createCode, extract + patch) }.unzip val code = s""" import java.nio.ByteBuffer; import java.nio.ByteOrder; + import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; + import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; + import org.apache.spark.sql.columnar.MutableUnsafeRow; public SpecificColumnarIterator generate($exprType[] expr) { return new SpecificColumnarIterator(); @@ -90,6 +127,10 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera private ByteOrder nativeOrder = null; private byte[][] buffers = null; + private UnsafeRow unsafeRow = new UnsafeRow(); + private BufferHolder bufferHolder = new BufferHolder(); + private UnsafeRowWriter rowWriter = new UnsafeRowWriter(); + private MutableUnsafeRow mutableRow = null; private int currentRow = 0; private int numRowsInBatch = 0; @@ -104,11 +145,12 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera public SpecificColumnarIterator() { this.nativeOrder = ByteOrder.nativeOrder(); this.buffers = new byte[${columnTypes.length}][]; + this.mutableRow = new MutableUnsafeRow(rowWriter); ${initMutableStates(ctx)} } - public void initialize(scala.collection.Iterator input, MutableRow mutableRow, + public void initialize(scala.collection.Iterator input, ${classOf[DataType].getName}[] columnTypes, int[] columnIndexes) { this.input = input; this.mutableRow = mutableRow; @@ -136,9 +178,12 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera } public InternalRow next() { - ${extractors.mkString("\n")} currentRow += 1; - return mutableRow; + bufferHolder.reset(); + rowWriter.initialize(bufferHolder, $numFields); + ${extractors.mkString("\n")} + unsafeRow.pointTo(bufferHolder.buffer, $numFields, bufferHolder.totalSize()); + return unsafeRow; } }""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 9f76a61a1574b..b4607b12fcefa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -209,6 +209,8 @@ private[sql] case class InMemoryColumnarTableScan( override def output: Seq[Attribute] = attributes + override def outputsUnsafeRows: Boolean = true + private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a) // Returned filter predicate should return false iff it is impossible for the input expression @@ -317,14 +319,12 @@ private[sql] case class InMemoryColumnarTableScan( cachedBatchIterator } - val nextRow = new SpecificMutableRow(requestedColumnDataTypes) val columnTypes = requestedColumnDataTypes.map { case udt: UserDefinedType[_] => udt.sqlType case other => other }.toArray val columnarIterator = GenerateColumnAccessor.generate(columnTypes) - columnarIterator.initialize(cachedBatchesToScan, nextRow, columnTypes, - requestedColumnIndices.toArray) + columnarIterator.initialize(cachedBatchesToScan, columnTypes, requestedColumnIndices.toArray) if (enableAccumulators && columnarIterator.hasNext) { readPartitions += 1 } From 81e3fad9462eed6d7b32b20002bf1127dd435f18 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 21 Oct 2015 13:28:02 -0700 Subject: [PATCH 2/6] fix style and refactor --- .../codegen/UnsafeArrayWriter.java | 78 ++++++++++++++----- .../expressions/codegen/UnsafeRowWriter.java | 62 +++++++-------- .../codegen/GenerateUnsafeProjection.scala | 50 ------------ .../sql/columnar/GenerateColumnAccessor.scala | 6 +- 4 files changed, 95 insertions(+), 101 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index 7f2a1cb07af01..7dd932d1981b7 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen; -import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.types.CalendarInterval; @@ -64,29 +63,72 @@ public void setOffset(int ordinal) { Platform.putInt(holder.buffer, getElementOffset(ordinal), relativeOffset); } - public void writeCompactDecimal(int ordinal, Decimal input, int precision, int scale) { - // make sure Decimal object has the same scale as DecimalType - if (input.changePrecision(precision, scale)) { - Platform.putLong(holder.buffer, holder.cursor, input.toUnscaledLong()); - setOffset(ordinal); - holder.cursor += 8; - } else { - setNullAt(ordinal); + public void write(int ordinal, boolean value) { + Platform.putBoolean(holder.buffer, holder.cursor, value); + setOffset(ordinal); + holder.cursor += 1; + } + + public void write(int ordinal, byte value) { + Platform.putByte(holder.buffer, holder.cursor, value); + setOffset(ordinal); + holder.cursor += 1; + } + + public void write(int ordinal, short value) { + Platform.putShort(holder.buffer, holder.cursor, value); + setOffset(ordinal); + holder.cursor += 2; + } + + public void write(int ordinal, int value) { + Platform.putInt(holder.buffer, holder.cursor, value); + setOffset(ordinal); + holder.cursor += 4; + } + + public void write(int ordinal, long value) { + Platform.putLong(holder.buffer, holder.cursor, value); + setOffset(ordinal); + holder.cursor += 8; + } + + public void write(int ordinal, float value) { + if (Float.isNaN(value)) { + value = Float.NaN; + } + Platform.putFloat(holder.buffer, holder.cursor, value); + setOffset(ordinal); + holder.cursor += 4; + } + + public void write(int ordinal, double value) { + if (Double.isNaN(value)) { + value = Double.NaN; } + Platform.putDouble(holder.buffer, holder.cursor, value); + setOffset(ordinal); + holder.cursor += 8; } public void write(int ordinal, Decimal input, int precision, int scale) { // make sure Decimal object has the same scale as DecimalType if (input.changePrecision(precision, scale)) { - final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); - assert bytes.length <= 16; - holder.grow(bytes.length); - - // Write the bytes to the variable length portion. - Platform.copyMemory( - bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length); - setOffset(ordinal); - holder.cursor += bytes.length; + if (precision <= Decimal.MAX_LONG_DIGITS()) { + Platform.putLong(holder.buffer, holder.cursor, input.toUnscaledLong()); + setOffset(ordinal); + holder.cursor += 8; + } else { + final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); + assert bytes.length <= 16; + holder.grow(bytes.length); + + // Write the bytes to the variable length portion. + Platform.copyMemory( + bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length); + setOffset(ordinal); + holder.cursor += bytes.length; + } } else { setNullAt(ordinal); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 6568ef4af7d31..adbe2621870df 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -133,41 +133,41 @@ public void write(int ordinal, double value) { Platform.putDouble(holder.buffer, getFieldOffset(ordinal), value); } - public void writeCompactDecimal(int ordinal, Decimal input, int precision, int scale) { - // make sure Decimal object has the same scale as DecimalType - if (input.changePrecision(precision, scale)) { - Platform.putLong(holder.buffer, getFieldOffset(ordinal), input.toUnscaledLong()); - } else { - setNullAt(ordinal); - } - } - public void write(int ordinal, Decimal input, int precision, int scale) { - // grow the global buffer before writing data. - holder.grow(16); - - // zero-out the bytes - Platform.putLong(holder.buffer, holder.cursor, 0L); - Platform.putLong(holder.buffer, holder.cursor + 8, 0L); - - // Make sure Decimal object has the same scale as DecimalType. - // Note that we may pass in null Decimal object to set null for it. - if (input == null || !input.changePrecision(precision, scale)) { - BitSetMethods.set(holder.buffer, startingOffset, ordinal); - // keep the offset for future update - setOffsetAndSize(ordinal, 0L); + if (precision <= Decimal.MAX_LONG_DIGITS()) { + // make sure Decimal object has the same scale as DecimalType + if (input.changePrecision(precision, scale)) { + Platform.putLong(holder.buffer, getFieldOffset(ordinal), input.toUnscaledLong()); + } else { + setNullAt(ordinal); + } } else { - final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); - assert bytes.length <= 16; + // grow the global buffer before writing data. + holder.grow(16); + + // zero-out the bytes + Platform.putLong(holder.buffer, holder.cursor, 0L); + Platform.putLong(holder.buffer, holder.cursor + 8, 0L); + + // Make sure Decimal object has the same scale as DecimalType. + // Note that we may pass in null Decimal object to set null for it. + if (input == null || !input.changePrecision(precision, scale)) { + BitSetMethods.set(holder.buffer, startingOffset, ordinal); + // keep the offset for future update + setOffsetAndSize(ordinal, 0L); + } else { + final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); + assert bytes.length <= 16; + + // Write the bytes to the variable length portion. + Platform.copyMemory( + bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length); + setOffsetAndSize(ordinal, bytes.length); + } - // Write the bytes to the variable length portion. - Platform.copyMemory( - bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length); - setOffsetAndSize(ordinal, bytes.length); + // move the cursor forward. + holder.cursor += 16; } - - // move the cursor forward. - holder.cursor += 16; } public void write(int ordinal, UTF8String input) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 23ee3b32b15d7..a24d175aa093d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -128,10 +128,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $rowWriter.write($index, ${input.value}); """ - case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => - s"$rowWriter.writeCompactDecimal($index, ${input.value}, " + - s"${t.precision}, ${t.scale});" - case t: DecimalType => s"$rowWriter.write($index, ${input.value}, ${t.precision}, ${t.scale});" @@ -201,20 +197,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)} """ - case _ if ctx.isPrimitiveType(et) => - // Should we do word align? - val dataSize = et.defaultSize - - s""" - $arrayWriter.setOffset($index); - ${writePrimitiveType(ctx, element, et, - s"$bufferHolder.buffer", s"$bufferHolder.cursor")} - $bufferHolder.cursor += $dataSize; - """ - - case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => - s"$arrayWriter.writeCompactDecimal($index, $element, ${t.precision}, ${t.scale});" - case t: DecimalType => s"$arrayWriter.write($index, $element, ${t.precision}, ${t.scale});" @@ -293,38 +275,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro """ } - private def writePrimitiveType( - ctx: CodeGenContext, - input: String, - dt: DataType, - buffer: String, - offset: String) = { - assert(ctx.isPrimitiveType(dt)) - - val putMethod = s"put${ctx.primitiveTypeName(dt)}" - - dt match { - case FloatType | DoubleType => - val normalized = ctx.freshName("normalized") - val boxedType = ctx.boxedType(dt) - val handleNaN = - s""" - final ${ctx.javaType(dt)} $normalized; - if ($boxedType.isNaN($input)) { - $normalized = $boxedType.NaN; - } else { - $normalized = $input; - } - """ - - s""" - $handleNaN - Platform.$putMethod($buffer, $offset, $normalized); - """ - case _ => s"Platform.$putMethod($buffer, $offset, $input);" - } - } - def createCode(ctx: CodeGenContext, expressions: Seq[Expression]): GeneratedExpressionCode = { val exprEvals = expressions.map(e => e.gen(ctx)) val exprTypes = expressions.map(_.dataType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala index 052eec23d7452..2d887e9a40871 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala @@ -49,8 +49,10 @@ class MutableUnsafeRow(val writer: UnsafeRowWriter) extends GenericMutableRow(nu override def setDouble(i: Int, v: Double): Unit = writer.write(i, v) // the writer will be used directly to avoid creating wrapper objects - override def setDecimal(i: Int, v: Decimal, precision: Int): Unit = ??? - override def update(i: Int, v: Any): Unit = ??? + override def setDecimal(i: Int, v: Decimal, precision: Int): Unit = + throw new UnsupportedOperationException + override def update(i: Int, v: Any): Unit = throw new UnsupportedOperationException + // all other methods inherited from GenericMutableRow are not need } From e33170d793993bcfc48bdc0c159b3fe51b4f0861 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 21 Oct 2015 13:42:28 -0700 Subject: [PATCH 3/6] fix code style --- .../spark/sql/columnar/GenerateColumnAccessor.scala | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala index 2d887e9a40871..188a64e9b66c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala @@ -31,7 +31,6 @@ abstract class ColumnarIterator extends Iterator[InternalRow] { columnIndexes: Array[Int]): Unit } - /** * An helper class to update the fields of UnsafeRow, used by ColumnAccessor */ @@ -53,7 +52,6 @@ class MutableUnsafeRow(val writer: UnsafeRowWriter) extends GenericMutableRow(nu throw new UnsupportedOperationException override def update(i: Int, v: Any): Unit = throw new UnsupportedOperationException - // all other methods inherited from GenericMutableRow are not need } @@ -117,6 +115,8 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera val code = s""" import java.nio.ByteBuffer; import java.nio.ByteOrder; + import scala.collection.Iterator; + import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; import org.apache.spark.sql.columnar.MutableUnsafeRow; @@ -139,7 +139,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera private scala.collection.Iterator input = null; private MutableRow mutableRow = null; - private ${classOf[DataType].getName}[] columnTypes = null; + private DataType[] columnTypes = null; private int[] columnIndexes = null; ${declareMutableStates(ctx)} @@ -152,8 +152,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera ${initMutableStates(ctx)} } - public void initialize(scala.collection.Iterator input, - ${classOf[DataType].getName}[] columnTypes, int[] columnIndexes) { + public void initialize(Iterator input, DataType[] columnTypes, int[] columnIndexes) { this.input = input; this.mutableRow = mutableRow; this.columnTypes = columnTypes; From c76c759e30e2f89a8f4a56a1a854b9d3ee8fa697 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 21 Oct 2015 16:22:55 -0700 Subject: [PATCH 4/6] address comments --- .../sql/catalyst/expressions/UnsafeRow.java | 14 ++++---- .../codegen/UnsafeArrayWriter.java | 35 +++++++++---------- .../expressions/codegen/UnsafeRowWriter.java | 15 ++++---- .../org/apache/spark/sql/types/Decimal.scala | 4 +-- .../ArithmeticExpressionSuite.scala | 14 ++++++-- .../sql/catalyst/expressions/CastSuite.scala | 29 ++++++++------- .../expressions/DecimalExpressionSuite.scala | 13 +++---- .../expressions/LiteralExpressionSuite.scala | 4 +-- .../expressions/MathFunctionsSuite.scala | 14 ++++---- .../sql/columnar/GenerateColumnAccessor.scala | 2 ++ 10 files changed, 76 insertions(+), 68 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 4b7f285d338c6..850838af9be35 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -402,7 +402,7 @@ public UTF8String getUTF8String(int ordinal) { if (isNullAt(ordinal)) return null; final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + final int size = (int) offsetAndSize; return UTF8String.fromAddress(baseObject, baseOffset + offset, size); } @@ -413,7 +413,7 @@ public byte[] getBinary(int ordinal) { } else { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + final int size = (int) offsetAndSize; final byte[] bytes = new byte[size]; Platform.copyMemory( baseObject, @@ -446,7 +446,7 @@ public UnsafeRow getStruct(int ordinal, int numFields) { } else { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + final int size = (int) offsetAndSize; final UnsafeRow row = new UnsafeRow(); row.pointTo(baseObject, baseOffset + offset, numFields, size); return row; @@ -460,7 +460,7 @@ public UnsafeArrayData getArray(int ordinal) { } else { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + final int size = (int) offsetAndSize; final UnsafeArrayData array = new UnsafeArrayData(); array.pointTo(baseObject, baseOffset + offset, size); return array; @@ -474,7 +474,7 @@ public UnsafeMapData getMap(int ordinal) { } else { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + final int size = (int) offsetAndSize; final UnsafeMapData map = new UnsafeMapData(); map.pointTo(baseObject, baseOffset + offset, size); return map; @@ -620,11 +620,13 @@ public void writeTo(ByteBuffer buffer) { /** * Write the bytes of var-length field into ByteBuffer + * + * Note: only work with HeapByteBuffer */ public void writeFieldTo(int ordinal, ByteBuffer buffer) { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + final int size = (int) offsetAndSize; buffer.putInt(size); int pos = buffer.position(); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index 7dd932d1981b7..d9bee1f779b47 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -111,26 +111,23 @@ public void write(int ordinal, double value) { holder.cursor += 8; } - public void write(int ordinal, Decimal input, int precision, int scale) { - // make sure Decimal object has the same scale as DecimalType - if (input.changePrecision(precision, scale)) { - if (precision <= Decimal.MAX_LONG_DIGITS()) { - Platform.putLong(holder.buffer, holder.cursor, input.toUnscaledLong()); - setOffset(ordinal); - holder.cursor += 8; - } else { - final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); - assert bytes.length <= 16; - holder.grow(bytes.length); - - // Write the bytes to the variable length portion. - Platform.copyMemory( - bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length); - setOffset(ordinal); - holder.cursor += bytes.length; - } + public void write(int ordinal, Decimal input, int precision) { + assert(input != null); + assert(input.precision() == precision); + if (precision <= Decimal.MAX_LONG_DIGITS()) { + Platform.putLong(holder.buffer, holder.cursor, input.toUnscaledLong()); + setOffset(ordinal); + holder.cursor += 8; } else { - setNullAt(ordinal); + final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); + assert bytes.length <= 16; + holder.grow(bytes.length); + + // Write the bytes to the variable length portion. + Platform.copyMemory( + bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length); + setOffset(ordinal); + holder.cursor += bytes.length; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index adbe2621870df..cbf91c0fbb073 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -135,12 +135,10 @@ public void write(int ordinal, double value) { public void write(int ordinal, Decimal input, int precision, int scale) { if (precision <= Decimal.MAX_LONG_DIGITS()) { - // make sure Decimal object has the same scale as DecimalType - if (input.changePrecision(precision, scale)) { - Platform.putLong(holder.buffer, getFieldOffset(ordinal), input.toUnscaledLong()); - } else { - setNullAt(ordinal); - } + assert(input != null); + assert(input.precision() == precision); + assert(input.scale() == scale); + Platform.putLong(holder.buffer, getFieldOffset(ordinal), input.toUnscaledLong()); } else { // grow the global buffer before writing data. holder.grow(16); @@ -149,13 +147,14 @@ public void write(int ordinal, Decimal input, int precision, int scale) { Platform.putLong(holder.buffer, holder.cursor, 0L); Platform.putLong(holder.buffer, holder.cursor + 8, 0L); - // Make sure Decimal object has the same scale as DecimalType. // Note that we may pass in null Decimal object to set null for it. - if (input == null || !input.changePrecision(precision, scale)) { + if (input == null) { BitSetMethods.set(holder.buffer, startingOffset, ordinal); // keep the offset for future update setOffsetAndSize(ordinal, 0L); } else { + assert(input.precision() == precision); + assert(input.scale() == scale); final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); assert bytes.length <= 16; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index c7a1a2e7469ee..6ce428b87fe0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -123,7 +123,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { def set(decimal: BigDecimal): Decimal = { this.decimalVal = decimal this.longVal = 0L - this._precision = decimal.precision + this._precision = math.max(decimal.precision, decimal.scale) this._scale = decimal.scale this } @@ -306,7 +306,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { def % (that: Decimal): Decimal = if (that.isZero) null - else Decimal(toJavaBigDecimal.remainder(that.toJavaBigDecimal, MATH_CONTEXT)) + else Decimal(toJavaBigDecimal.remainder(that.toJavaBigDecimal, MATH_CONTEXT), precision, scale) def remainder(that: Decimal): Decimal = this % that diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 72285c6a24199..5c10466902c07 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -38,7 +38,6 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper testFunc(_.toLong) testFunc(_.toFloat) testFunc(_.toDouble) - testFunc(Decimal(_)) } test("+ (Add)") { @@ -49,6 +48,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Add(Literal.create(null, left.dataType), right), null) checkEvaluation(Add(left, Literal.create(null, right.dataType)), null) } + checkEvaluation(Add(Literal(Decimal(1)), Literal(Decimal(2))), Decimal(3)) checkEvaluation(Add(positiveShortLit, negativeShortLit), -1.toShort) checkEvaluation(Add(positiveIntLit, negativeIntLit), -1) checkEvaluation(Add(positiveLongLit, negativeLongLit), -1L) @@ -65,6 +65,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(UnaryMinus(input), convert(-1)) checkEvaluation(UnaryMinus(Literal.create(null, dataType)), null) } + checkEvaluation(UnaryMinus(Literal(Decimal(1))), Decimal(-1)) checkEvaluation(UnaryMinus(Literal(Long.MinValue)), Long.MinValue) checkEvaluation(UnaryMinus(Literal(Int.MinValue)), Int.MinValue) checkEvaluation(UnaryMinus(Literal(Short.MinValue)), Short.MinValue) @@ -89,6 +90,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Subtract(Literal.create(null, left.dataType), right), null) checkEvaluation(Subtract(left, Literal.create(null, right.dataType)), null) } + checkEvaluation(Subtract(Literal(Decimal(1)), Literal(Decimal(2))), Decimal(-1)) checkEvaluation(Subtract(positiveShortLit, negativeShortLit), (positiveShort - negativeShort).toShort) checkEvaluation(Subtract(positiveIntLit, negativeIntLit), positiveInt - negativeInt) @@ -107,6 +109,8 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Multiply(Literal.create(null, left.dataType), right), null) checkEvaluation(Multiply(left, Literal.create(null, right.dataType)), null) } + checkEvaluation(Cast(Multiply(Literal(Decimal(1)), Literal(Decimal(2))), DecimalType(20, 0)), + Decimal(2, 20, 0)) checkEvaluation(Multiply(positiveShortLit, negativeShortLit), (positiveShort * negativeShort).toShort) checkEvaluation(Multiply(positiveIntLit, negativeIntLit), positiveInt * negativeInt) @@ -127,7 +131,8 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Divide(left, Literal.create(null, right.dataType)), null) checkEvaluation(Divide(left, Literal(convert(0))), null) // divide by zero } - + checkEvaluation(Cast(Divide(Literal(Decimal(2)), Literal(Decimal(1))), DecimalType(18, 9)), + Decimal(2.0, 18, 9)) DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => checkConsistencyBetweenInterpretedAndCodegen(Divide, tpe, tpe) } @@ -146,7 +151,8 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper test("/ (Divide) for floating point") { checkEvaluation(Divide(Literal(1.0f), Literal(2.0f)), 0.5f) checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5) - checkEvaluation(Divide(Literal(Decimal(1.0)), Literal(Decimal(2.0))), Decimal(0.5)) + checkEvaluation(Cast(Divide(Literal(Decimal(1.0)), Literal(Decimal(2.0))), DecimalType(4, 3)), + Decimal(0.5, 4, 3)) } test("% (Remainder)") { @@ -158,6 +164,8 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Remainder(left, Literal.create(null, right.dataType)), null) checkEvaluation(Remainder(left, Literal(convert(0))), null) // mod by 0 } + checkEvaluation(Cast(Remainder(Literal(Decimal(1)), Literal(Decimal(2))), DecimalType(10, 0)), + Decimal(1)) checkEvaluation(Remainder(positiveShortLit, positiveShortLit), 0.toShort) checkEvaluation(Remainder(negativeShortLit, negativeShortLit), 0.toShort) checkEvaluation(Remainder(positiveIntLit, positiveIntLit), 0) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index f4db4da7646f8..e88b9671886fd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -222,8 +222,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkCast(1, 1.0) checkCast(123, "123") - checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123)) - checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123)) + checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123, 10, 0)) + checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123, 3, 0)) checkEvaluation(cast(123, DecimalType(3, 1)), null) checkEvaluation(cast(123, DecimalType(2, 0)), null) } @@ -240,10 +240,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkCast(1L, 1.0) checkCast(123L, "123") - checkEvaluation(cast(123L, DecimalType.USER_DEFAULT), Decimal(123)) - checkEvaluation(cast(123L, DecimalType(3, 0)), Decimal(123)) + checkEvaluation(cast(123L, DecimalType.USER_DEFAULT), Decimal(123, 10, 0)) + checkEvaluation(cast(123L, DecimalType(3, 0)), Decimal(123, 3, 0)) checkEvaluation(cast(123L, DecimalType(3, 1)), null) - checkEvaluation(cast(123L, DecimalType(2, 0)), null) } @@ -261,8 +260,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(cast(1000, TimestampType), LongType), 1.toLong) checkEvaluation(cast(cast(-1200, TimestampType), LongType), -2.toLong) - checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123)) - checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123)) + checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123, 10, 0)) + checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123, 3, 0)) checkEvaluation(cast(123, DecimalType(3, 1)), null) checkEvaluation(cast(123, DecimalType(2, 0)), null) } @@ -329,7 +328,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast("abdef", StringType), "abdef") checkEvaluation(cast("abdef", DecimalType.USER_DEFAULT), null) checkEvaluation(cast("abdef", TimestampType), null) - checkEvaluation(cast("12.65", DecimalType.SYSTEM_DEFAULT), Decimal(12.65)) + checkEvaluation(cast("12.65", DecimalType.SYSTEM_DEFAULT), Decimal(12.65, 38, 18)) checkEvaluation(cast(cast(sd, DateType), StringType), sd) checkEvaluation(cast(cast(d, StringType), DateType), 0) @@ -409,20 +408,20 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { assert(cast(Decimal(10.03), DecimalType(2, 1)).nullable === true) - checkEvaluation(cast(10.03, DecimalType.SYSTEM_DEFAULT), Decimal(10.03)) + checkEvaluation(cast(10.03, DecimalType.SYSTEM_DEFAULT), Decimal(10.03, 38, 18)) checkEvaluation(cast(10.03, DecimalType(4, 2)), Decimal(10.03)) checkEvaluation(cast(10.03, DecimalType(3, 1)), Decimal(10.0)) - checkEvaluation(cast(10.03, DecimalType(2, 0)), Decimal(10)) + checkEvaluation(cast(10.03, DecimalType(2, 0)), Decimal(10, 2, 0)) checkEvaluation(cast(10.03, DecimalType(1, 0)), null) checkEvaluation(cast(10.03, DecimalType(2, 1)), null) checkEvaluation(cast(10.03, DecimalType(3, 2)), null) checkEvaluation(cast(Decimal(10.03), DecimalType(3, 1)), Decimal(10.0)) checkEvaluation(cast(Decimal(10.03), DecimalType(3, 2)), null) - checkEvaluation(cast(10.05, DecimalType.SYSTEM_DEFAULT), Decimal(10.05)) + checkEvaluation(cast(10.05, DecimalType.SYSTEM_DEFAULT), Decimal(10.05, 38, 18)) checkEvaluation(cast(10.05, DecimalType(4, 2)), Decimal(10.05)) checkEvaluation(cast(10.05, DecimalType(3, 1)), Decimal(10.1)) - checkEvaluation(cast(10.05, DecimalType(2, 0)), Decimal(10)) + checkEvaluation(cast(10.05, DecimalType(2, 0)), Decimal(10, 2, 0)) checkEvaluation(cast(10.05, DecimalType(1, 0)), null) checkEvaluation(cast(10.05, DecimalType(2, 1)), null) checkEvaluation(cast(10.05, DecimalType(3, 2)), null) @@ -431,7 +430,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(9.95, DecimalType(3, 2)), Decimal(9.95)) checkEvaluation(cast(9.95, DecimalType(3, 1)), Decimal(10.0)) - checkEvaluation(cast(9.95, DecimalType(2, 0)), Decimal(10)) + checkEvaluation(cast(9.95, DecimalType(2, 0)), Decimal(10, 2, 0)) checkEvaluation(cast(9.95, DecimalType(2, 1)), null) checkEvaluation(cast(9.95, DecimalType(1, 0)), null) checkEvaluation(cast(Decimal(9.95), DecimalType(3, 1)), Decimal(10.0)) @@ -439,7 +438,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(-9.95, DecimalType(3, 2)), Decimal(-9.95)) checkEvaluation(cast(-9.95, DecimalType(3, 1)), Decimal(-10.0)) - checkEvaluation(cast(-9.95, DecimalType(2, 0)), Decimal(-10)) + checkEvaluation(cast(-9.95, DecimalType(2, 0)), Decimal(-10, 2, 0)) checkEvaluation(cast(-9.95, DecimalType(2, 1)), null) checkEvaluation(cast(-9.95, DecimalType(1, 0)), null) checkEvaluation(cast(Decimal(-9.95), DecimalType(3, 1)), Decimal(-10.0)) @@ -491,7 +490,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { millis.toDouble / 1000) checkEvaluation( cast(cast(Decimal(1), TimestampType), DecimalType.SYSTEM_DEFAULT), - Decimal(1)) + Decimal(1.0, 38, 18)) // A test for higher precision than millis checkEvaluation(cast(cast(0.000001, TimestampType), DoubleType), 0.000001) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala index 511f0307901df..7998475e7a359 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala @@ -46,15 +46,16 @@ class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { test("CheckOverflow") { val d1 = Decimal("10.1") - checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 0)), Decimal("10")) - checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 1)), d1) - checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 2)), d1) + println(d1.precision) + checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 0)), Decimal(10, 4, 0)) + checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 1)), Decimal(101, 4, 1)) + checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 2)), Decimal(1010, 4, 2)) checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 3)), null) val d2 = Decimal(101, 3, 1) - checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 0)), Decimal("10")) - checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 1)), d2) - checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 2)), d2) + checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 0)), Decimal(10, 4, 0)) + checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 1)), Decimal(101, 4, 1)) + checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 2)), Decimal(1010, 4, 2)) checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 3)), null) checkEvaluation(CheckOverflow(Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2)), null) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index 7b85286c4dc8c..6052098b874d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -55,8 +55,8 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.default(DoubleType), 0.0) checkEvaluation(Literal.default(StringType), "") checkEvaluation(Literal.default(BinaryType), "".getBytes) - checkEvaluation(Literal.default(DecimalType.USER_DEFAULT), Decimal(0)) - checkEvaluation(Literal.default(DecimalType.SYSTEM_DEFAULT), Decimal(0)) + checkEvaluation(Literal.default(DecimalType.USER_DEFAULT), Decimal(0, 10, 0)) + checkEvaluation(Literal.default(DecimalType.SYSTEM_DEFAULT), Decimal(0, 38, 18)) checkEvaluation(Literal.default(DateType), DateTimeUtils.toJavaDate(0)) checkEvaluation(Literal.default(TimestampType), DateTimeUtils.toJavaTimestamp(0L)) checkEvaluation(Literal.default(CalendarIntervalType), new CalendarInterval(0, 0L)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 88ed9fdd6465f..78a930db0b0ce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -20,16 +20,16 @@ package org.apache.spark.sql.catalyst.expressions import com.google.common.math.LongMath import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.types._ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { - import IntegralLiteralTestUtils._ + import org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils._ /** * Used for testing leaf math expressions. @@ -535,13 +535,13 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow) } - val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14), - BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159), - BigDecimal(3.141593), BigDecimal(3.1415927)) + val results: Seq[Decimal] = Seq(Decimal(3, 8, 0), Decimal(31, 8, 1), Decimal(314, 8, 2), + Decimal(3142, 8, 3), Decimal(31416, 8, 4), Decimal(314159, 8, 5), + Decimal(3141593, 8, 6), Decimal(31415927, 8, 7)) // round_scale > current_scale would result in precision increase // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null (0 to 7).foreach { i => - checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow) + checkEvaluation(Round(bdPi, i), results(i), EmptyRow) } (8 to 10).foreach { scale => checkEvaluation(Round(bdPi, scale), null, EmptyRow) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala index 188a64e9b66c4..d0f5bfa1cd7bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala @@ -33,6 +33,8 @@ abstract class ColumnarIterator extends Iterator[InternalRow] { /** * An helper class to update the fields of UnsafeRow, used by ColumnAccessor + * + * WARNNING: These setter MUST be called in increasing order of ordinals. */ class MutableUnsafeRow(val writer: UnsafeRowWriter) extends GenericMutableRow(null) { From f0eb10c3c6e2b50be8a8996b7127b5a7d7ba2f2f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 21 Oct 2015 16:23:32 -0700 Subject: [PATCH 5/6] Revert "address comments" This reverts commit c76c759e30e2f89a8f4a56a1a854b9d3ee8fa697. --- .../sql/catalyst/expressions/UnsafeRow.java | 14 ++++---- .../codegen/UnsafeArrayWriter.java | 35 ++++++++++--------- .../expressions/codegen/UnsafeRowWriter.java | 15 ++++---- .../org/apache/spark/sql/types/Decimal.scala | 4 +-- .../ArithmeticExpressionSuite.scala | 14 ++------ .../sql/catalyst/expressions/CastSuite.scala | 29 +++++++-------- .../expressions/DecimalExpressionSuite.scala | 13 ++++--- .../expressions/LiteralExpressionSuite.scala | 4 +-- .../expressions/MathFunctionsSuite.scala | 14 ++++---- .../sql/columnar/GenerateColumnAccessor.scala | 2 -- 10 files changed, 68 insertions(+), 76 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 850838af9be35..4b7f285d338c6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -402,7 +402,7 @@ public UTF8String getUTF8String(int ordinal) { if (isNullAt(ordinal)) return null; final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int size = (int) offsetAndSize; + final int size = (int) (offsetAndSize & ((1L << 32) - 1)); return UTF8String.fromAddress(baseObject, baseOffset + offset, size); } @@ -413,7 +413,7 @@ public byte[] getBinary(int ordinal) { } else { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int size = (int) offsetAndSize; + final int size = (int) (offsetAndSize & ((1L << 32) - 1)); final byte[] bytes = new byte[size]; Platform.copyMemory( baseObject, @@ -446,7 +446,7 @@ public UnsafeRow getStruct(int ordinal, int numFields) { } else { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int size = (int) offsetAndSize; + final int size = (int) (offsetAndSize & ((1L << 32) - 1)); final UnsafeRow row = new UnsafeRow(); row.pointTo(baseObject, baseOffset + offset, numFields, size); return row; @@ -460,7 +460,7 @@ public UnsafeArrayData getArray(int ordinal) { } else { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int size = (int) offsetAndSize; + final int size = (int) (offsetAndSize & ((1L << 32) - 1)); final UnsafeArrayData array = new UnsafeArrayData(); array.pointTo(baseObject, baseOffset + offset, size); return array; @@ -474,7 +474,7 @@ public UnsafeMapData getMap(int ordinal) { } else { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int size = (int) offsetAndSize; + final int size = (int) (offsetAndSize & ((1L << 32) - 1)); final UnsafeMapData map = new UnsafeMapData(); map.pointTo(baseObject, baseOffset + offset, size); return map; @@ -620,13 +620,11 @@ public void writeTo(ByteBuffer buffer) { /** * Write the bytes of var-length field into ByteBuffer - * - * Note: only work with HeapByteBuffer */ public void writeFieldTo(int ordinal, ByteBuffer buffer) { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int size = (int) offsetAndSize; + final int size = (int) (offsetAndSize & ((1L << 32) - 1)); buffer.putInt(size); int pos = buffer.position(); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index d9bee1f779b47..7dd932d1981b7 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -111,23 +111,26 @@ public void write(int ordinal, double value) { holder.cursor += 8; } - public void write(int ordinal, Decimal input, int precision) { - assert(input != null); - assert(input.precision() == precision); - if (precision <= Decimal.MAX_LONG_DIGITS()) { - Platform.putLong(holder.buffer, holder.cursor, input.toUnscaledLong()); - setOffset(ordinal); - holder.cursor += 8; + public void write(int ordinal, Decimal input, int precision, int scale) { + // make sure Decimal object has the same scale as DecimalType + if (input.changePrecision(precision, scale)) { + if (precision <= Decimal.MAX_LONG_DIGITS()) { + Platform.putLong(holder.buffer, holder.cursor, input.toUnscaledLong()); + setOffset(ordinal); + holder.cursor += 8; + } else { + final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); + assert bytes.length <= 16; + holder.grow(bytes.length); + + // Write the bytes to the variable length portion. + Platform.copyMemory( + bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length); + setOffset(ordinal); + holder.cursor += bytes.length; + } } else { - final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); - assert bytes.length <= 16; - holder.grow(bytes.length); - - // Write the bytes to the variable length portion. - Platform.copyMemory( - bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length); - setOffset(ordinal); - holder.cursor += bytes.length; + setNullAt(ordinal); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index cbf91c0fbb073..adbe2621870df 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -135,10 +135,12 @@ public void write(int ordinal, double value) { public void write(int ordinal, Decimal input, int precision, int scale) { if (precision <= Decimal.MAX_LONG_DIGITS()) { - assert(input != null); - assert(input.precision() == precision); - assert(input.scale() == scale); - Platform.putLong(holder.buffer, getFieldOffset(ordinal), input.toUnscaledLong()); + // make sure Decimal object has the same scale as DecimalType + if (input.changePrecision(precision, scale)) { + Platform.putLong(holder.buffer, getFieldOffset(ordinal), input.toUnscaledLong()); + } else { + setNullAt(ordinal); + } } else { // grow the global buffer before writing data. holder.grow(16); @@ -147,14 +149,13 @@ public void write(int ordinal, Decimal input, int precision, int scale) { Platform.putLong(holder.buffer, holder.cursor, 0L); Platform.putLong(holder.buffer, holder.cursor + 8, 0L); + // Make sure Decimal object has the same scale as DecimalType. // Note that we may pass in null Decimal object to set null for it. - if (input == null) { + if (input == null || !input.changePrecision(precision, scale)) { BitSetMethods.set(holder.buffer, startingOffset, ordinal); // keep the offset for future update setOffsetAndSize(ordinal, 0L); } else { - assert(input.precision() == precision); - assert(input.scale() == scale); final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); assert bytes.length <= 16; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 6ce428b87fe0e..c7a1a2e7469ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -123,7 +123,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { def set(decimal: BigDecimal): Decimal = { this.decimalVal = decimal this.longVal = 0L - this._precision = math.max(decimal.precision, decimal.scale) + this._precision = decimal.precision this._scale = decimal.scale this } @@ -306,7 +306,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { def % (that: Decimal): Decimal = if (that.isZero) null - else Decimal(toJavaBigDecimal.remainder(that.toJavaBigDecimal, MATH_CONTEXT), precision, scale) + else Decimal(toJavaBigDecimal.remainder(that.toJavaBigDecimal, MATH_CONTEXT)) def remainder(that: Decimal): Decimal = this % that diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 5c10466902c07..72285c6a24199 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -38,6 +38,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper testFunc(_.toLong) testFunc(_.toFloat) testFunc(_.toDouble) + testFunc(Decimal(_)) } test("+ (Add)") { @@ -48,7 +49,6 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Add(Literal.create(null, left.dataType), right), null) checkEvaluation(Add(left, Literal.create(null, right.dataType)), null) } - checkEvaluation(Add(Literal(Decimal(1)), Literal(Decimal(2))), Decimal(3)) checkEvaluation(Add(positiveShortLit, negativeShortLit), -1.toShort) checkEvaluation(Add(positiveIntLit, negativeIntLit), -1) checkEvaluation(Add(positiveLongLit, negativeLongLit), -1L) @@ -65,7 +65,6 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(UnaryMinus(input), convert(-1)) checkEvaluation(UnaryMinus(Literal.create(null, dataType)), null) } - checkEvaluation(UnaryMinus(Literal(Decimal(1))), Decimal(-1)) checkEvaluation(UnaryMinus(Literal(Long.MinValue)), Long.MinValue) checkEvaluation(UnaryMinus(Literal(Int.MinValue)), Int.MinValue) checkEvaluation(UnaryMinus(Literal(Short.MinValue)), Short.MinValue) @@ -90,7 +89,6 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Subtract(Literal.create(null, left.dataType), right), null) checkEvaluation(Subtract(left, Literal.create(null, right.dataType)), null) } - checkEvaluation(Subtract(Literal(Decimal(1)), Literal(Decimal(2))), Decimal(-1)) checkEvaluation(Subtract(positiveShortLit, negativeShortLit), (positiveShort - negativeShort).toShort) checkEvaluation(Subtract(positiveIntLit, negativeIntLit), positiveInt - negativeInt) @@ -109,8 +107,6 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Multiply(Literal.create(null, left.dataType), right), null) checkEvaluation(Multiply(left, Literal.create(null, right.dataType)), null) } - checkEvaluation(Cast(Multiply(Literal(Decimal(1)), Literal(Decimal(2))), DecimalType(20, 0)), - Decimal(2, 20, 0)) checkEvaluation(Multiply(positiveShortLit, negativeShortLit), (positiveShort * negativeShort).toShort) checkEvaluation(Multiply(positiveIntLit, negativeIntLit), positiveInt * negativeInt) @@ -131,8 +127,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Divide(left, Literal.create(null, right.dataType)), null) checkEvaluation(Divide(left, Literal(convert(0))), null) // divide by zero } - checkEvaluation(Cast(Divide(Literal(Decimal(2)), Literal(Decimal(1))), DecimalType(18, 9)), - Decimal(2.0, 18, 9)) + DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => checkConsistencyBetweenInterpretedAndCodegen(Divide, tpe, tpe) } @@ -151,8 +146,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper test("/ (Divide) for floating point") { checkEvaluation(Divide(Literal(1.0f), Literal(2.0f)), 0.5f) checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5) - checkEvaluation(Cast(Divide(Literal(Decimal(1.0)), Literal(Decimal(2.0))), DecimalType(4, 3)), - Decimal(0.5, 4, 3)) + checkEvaluation(Divide(Literal(Decimal(1.0)), Literal(Decimal(2.0))), Decimal(0.5)) } test("% (Remainder)") { @@ -164,8 +158,6 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Remainder(left, Literal.create(null, right.dataType)), null) checkEvaluation(Remainder(left, Literal(convert(0))), null) // mod by 0 } - checkEvaluation(Cast(Remainder(Literal(Decimal(1)), Literal(Decimal(2))), DecimalType(10, 0)), - Decimal(1)) checkEvaluation(Remainder(positiveShortLit, positiveShortLit), 0.toShort) checkEvaluation(Remainder(negativeShortLit, negativeShortLit), 0.toShort) checkEvaluation(Remainder(positiveIntLit, positiveIntLit), 0) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index e88b9671886fd..f4db4da7646f8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -222,8 +222,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkCast(1, 1.0) checkCast(123, "123") - checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123, 10, 0)) - checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123, 3, 0)) + checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123)) + checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123)) checkEvaluation(cast(123, DecimalType(3, 1)), null) checkEvaluation(cast(123, DecimalType(2, 0)), null) } @@ -240,9 +240,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkCast(1L, 1.0) checkCast(123L, "123") - checkEvaluation(cast(123L, DecimalType.USER_DEFAULT), Decimal(123, 10, 0)) - checkEvaluation(cast(123L, DecimalType(3, 0)), Decimal(123, 3, 0)) + checkEvaluation(cast(123L, DecimalType.USER_DEFAULT), Decimal(123)) + checkEvaluation(cast(123L, DecimalType(3, 0)), Decimal(123)) checkEvaluation(cast(123L, DecimalType(3, 1)), null) + checkEvaluation(cast(123L, DecimalType(2, 0)), null) } @@ -260,8 +261,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(cast(1000, TimestampType), LongType), 1.toLong) checkEvaluation(cast(cast(-1200, TimestampType), LongType), -2.toLong) - checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123, 10, 0)) - checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123, 3, 0)) + checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123)) + checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123)) checkEvaluation(cast(123, DecimalType(3, 1)), null) checkEvaluation(cast(123, DecimalType(2, 0)), null) } @@ -328,7 +329,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast("abdef", StringType), "abdef") checkEvaluation(cast("abdef", DecimalType.USER_DEFAULT), null) checkEvaluation(cast("abdef", TimestampType), null) - checkEvaluation(cast("12.65", DecimalType.SYSTEM_DEFAULT), Decimal(12.65, 38, 18)) + checkEvaluation(cast("12.65", DecimalType.SYSTEM_DEFAULT), Decimal(12.65)) checkEvaluation(cast(cast(sd, DateType), StringType), sd) checkEvaluation(cast(cast(d, StringType), DateType), 0) @@ -408,20 +409,20 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { assert(cast(Decimal(10.03), DecimalType(2, 1)).nullable === true) - checkEvaluation(cast(10.03, DecimalType.SYSTEM_DEFAULT), Decimal(10.03, 38, 18)) + checkEvaluation(cast(10.03, DecimalType.SYSTEM_DEFAULT), Decimal(10.03)) checkEvaluation(cast(10.03, DecimalType(4, 2)), Decimal(10.03)) checkEvaluation(cast(10.03, DecimalType(3, 1)), Decimal(10.0)) - checkEvaluation(cast(10.03, DecimalType(2, 0)), Decimal(10, 2, 0)) + checkEvaluation(cast(10.03, DecimalType(2, 0)), Decimal(10)) checkEvaluation(cast(10.03, DecimalType(1, 0)), null) checkEvaluation(cast(10.03, DecimalType(2, 1)), null) checkEvaluation(cast(10.03, DecimalType(3, 2)), null) checkEvaluation(cast(Decimal(10.03), DecimalType(3, 1)), Decimal(10.0)) checkEvaluation(cast(Decimal(10.03), DecimalType(3, 2)), null) - checkEvaluation(cast(10.05, DecimalType.SYSTEM_DEFAULT), Decimal(10.05, 38, 18)) + checkEvaluation(cast(10.05, DecimalType.SYSTEM_DEFAULT), Decimal(10.05)) checkEvaluation(cast(10.05, DecimalType(4, 2)), Decimal(10.05)) checkEvaluation(cast(10.05, DecimalType(3, 1)), Decimal(10.1)) - checkEvaluation(cast(10.05, DecimalType(2, 0)), Decimal(10, 2, 0)) + checkEvaluation(cast(10.05, DecimalType(2, 0)), Decimal(10)) checkEvaluation(cast(10.05, DecimalType(1, 0)), null) checkEvaluation(cast(10.05, DecimalType(2, 1)), null) checkEvaluation(cast(10.05, DecimalType(3, 2)), null) @@ -430,7 +431,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(9.95, DecimalType(3, 2)), Decimal(9.95)) checkEvaluation(cast(9.95, DecimalType(3, 1)), Decimal(10.0)) - checkEvaluation(cast(9.95, DecimalType(2, 0)), Decimal(10, 2, 0)) + checkEvaluation(cast(9.95, DecimalType(2, 0)), Decimal(10)) checkEvaluation(cast(9.95, DecimalType(2, 1)), null) checkEvaluation(cast(9.95, DecimalType(1, 0)), null) checkEvaluation(cast(Decimal(9.95), DecimalType(3, 1)), Decimal(10.0)) @@ -438,7 +439,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(-9.95, DecimalType(3, 2)), Decimal(-9.95)) checkEvaluation(cast(-9.95, DecimalType(3, 1)), Decimal(-10.0)) - checkEvaluation(cast(-9.95, DecimalType(2, 0)), Decimal(-10, 2, 0)) + checkEvaluation(cast(-9.95, DecimalType(2, 0)), Decimal(-10)) checkEvaluation(cast(-9.95, DecimalType(2, 1)), null) checkEvaluation(cast(-9.95, DecimalType(1, 0)), null) checkEvaluation(cast(Decimal(-9.95), DecimalType(3, 1)), Decimal(-10.0)) @@ -490,7 +491,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { millis.toDouble / 1000) checkEvaluation( cast(cast(Decimal(1), TimestampType), DecimalType.SYSTEM_DEFAULT), - Decimal(1.0, 38, 18)) + Decimal(1)) // A test for higher precision than millis checkEvaluation(cast(cast(0.000001, TimestampType), DoubleType), 0.000001) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala index 7998475e7a359..511f0307901df 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala @@ -46,16 +46,15 @@ class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { test("CheckOverflow") { val d1 = Decimal("10.1") - println(d1.precision) - checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 0)), Decimal(10, 4, 0)) - checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 1)), Decimal(101, 4, 1)) - checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 2)), Decimal(1010, 4, 2)) + checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 0)), Decimal("10")) + checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 1)), d1) + checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 2)), d1) checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 3)), null) val d2 = Decimal(101, 3, 1) - checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 0)), Decimal(10, 4, 0)) - checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 1)), Decimal(101, 4, 1)) - checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 2)), Decimal(1010, 4, 2)) + checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 0)), Decimal("10")) + checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 1)), d2) + checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 2)), d2) checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 3)), null) checkEvaluation(CheckOverflow(Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2)), null) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index 6052098b874d4..7b85286c4dc8c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -55,8 +55,8 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.default(DoubleType), 0.0) checkEvaluation(Literal.default(StringType), "") checkEvaluation(Literal.default(BinaryType), "".getBytes) - checkEvaluation(Literal.default(DecimalType.USER_DEFAULT), Decimal(0, 10, 0)) - checkEvaluation(Literal.default(DecimalType.SYSTEM_DEFAULT), Decimal(0, 38, 18)) + checkEvaluation(Literal.default(DecimalType.USER_DEFAULT), Decimal(0)) + checkEvaluation(Literal.default(DecimalType.SYSTEM_DEFAULT), Decimal(0)) checkEvaluation(Literal.default(DateType), DateTimeUtils.toJavaDate(0)) checkEvaluation(Literal.default(TimestampType), DateTimeUtils.toJavaTimestamp(0L)) checkEvaluation(Literal.default(CalendarIntervalType), new CalendarInterval(0, 0L)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 78a930db0b0ce..88ed9fdd6465f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -20,16 +20,16 @@ package org.apache.spark.sql.catalyst.expressions import com.google.common.math.LongMath import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.types._ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { - import org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils._ + import IntegralLiteralTestUtils._ /** * Used for testing leaf math expressions. @@ -535,13 +535,13 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow) } - val results: Seq[Decimal] = Seq(Decimal(3, 8, 0), Decimal(31, 8, 1), Decimal(314, 8, 2), - Decimal(3142, 8, 3), Decimal(31416, 8, 4), Decimal(314159, 8, 5), - Decimal(3141593, 8, 6), Decimal(31415927, 8, 7)) + val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14), + BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159), + BigDecimal(3.141593), BigDecimal(3.1415927)) // round_scale > current_scale would result in precision increase // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null (0 to 7).foreach { i => - checkEvaluation(Round(bdPi, i), results(i), EmptyRow) + checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow) } (8 to 10).foreach { scale => checkEvaluation(Round(bdPi, scale), null, EmptyRow) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala index d0f5bfa1cd7bc..188a64e9b66c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala @@ -33,8 +33,6 @@ abstract class ColumnarIterator extends Iterator[InternalRow] { /** * An helper class to update the fields of UnsafeRow, used by ColumnAccessor - * - * WARNNING: These setter MUST be called in increasing order of ordinals. */ class MutableUnsafeRow(val writer: UnsafeRowWriter) extends GenericMutableRow(null) { From cab9286b15bd784a023cd2b2c5c3ad755ebb566b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 21 Oct 2015 16:25:08 -0700 Subject: [PATCH 6/6] address comments --- .../spark/sql/catalyst/expressions/UnsafeRow.java | 14 ++++++++------ .../codegen/GenerateUnsafeProjection.scala | 2 +- .../sql/columnar/GenerateColumnAccessor.scala | 2 ++ 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 4b7f285d338c6..850838af9be35 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -402,7 +402,7 @@ public UTF8String getUTF8String(int ordinal) { if (isNullAt(ordinal)) return null; final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + final int size = (int) offsetAndSize; return UTF8String.fromAddress(baseObject, baseOffset + offset, size); } @@ -413,7 +413,7 @@ public byte[] getBinary(int ordinal) { } else { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + final int size = (int) offsetAndSize; final byte[] bytes = new byte[size]; Platform.copyMemory( baseObject, @@ -446,7 +446,7 @@ public UnsafeRow getStruct(int ordinal, int numFields) { } else { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + final int size = (int) offsetAndSize; final UnsafeRow row = new UnsafeRow(); row.pointTo(baseObject, baseOffset + offset, numFields, size); return row; @@ -460,7 +460,7 @@ public UnsafeArrayData getArray(int ordinal) { } else { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + final int size = (int) offsetAndSize; final UnsafeArrayData array = new UnsafeArrayData(); array.pointTo(baseObject, baseOffset + offset, size); return array; @@ -474,7 +474,7 @@ public UnsafeMapData getMap(int ordinal) { } else { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + final int size = (int) offsetAndSize; final UnsafeMapData map = new UnsafeMapData(); map.pointTo(baseObject, baseOffset + offset, size); return map; @@ -620,11 +620,13 @@ public void writeTo(ByteBuffer buffer) { /** * Write the bytes of var-length field into ByteBuffer + * + * Note: only work with HeapByteBuffer */ public void writeFieldTo(int ordinal, ByteBuffer buffer) { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + final int size = (int) offsetAndSize; buffer.putInt(size); int pos = buffer.position(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index a24d175aa093d..2136f82ba4752 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -69,7 +69,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro """ } - def writeExpressionsToBuffer( + private def writeExpressionsToBuffer( ctx: CodeGenContext, row: String, inputs: Seq[GeneratedExpressionCode], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala index 188a64e9b66c4..d0f5bfa1cd7bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala @@ -33,6 +33,8 @@ abstract class ColumnarIterator extends Iterator[InternalRow] { /** * An helper class to update the fields of UnsafeRow, used by ColumnAccessor + * + * WARNNING: These setter MUST be called in increasing order of ordinals. */ class MutableUnsafeRow(val writer: UnsafeRowWriter) extends GenericMutableRow(null) {