From 6ffdaa16652fda6882d14023aebbe4fb9d2ece71 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 24 Apr 2015 12:24:33 -0700 Subject: [PATCH] Null handling improvements in UnsafeRow. --- .../sql/catalyst/expressions/UnsafeRow.java | 16 ++++- .../expressions/UnsafeRowConverter.scala | 24 +++---- .../expressions/UnsafeRowConverterSuite.scala | 72 ++++++++++++++++++- 3 files changed, 95 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 1a4b21f441a8b..d2f25fd2e692e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -145,6 +145,10 @@ private void assertIndexIsValid(int index) { public void setNullAt(int i) { assertIndexIsValid(i); BitSetMethods.set(baseObject, baseOffset, i); + // To preserve row equality, zero out the value when setting the column to null. + // Since this row does does not currently support updates to variable-length values, we don't + // have to worry about zeroing out that data. + PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(i), 0); } private void setNotNullAt(int i) { @@ -288,13 +292,21 @@ public long getLong(int i) { @Override public float getFloat(int i) { assertIndexIsValid(i); - return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(i)); + if (isNullAt(i)) { + return Float.NaN; + } else { + return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(i)); + } } @Override public double getDouble(int i) { assertIndexIsValid(i); - return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(i)); + if (isNullAt(i)) { + return Float.NaN; + } else { + return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(i)); + } } public UTF8String getUTF8String(int i) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index e52fc8177771b..4418c92fd6bc1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -74,7 +74,6 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { while (fieldNumber < writers.length) { if (row.isNullAt(fieldNumber)) { unsafeRow.setNullAt(fieldNumber) - // TODO: type-specific null value writing? } else { appendCursor += writers(fieldNumber).write( row(fieldNumber), @@ -122,11 +121,6 @@ private abstract class UnsafeColumnWriter[T] { } private object UnsafeColumnWriter { - private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter - private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter - private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter - private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter - private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter def forType(dataType: DataType): UnsafeColumnWriter[_] = { dataType match { @@ -143,6 +137,12 @@ private object UnsafeColumnWriter { // ------------------------------------------------------------------------------------------------ +private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter +private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter +private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter +private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter +private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter + private abstract class PrimitiveUnsafeColumnWriter[T] extends UnsafeColumnWriter[T] { def getSize(value: T): Int = 0 } @@ -205,12 +205,12 @@ private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter[UTF8 } override def write( - value: UTF8String, - columnNumber: Int, - row: UnsafeRow, - baseObject: Object, - baseOffset: Long, - appendCursor: Int): Int = { + value: UTF8String, + columnNumber: Int, + row: UnsafeRow, + baseObject: Object, + baseOffset: Long, + appendCursor: Int): Int = { val numBytes = value.getBytes.length PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, numBytes) PlatformDependent.copyMemory( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 6009ded1d58dc..211bc3333e386 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.Arrays + import org.scalatest.{FunSuite, Matchers} -import org.apache.spark.sql.types.{StringType, DataType, LongType, IntegerType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods @@ -27,16 +29,19 @@ class UnsafeRowConverterSuite extends FunSuite with Matchers { test("basic conversion with only primitive types") { val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) + val converter = new UnsafeRowConverter(fieldTypes) + val row = new SpecificMutableRow(fieldTypes) row.setLong(0, 0) row.setLong(1, 1) row.setInt(2, 2) - val converter = new UnsafeRowConverter(fieldTypes) + val sizeRequired: Int = converter.getSizeRequirement(row) sizeRequired should be (8 + (3 * 8)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) numBytesWritten should be (sizeRequired) + val unsafeRow = new UnsafeRow() unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) unsafeRow.getLong(0) should be (0) @@ -46,11 +51,13 @@ class UnsafeRowConverterSuite extends FunSuite with Matchers { test("basic conversion with primitive and string types") { val fieldTypes: Array[DataType] = Array(LongType, StringType, StringType) + val converter = new UnsafeRowConverter(fieldTypes) + val row = new SpecificMutableRow(fieldTypes) row.setLong(0, 0) row.setString(1, "Hello") row.setString(2, "World") - val converter = new UnsafeRowConverter(fieldTypes) + val sizeRequired: Int = converter.getSizeRequirement(row) sizeRequired should be (8 + (8 * 3) + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8) + @@ -58,10 +65,69 @@ class UnsafeRowConverterSuite extends FunSuite with Matchers { val buffer: Array[Long] = new Array[Long](sizeRequired / 8) val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) numBytesWritten should be (sizeRequired) + val unsafeRow = new UnsafeRow() unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) unsafeRow.getLong(0) should be (0) unsafeRow.getString(1) should be ("Hello") unsafeRow.getString(2) should be ("World") } + + test("null handling") { + val fieldTypes: Array[DataType] = Array(IntegerType, LongType, FloatType, DoubleType) + val converter = new UnsafeRowConverter(fieldTypes) + + val rowWithAllNullColumns: Row = { + val r = new SpecificMutableRow(fieldTypes) + for (i <- 0 to 3) { + r.setNullAt(i) + } + r + } + + val sizeRequired: Int = converter.getSizeRequirement(rowWithAllNullColumns) + val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8) + val numBytesWritten = converter.writeRow( + rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET) + numBytesWritten should be (sizeRequired) + + val createdFromNull = new UnsafeRow() + createdFromNull.pointTo( + createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + for (i <- 0 to 3) { + assert(createdFromNull.isNullAt(i)) + } + createdFromNull.getInt(0) should be (0) + createdFromNull.getLong(1) should be (0) + assert(java.lang.Float.isNaN(createdFromNull.getFloat(2))) + assert(java.lang.Double.isNaN(createdFromNull.getFloat(3))) + + // If we have an UnsafeRow with columns that are initially non-null and we null out those + // columns, then the serialized row representation should be identical to what we would get by + // creating an entirely null row via the converter + val rowWithNoNullColumns: Row = { + val r = new SpecificMutableRow(fieldTypes) + r.setInt(0, 100) + r.setLong(1, 200) + r.setFloat(2, 300) + r.setDouble(3, 400) + r + } + val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8) + converter.writeRow( + rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET) + val setToNullAfterCreation = new UnsafeRow() + setToNullAfterCreation.pointTo( + setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + setToNullAfterCreation.getInt(0) should be (rowWithNoNullColumns.getInt(0)) + setToNullAfterCreation.getLong(1) should be (rowWithNoNullColumns.getLong(1)) + setToNullAfterCreation.getFloat(2) should be (rowWithNoNullColumns.getFloat(2)) + setToNullAfterCreation.getDouble(3) should be (rowWithNoNullColumns.getDouble(3)) + + for (i <- 0 to 3) { + setToNullAfterCreation.setNullAt(i) + } + assert(Arrays.equals(createdFromNullBuffer, setToNullAfterCreationBuffer)) + } + }