Skip to content

Commit

Permalink
Add converters for Null, Boolean, Byte, and Short columns.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Apr 29, 2015
1 parent 81f34f8 commit eeee512
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,14 @@ public static int calculateBitSetWidthInBytes(int numFields) {
settableFieldTypes = Collections.unmodifiableSet(
new HashSet<DataType>(
Arrays.asList(new DataType[] {
IntegerType,
LongType,
DoubleType,
NullType,
BooleanType,
ShortType,
ByteType,
FloatType
ShortType,
IntegerType,
LongType,
FloatType,
DoubleType
})));

// We support get() on a superset of the types for which we support set():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ private object UnsafeColumnWriter {

def forType(dataType: DataType): UnsafeColumnWriter = {
dataType match {
case NullType => NullUnsafeColumnWriter
case BooleanType => BooleanUnsafeColumnWriter
case ByteType => ByteUnsafeColumnWriter
case ShortType => ShortUnsafeColumnWriter
case IntegerType => IntUnsafeColumnWriter
case LongType => LongUnsafeColumnWriter
case FloatType => FloatUnsafeColumnWriter
Expand All @@ -123,6 +127,10 @@ private object UnsafeColumnWriter {

// ------------------------------------------------------------------------------------------------

private object NullUnsafeColumnWriter extends NullUnsafeColumnWriter
private object BooleanUnsafeColumnWriter extends BooleanUnsafeColumnWriter
private object ByteUnsafeColumnWriter extends ByteUnsafeColumnWriter
private object ShortUnsafeColumnWriter extends ShortUnsafeColumnWriter
private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter
private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter
private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter
Expand All @@ -134,6 +142,34 @@ private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter {
def getSize(sourceRow: Row, column: Int): Int = 0
}

private class NullUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = {
target.setNullAt(column)
0
}
}

private class BooleanUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = {
target.setBoolean(column, source.getBoolean(column))
0
}
}

private class ByteUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = {
target.setByte(column, source.getByte(column))
0
}
}

private class ShortUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = {
target.setShort(column, source.getShort(column))
0
}
}

private class IntUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = {
target.setInt(column, source.getInt(column))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,20 @@ class UnsafeRowConverterSuite extends FunSuite with Matchers {
}

test("null handling") {
val fieldTypes: Array[DataType] = Array(IntegerType, LongType, FloatType, DoubleType)
val fieldTypes: Array[DataType] = Array(
NullType,
BooleanType,
ByteType,
ShortType,
IntegerType,
LongType,
FloatType,
DoubleType)
val converter = new UnsafeRowConverter(fieldTypes)

val rowWithAllNullColumns: Row = {
val r = new SpecificMutableRow(fieldTypes)
for (i <- 0 to 3) {
for (i <- 0 to fieldTypes.length - 1) {
r.setNullAt(i)
}
r
Expand All @@ -94,23 +102,30 @@ class UnsafeRowConverterSuite extends FunSuite with Matchers {
val createdFromNull = new UnsafeRow()
createdFromNull.pointTo(
createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
for (i <- 0 to 3) {
for (i <- 0 to fieldTypes.length - 1) {
assert(createdFromNull.isNullAt(i))
}
createdFromNull.getInt(0) should be (0)
createdFromNull.getLong(1) should be (0)
assert(java.lang.Float.isNaN(createdFromNull.getFloat(2)))
assert(java.lang.Double.isNaN(createdFromNull.getFloat(3)))
createdFromNull.getBoolean(1) should be (false)
createdFromNull.getByte(2) should be (0)
createdFromNull.getShort(3) should be (0)
createdFromNull.getInt(4) should be (0)
createdFromNull.getLong(5) should be (0)
assert(java.lang.Float.isNaN(createdFromNull.getFloat(6)))
assert(java.lang.Double.isNaN(createdFromNull.getFloat(7)))

// If we have an UnsafeRow with columns that are initially non-null and we null out those
// columns, then the serialized row representation should be identical to what we would get by
// creating an entirely null row via the converter
val rowWithNoNullColumns: Row = {
val r = new SpecificMutableRow(fieldTypes)
r.setInt(0, 100)
r.setLong(1, 200)
r.setFloat(2, 300)
r.setDouble(3, 400)
r.setNullAt(0)
r.setBoolean(1, false)
r.setByte(2, 20)
r.setShort(3, 30)
r.setInt(4, 400)
r.setLong(5, 500)
r.setFloat(6, 600)
r.setDouble(7, 700)
r
}
val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8)
Expand All @@ -119,12 +134,17 @@ class UnsafeRowConverterSuite extends FunSuite with Matchers {
val setToNullAfterCreation = new UnsafeRow()
setToNullAfterCreation.pointTo(
setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
setToNullAfterCreation.getInt(0) should be (rowWithNoNullColumns.getInt(0))
setToNullAfterCreation.getLong(1) should be (rowWithNoNullColumns.getLong(1))
setToNullAfterCreation.getFloat(2) should be (rowWithNoNullColumns.getFloat(2))
setToNullAfterCreation.getDouble(3) should be (rowWithNoNullColumns.getDouble(3))

for (i <- 0 to 3) {
setToNullAfterCreation.isNullAt(0) should be (rowWithNoNullColumns.isNullAt(0))
setToNullAfterCreation.getBoolean(1) should be (rowWithNoNullColumns.getBoolean(1))
setToNullAfterCreation.getByte(2) should be (rowWithNoNullColumns.getByte(2))
setToNullAfterCreation.getShort(3) should be (rowWithNoNullColumns.getShort(3))
setToNullAfterCreation.getInt(4) should be (rowWithNoNullColumns.getInt(4))
setToNullAfterCreation.getLong(5) should be (rowWithNoNullColumns.getLong(5))
setToNullAfterCreation.getFloat(6) should be (rowWithNoNullColumns.getFloat(6))
setToNullAfterCreation.getDouble(7) should be (rowWithNoNullColumns.getDouble(7))

for (i <- 0 to fieldTypes.length - 1) {
setToNullAfterCreation.setNullAt(i)
}
assert(Arrays.equals(createdFromNullBuffer, setToNullAfterCreationBuffer))
Expand Down

0 comments on commit eeee512

Please sign in to comment.