Skip to content

Commit

Permalink
Null handling improvements in UnsafeRow.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Apr 24, 2015
1 parent 31eaabc commit 6ffdaa1
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ private void assertIndexIsValid(int index) {
public void setNullAt(int i) {
assertIndexIsValid(i);
BitSetMethods.set(baseObject, baseOffset, i);
// To preserve row equality, zero out the value when setting the column to null.
// Since this row does does not currently support updates to variable-length values, we don't
// have to worry about zeroing out that data.
PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(i), 0);
}

private void setNotNullAt(int i) {
Expand Down Expand Up @@ -288,13 +292,21 @@ public long getLong(int i) {
@Override
public float getFloat(int i) {
assertIndexIsValid(i);
return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(i));
if (isNullAt(i)) {
return Float.NaN;
} else {
return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(i));
}
}

@Override
public double getDouble(int i) {
assertIndexIsValid(i);
return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(i));
if (isNullAt(i)) {
return Float.NaN;
} else {
return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(i));
}
}

public UTF8String getUTF8String(int i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) {
while (fieldNumber < writers.length) {
if (row.isNullAt(fieldNumber)) {
unsafeRow.setNullAt(fieldNumber)
// TODO: type-specific null value writing?
} else {
appendCursor += writers(fieldNumber).write(
row(fieldNumber),
Expand Down Expand Up @@ -122,11 +121,6 @@ private abstract class UnsafeColumnWriter[T] {
}

private object UnsafeColumnWriter {
private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter
private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter
private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter
private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter
private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter

def forType(dataType: DataType): UnsafeColumnWriter[_] = {
dataType match {
Expand All @@ -143,6 +137,12 @@ private object UnsafeColumnWriter {

// ------------------------------------------------------------------------------------------------

private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter
private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter
private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter
private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter
private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter

private abstract class PrimitiveUnsafeColumnWriter[T] extends UnsafeColumnWriter[T] {
def getSize(value: T): Int = 0
}
Expand Down Expand Up @@ -205,12 +205,12 @@ private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter[UTF8
}

override def write(
value: UTF8String,
columnNumber: Int,
row: UnsafeRow,
baseObject: Object,
baseOffset: Long,
appendCursor: Int): Int = {
value: UTF8String,
columnNumber: Int,
row: UnsafeRow,
baseObject: Object,
baseOffset: Long,
appendCursor: Int): Int = {
val numBytes = value.getBytes.length
PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, numBytes)
PlatformDependent.copyMemory(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,31 @@

package org.apache.spark.sql.catalyst.expressions

import java.util.Arrays

import org.scalatest.{FunSuite, Matchers}

import org.apache.spark.sql.types.{StringType, DataType, LongType, IntegerType}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.array.ByteArrayMethods

class UnsafeRowConverterSuite extends FunSuite with Matchers {

test("basic conversion with only primitive types") {
val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType)
val converter = new UnsafeRowConverter(fieldTypes)

val row = new SpecificMutableRow(fieldTypes)
row.setLong(0, 0)
row.setLong(1, 1)
row.setInt(2, 2)
val converter = new UnsafeRowConverter(fieldTypes)

val sizeRequired: Int = converter.getSizeRequirement(row)
sizeRequired should be (8 + (3 * 8))
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
numBytesWritten should be (sizeRequired)

val unsafeRow = new UnsafeRow()
unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
unsafeRow.getLong(0) should be (0)
Expand All @@ -46,22 +51,83 @@ class UnsafeRowConverterSuite extends FunSuite with Matchers {

test("basic conversion with primitive and string types") {
val fieldTypes: Array[DataType] = Array(LongType, StringType, StringType)
val converter = new UnsafeRowConverter(fieldTypes)

val row = new SpecificMutableRow(fieldTypes)
row.setLong(0, 0)
row.setString(1, "Hello")
row.setString(2, "World")
val converter = new UnsafeRowConverter(fieldTypes)

val sizeRequired: Int = converter.getSizeRequirement(row)
sizeRequired should be (8 + (8 * 3) +
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8) +
ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length + 8))
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
numBytesWritten should be (sizeRequired)

val unsafeRow = new UnsafeRow()
unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
unsafeRow.getLong(0) should be (0)
unsafeRow.getString(1) should be ("Hello")
unsafeRow.getString(2) should be ("World")
}

test("null handling") {
val fieldTypes: Array[DataType] = Array(IntegerType, LongType, FloatType, DoubleType)
val converter = new UnsafeRowConverter(fieldTypes)

val rowWithAllNullColumns: Row = {
val r = new SpecificMutableRow(fieldTypes)
for (i <- 0 to 3) {
r.setNullAt(i)
}
r
}

val sizeRequired: Int = converter.getSizeRequirement(rowWithAllNullColumns)
val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8)
val numBytesWritten = converter.writeRow(
rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET)
numBytesWritten should be (sizeRequired)

val createdFromNull = new UnsafeRow()
createdFromNull.pointTo(
createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
for (i <- 0 to 3) {
assert(createdFromNull.isNullAt(i))
}
createdFromNull.getInt(0) should be (0)
createdFromNull.getLong(1) should be (0)
assert(java.lang.Float.isNaN(createdFromNull.getFloat(2)))
assert(java.lang.Double.isNaN(createdFromNull.getFloat(3)))

// If we have an UnsafeRow with columns that are initially non-null and we null out those
// columns, then the serialized row representation should be identical to what we would get by
// creating an entirely null row via the converter
val rowWithNoNullColumns: Row = {
val r = new SpecificMutableRow(fieldTypes)
r.setInt(0, 100)
r.setLong(1, 200)
r.setFloat(2, 300)
r.setDouble(3, 400)
r
}
val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8)
converter.writeRow(
rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET)
val setToNullAfterCreation = new UnsafeRow()
setToNullAfterCreation.pointTo(
setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
setToNullAfterCreation.getInt(0) should be (rowWithNoNullColumns.getInt(0))
setToNullAfterCreation.getLong(1) should be (rowWithNoNullColumns.getLong(1))
setToNullAfterCreation.getFloat(2) should be (rowWithNoNullColumns.getFloat(2))
setToNullAfterCreation.getDouble(3) should be (rowWithNoNullColumns.getDouble(3))

for (i <- 0 to 3) {
setToNullAfterCreation.setNullAt(i)
}
assert(Arrays.equals(createdFromNullBuffer, setToNullAfterCreationBuffer))
}

}

0 comments on commit 6ffdaa1

Please sign in to comment.