Skip to content

Commit

Permalink
Lots of TODO and doc cleanup.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Apr 24, 2015
1 parent a95291e commit 31eaabc
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 129 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@
import org.apache.spark.sql.types.UTF8String;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.bitset.BitSetMethods;
import org.apache.spark.unsafe.string.UTF8StringMethods;

// TODO: pick a better name for this class, since this is potentially confusing.
// Maybe call it UnsafeMutableRow?

/**
* An Unsafe implementation of Row which is backed by raw memory instead of Java objects.
Expand All @@ -58,6 +54,7 @@ public final class UnsafeRow implements MutableRow {

private Object baseObject;
private long baseOffset;
/** The number of fields in this row, used for calculating the bitset width (and in assertions) */
private int numFields;
/** The width of the null tracking bit set, in bytes */
private int bitSetWidthInBytes;
Expand All @@ -74,7 +71,7 @@ private long getFieldOffset(int ordinal) {
}

public static int calculateBitSetWidthInBytes(int numFields) {
return ((numFields / 64) + ((numFields % 64 == 0 ? 0 : 1))) * 8;
return ((numFields / 64) + (numFields % 64 == 0 ? 0 : 1)) * 8;
}

/**
Expand Down Expand Up @@ -211,7 +208,6 @@ public void setFloat(int ordinal, float value) {

@Override
public void setString(int ordinal, String value) {
// TODO: need to ensure that array has been suitably sized.
throw new UnsupportedOperationException();
}

Expand Down Expand Up @@ -240,23 +236,14 @@ public Object get(int i) {
assertIndexIsValid(i);
assert (schema != null) : "Schema must be defined when calling generic get() method";
final DataType dataType = schema.fields()[i].dataType();
// The ordering of these `if` statements is intentional: internally, it looks like this only
// gets invoked in JoinedRow when trying to access UTF8String columns. It's extremely unlikely
// that internal code will call this on non-string-typed columns, but we support that anyways
// just for the sake of completeness.
// TODO: complete this for the remaining types?
// UnsafeRow is only designed to be invoked by internal code, which only invokes this generic
// get() method when trying to access UTF8String-typed columns. If we refactor the codebase to
// separate the internal and external row interfaces, then internal code can fetch strings via
// a new getUTF8String() method and we'll be able to remove this method.
if (isNullAt(i)) {
return null;
} else if (dataType == StringType) {
return getUTF8String(i);
} else if (dataType == IntegerType) {
return getInt(i);
} else if (dataType == LongType) {
return getLong(i);
} else if (dataType == DoubleType) {
return getDouble(i);
} else if (dataType == FloatType) {
return getFloat(i);
} else {
throw new UnsupportedOperationException();
}
Expand Down Expand Up @@ -319,7 +306,7 @@ public UTF8String getUTF8String(int i) {
final byte[] strBytes = new byte[stringSizeInBytes];
PlatformDependent.copyMemory(
baseObject,
baseOffset + offsetToStringSize + 8, // The +8 is to skip past the size to get the data,
baseOffset + offsetToStringSize + 8, // The `+ 8` is to skip past the size to get the data
strBytes,
PlatformDependent.BYTE_ARRAY_OFFSET,
stringSizeInBytes
Expand All @@ -335,31 +322,26 @@ public String getString(int i) {

@Override
public BigDecimal getDecimal(int i) {
// TODO
throw new UnsupportedOperationException();
}

@Override
public Date getDate(int i) {
// TODO
throw new UnsupportedOperationException();
}

@Override
public <T> Seq<T> getSeq(int i) {
// TODO
throw new UnsupportedOperationException();
}

@Override
public <T> List<T> getList(int i) {
// TODO
throw new UnsupportedOperationException();
}

@Override
public <K, V> Map<K, V> getMap(int i) {
// TODO
throw new UnsupportedOperationException();
}

Expand All @@ -370,19 +352,16 @@ public <T> scala.collection.immutable.Map<String, T> getValuesMap(Seq<String> fi

@Override
public <K, V> java.util.Map<K, V> getJavaMap(int i) {
// TODO
throw new UnsupportedOperationException();
}

@Override
public Row getStruct(int i) {
// TODO
throw new UnsupportedOperationException();
}

@Override
public <T> T getAs(int i) {
// TODO
throw new UnsupportedOperationException();
}

Expand All @@ -398,7 +377,6 @@ public int fieldIndex(String name) {

@Override
public Row copy() {
// TODO
throw new UnsupportedOperationException();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,88 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.array.ByteArrayMethods

/** Write a column into an UnsafeRow */
/**
* Converts Rows into UnsafeRow format. This class is NOT thread-safe.
*
* @param fieldTypes the data types of the row's columns.
*/
class UnsafeRowConverter(fieldTypes: Array[DataType]) {

def this(schema: StructType) {
this(schema.fields.map(_.dataType))
}

/** Re-used pointer to the unsafe row being written */
private[this] val unsafeRow = new UnsafeRow()

/** Functions for encoding each column */
private[this] val writers: Array[UnsafeColumnWriter[Any]] = {
fieldTypes.map(t => UnsafeColumnWriter.forType(t).asInstanceOf[UnsafeColumnWriter[Any]])
}

/** The size, in bytes, of the fixed-length portion of the row, including the null bitmap */
private[this] val fixedLengthSize: Int =
(8 * fieldTypes.length) + UnsafeRow.calculateBitSetWidthInBytes(fieldTypes.length)

/**
* Compute the amount of space, in bytes, required to encode the given row.
*/
def getSizeRequirement(row: Row): Int = {
var fieldNumber = 0
var variableLengthFieldSize: Int = 0
while (fieldNumber < writers.length) {
if (!row.isNullAt(fieldNumber)) {
variableLengthFieldSize += writers(fieldNumber).getSize(row(fieldNumber))
}
fieldNumber += 1
}
fixedLengthSize + variableLengthFieldSize
}

/**
* Convert the given row into UnsafeRow format.
*
* @param row the row to convert
* @param baseObject the base object of the destination address
* @param baseOffset the base offset of the destination address
* @return the number of bytes written. This should be equal to `getSizeRequirement(row)`.
*/
def writeRow(row: Row, baseObject: Object, baseOffset: Long): Long = {
unsafeRow.pointTo(baseObject, baseOffset, writers.length, null)
var fieldNumber = 0
var appendCursor: Int = fixedLengthSize
while (fieldNumber < writers.length) {
if (row.isNullAt(fieldNumber)) {
unsafeRow.setNullAt(fieldNumber)
// TODO: type-specific null value writing?
} else {
appendCursor += writers(fieldNumber).write(
row(fieldNumber),
fieldNumber,
unsafeRow,
baseObject,
baseOffset,
appendCursor)
}
fieldNumber += 1
}
appendCursor
}

}

/**
* Function for writing a column into an UnsafeRow.
*/
private abstract class UnsafeColumnWriter[T] {
/**
* Write a value into an UnsafeRow.
*
* @param value the value to write
* @param columnNumber what column to write it to
* @param row a pointer to the unsafe row
* @param baseObject
* @param baseOffset
* @param baseObject the base object of the target row's address
* @param baseOffset the base offset of the target row's address
* @param appendCursor the offset from the start of the unsafe row to the end of the row;
* used for calculating where variable-length data should be written
* @return the number of variable-length bytes written
Expand All @@ -50,6 +122,12 @@ private abstract class UnsafeColumnWriter[T] {
}

private object UnsafeColumnWriter {
private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter
private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter
private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter
private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter
private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter

def forType(dataType: DataType): UnsafeColumnWriter[_] = {
dataType match {
case IntegerType => IntUnsafeColumnWriter
Expand All @@ -63,34 +141,7 @@ private object UnsafeColumnWriter {
}
}

private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter[UTF8String] {
def getSize(value: UTF8String): Int = {
// round to nearest word
val numBytes = value.getBytes.length
8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
}

override def write(
value: UTF8String,
columnNumber: Int,
row: UnsafeRow,
baseObject: Object,
baseOffset: Long,
appendCursor: Int): Int = {
val numBytes = value.getBytes.length
PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, numBytes)
PlatformDependent.copyMemory(
value.getBytes,
PlatformDependent.BYTE_ARRAY_OFFSET,
baseObject,
baseOffset + appendCursor + 8,
numBytes
)
row.setLong(columnNumber, appendCursor)
8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
}
}
private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter
// ------------------------------------------------------------------------------------------------

private abstract class PrimitiveUnsafeColumnWriter[T] extends UnsafeColumnWriter[T] {
def getSize(value: T): Int = 0
Expand All @@ -108,7 +159,6 @@ private class IntUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWrite
0
}
}
private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter

private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter[Long] {
override def write(
Expand All @@ -122,7 +172,6 @@ private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWrit
0
}
}
private case object LongUnsafeColumnWriter extends LongUnsafeColumnWriter

private class FloatUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter[Float] {
override def write(
Expand All @@ -136,7 +185,6 @@ private class FloatUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWri
0
}
}
private case object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter

private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter[Double] {
override def write(
Expand All @@ -150,55 +198,29 @@ private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWr
0
}
}
private case object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter

class UnsafeRowConverter(fieldTypes: Array[DataType]) {

def this(schema: StructType) {
this(schema.fields.map(_.dataType))
}

private[this] val unsafeRow = new UnsafeRow()

private[this] val writers: Array[UnsafeColumnWriter[Any]] = {
fieldTypes.map(t => UnsafeColumnWriter.forType(t).asInstanceOf[UnsafeColumnWriter[Any]])
}

private[this] val fixedLengthSize: Int =
(8 * fieldTypes.length) + UnsafeRow.calculateBitSetWidthInBytes(fieldTypes.length)

def getSizeRequirement(row: Row): Int = {
var fieldNumber = 0
var variableLengthFieldSize: Int = 0
while (fieldNumber < writers.length) {
if (!row.isNullAt(fieldNumber)) {
variableLengthFieldSize += writers(fieldNumber).getSize(row(fieldNumber))
}
fieldNumber += 1
}
fixedLengthSize + variableLengthFieldSize
private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter[UTF8String] {
def getSize(value: UTF8String): Int = {
8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(value.getBytes.length)
}

def writeRow(row: Row, baseObject: Object, baseOffset: Long): Long = {
unsafeRow.pointTo(baseObject, baseOffset, writers.length, null)
var fieldNumber = 0
var appendCursor: Int = fixedLengthSize
while (fieldNumber < writers.length) {
if (row.isNullAt(fieldNumber)) {
unsafeRow.setNullAt(fieldNumber)
// TODO: type-specific null value writing?
} else {
appendCursor += writers(fieldNumber).write(
row(fieldNumber),
fieldNumber,
unsafeRow,
baseObject,
baseOffset,
appendCursor)
}
fieldNumber += 1
}
appendCursor
override def write(
value: UTF8String,
columnNumber: Int,
row: UnsafeRow,
baseObject: Object,
baseOffset: Long,
appendCursor: Int): Int = {
val numBytes = value.getBytes.length
PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, numBytes)
PlatformDependent.copyMemory(
value.getBytes,
PlatformDependent.BYTE_ARRAY_OFFSET,
baseObject,
baseOffset + appendCursor + 8,
numBytes
)
row.setLong(columnNumber, appendCursor)
8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
}

}
}
Loading

0 comments on commit 31eaabc

Please sign in to comment.