From 31eaabcddcc5e3dda88a70645a28d476f853849f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 24 Apr 2015 11:50:58 -0700 Subject: [PATCH] Lots of TODO and doc cleanup. --- .../sql/catalyst/expressions/UnsafeRow.java | 36 +--- .../expressions/UnsafeRowConverter.scala | 186 ++++++++++-------- .../UnsafeFixedWidthAggregationMapSuite.scala | 7 +- .../expressions/UnsafeRowConverterSuite.scala | 4 +- .../spark/unsafe/map/BytesToBytesMap.java | 37 ++-- 5 files changed, 141 insertions(+), 129 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 63a4fac2ff4a0..1a4b21f441a8b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -33,10 +33,6 @@ import org.apache.spark.sql.types.UTF8String; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.bitset.BitSetMethods; -import org.apache.spark.unsafe.string.UTF8StringMethods; - -// TODO: pick a better name for this class, since this is potentially confusing. -// Maybe call it UnsafeMutableRow? /** * An Unsafe implementation of Row which is backed by raw memory instead of Java objects. @@ -58,6 +54,7 @@ public final class UnsafeRow implements MutableRow { private Object baseObject; private long baseOffset; + /** The number of fields in this row, used for calculating the bitset width (and in assertions) */ private int numFields; /** The width of the null tracking bit set, in bytes */ private int bitSetWidthInBytes; @@ -74,7 +71,7 @@ private long getFieldOffset(int ordinal) { } public static int calculateBitSetWidthInBytes(int numFields) { - return ((numFields / 64) + ((numFields % 64 == 0 ? 0 : 1))) * 8; + return ((numFields / 64) + (numFields % 64 == 0 ? 0 : 1)) * 8; } /** @@ -211,7 +208,6 @@ public void setFloat(int ordinal, float value) { @Override public void setString(int ordinal, String value) { - // TODO: need to ensure that array has been suitably sized. throw new UnsupportedOperationException(); } @@ -240,23 +236,14 @@ public Object get(int i) { assertIndexIsValid(i); assert (schema != null) : "Schema must be defined when calling generic get() method"; final DataType dataType = schema.fields()[i].dataType(); - // The ordering of these `if` statements is intentional: internally, it looks like this only - // gets invoked in JoinedRow when trying to access UTF8String columns. It's extremely unlikely - // that internal code will call this on non-string-typed columns, but we support that anyways - // just for the sake of completeness. - // TODO: complete this for the remaining types? + // UnsafeRow is only designed to be invoked by internal code, which only invokes this generic + // get() method when trying to access UTF8String-typed columns. If we refactor the codebase to + // separate the internal and external row interfaces, then internal code can fetch strings via + // a new getUTF8String() method and we'll be able to remove this method. if (isNullAt(i)) { return null; } else if (dataType == StringType) { return getUTF8String(i); - } else if (dataType == IntegerType) { - return getInt(i); - } else if (dataType == LongType) { - return getLong(i); - } else if (dataType == DoubleType) { - return getDouble(i); - } else if (dataType == FloatType) { - return getFloat(i); } else { throw new UnsupportedOperationException(); } @@ -319,7 +306,7 @@ public UTF8String getUTF8String(int i) { final byte[] strBytes = new byte[stringSizeInBytes]; PlatformDependent.copyMemory( baseObject, - baseOffset + offsetToStringSize + 8, // The +8 is to skip past the size to get the data, + baseOffset + offsetToStringSize + 8, // The `+ 8` is to skip past the size to get the data strBytes, PlatformDependent.BYTE_ARRAY_OFFSET, stringSizeInBytes @@ -335,31 +322,26 @@ public String getString(int i) { @Override public BigDecimal getDecimal(int i) { - // TODO throw new UnsupportedOperationException(); } @Override public Date getDate(int i) { - // TODO throw new UnsupportedOperationException(); } @Override public Seq getSeq(int i) { - // TODO throw new UnsupportedOperationException(); } @Override public List getList(int i) { - // TODO throw new UnsupportedOperationException(); } @Override public Map getMap(int i) { - // TODO throw new UnsupportedOperationException(); } @@ -370,19 +352,16 @@ public scala.collection.immutable.Map getValuesMap(Seq fi @Override public java.util.Map getJavaMap(int i) { - // TODO throw new UnsupportedOperationException(); } @Override public Row getStruct(int i) { - // TODO throw new UnsupportedOperationException(); } @Override public T getAs(int i) { - // TODO throw new UnsupportedOperationException(); } @@ -398,7 +377,6 @@ public int fieldIndex(String name) { @Override public Row copy() { - // TODO throw new UnsupportedOperationException(); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index 8e09d76a320a5..e52fc8177771b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -21,7 +21,79 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods -/** Write a column into an UnsafeRow */ +/** + * Converts Rows into UnsafeRow format. This class is NOT thread-safe. + * + * @param fieldTypes the data types of the row's columns. + */ +class UnsafeRowConverter(fieldTypes: Array[DataType]) { + + def this(schema: StructType) { + this(schema.fields.map(_.dataType)) + } + + /** Re-used pointer to the unsafe row being written */ + private[this] val unsafeRow = new UnsafeRow() + + /** Functions for encoding each column */ + private[this] val writers: Array[UnsafeColumnWriter[Any]] = { + fieldTypes.map(t => UnsafeColumnWriter.forType(t).asInstanceOf[UnsafeColumnWriter[Any]]) + } + + /** The size, in bytes, of the fixed-length portion of the row, including the null bitmap */ + private[this] val fixedLengthSize: Int = + (8 * fieldTypes.length) + UnsafeRow.calculateBitSetWidthInBytes(fieldTypes.length) + + /** + * Compute the amount of space, in bytes, required to encode the given row. + */ + def getSizeRequirement(row: Row): Int = { + var fieldNumber = 0 + var variableLengthFieldSize: Int = 0 + while (fieldNumber < writers.length) { + if (!row.isNullAt(fieldNumber)) { + variableLengthFieldSize += writers(fieldNumber).getSize(row(fieldNumber)) + } + fieldNumber += 1 + } + fixedLengthSize + variableLengthFieldSize + } + + /** + * Convert the given row into UnsafeRow format. + * + * @param row the row to convert + * @param baseObject the base object of the destination address + * @param baseOffset the base offset of the destination address + * @return the number of bytes written. This should be equal to `getSizeRequirement(row)`. + */ + def writeRow(row: Row, baseObject: Object, baseOffset: Long): Long = { + unsafeRow.pointTo(baseObject, baseOffset, writers.length, null) + var fieldNumber = 0 + var appendCursor: Int = fixedLengthSize + while (fieldNumber < writers.length) { + if (row.isNullAt(fieldNumber)) { + unsafeRow.setNullAt(fieldNumber) + // TODO: type-specific null value writing? + } else { + appendCursor += writers(fieldNumber).write( + row(fieldNumber), + fieldNumber, + unsafeRow, + baseObject, + baseOffset, + appendCursor) + } + fieldNumber += 1 + } + appendCursor + } + +} + +/** + * Function for writing a column into an UnsafeRow. + */ private abstract class UnsafeColumnWriter[T] { /** * Write a value into an UnsafeRow. @@ -29,8 +101,8 @@ private abstract class UnsafeColumnWriter[T] { * @param value the value to write * @param columnNumber what column to write it to * @param row a pointer to the unsafe row - * @param baseObject - * @param baseOffset + * @param baseObject the base object of the target row's address + * @param baseOffset the base offset of the target row's address * @param appendCursor the offset from the start of the unsafe row to the end of the row; * used for calculating where variable-length data should be written * @return the number of variable-length bytes written @@ -50,6 +122,12 @@ private abstract class UnsafeColumnWriter[T] { } private object UnsafeColumnWriter { + private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter + private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter + private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter + private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter + private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter + def forType(dataType: DataType): UnsafeColumnWriter[_] = { dataType match { case IntegerType => IntUnsafeColumnWriter @@ -63,34 +141,7 @@ private object UnsafeColumnWriter { } } -private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter[UTF8String] { - def getSize(value: UTF8String): Int = { - // round to nearest word - val numBytes = value.getBytes.length - 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) - } - - override def write( - value: UTF8String, - columnNumber: Int, - row: UnsafeRow, - baseObject: Object, - baseOffset: Long, - appendCursor: Int): Int = { - val numBytes = value.getBytes.length - PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, numBytes) - PlatformDependent.copyMemory( - value.getBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - baseObject, - baseOffset + appendCursor + 8, - numBytes - ) - row.setLong(columnNumber, appendCursor) - 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) - } -} -private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter +// ------------------------------------------------------------------------------------------------ private abstract class PrimitiveUnsafeColumnWriter[T] extends UnsafeColumnWriter[T] { def getSize(value: T): Int = 0 @@ -108,7 +159,6 @@ private class IntUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWrite 0 } } -private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter[Long] { override def write( @@ -122,7 +172,6 @@ private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWrit 0 } } -private case object LongUnsafeColumnWriter extends LongUnsafeColumnWriter private class FloatUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter[Float] { override def write( @@ -136,7 +185,6 @@ private class FloatUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWri 0 } } -private case object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter[Double] { override def write( @@ -150,55 +198,29 @@ private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWr 0 } } -private case object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter -class UnsafeRowConverter(fieldTypes: Array[DataType]) { - - def this(schema: StructType) { - this(schema.fields.map(_.dataType)) - } - - private[this] val unsafeRow = new UnsafeRow() - - private[this] val writers: Array[UnsafeColumnWriter[Any]] = { - fieldTypes.map(t => UnsafeColumnWriter.forType(t).asInstanceOf[UnsafeColumnWriter[Any]]) - } - - private[this] val fixedLengthSize: Int = - (8 * fieldTypes.length) + UnsafeRow.calculateBitSetWidthInBytes(fieldTypes.length) - - def getSizeRequirement(row: Row): Int = { - var fieldNumber = 0 - var variableLengthFieldSize: Int = 0 - while (fieldNumber < writers.length) { - if (!row.isNullAt(fieldNumber)) { - variableLengthFieldSize += writers(fieldNumber).getSize(row(fieldNumber)) - } - fieldNumber += 1 - } - fixedLengthSize + variableLengthFieldSize +private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter[UTF8String] { + def getSize(value: UTF8String): Int = { + 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(value.getBytes.length) } - def writeRow(row: Row, baseObject: Object, baseOffset: Long): Long = { - unsafeRow.pointTo(baseObject, baseOffset, writers.length, null) - var fieldNumber = 0 - var appendCursor: Int = fixedLengthSize - while (fieldNumber < writers.length) { - if (row.isNullAt(fieldNumber)) { - unsafeRow.setNullAt(fieldNumber) - // TODO: type-specific null value writing? - } else { - appendCursor += writers(fieldNumber).write( - row(fieldNumber), - fieldNumber, - unsafeRow, - baseObject, - baseOffset, - appendCursor) - } - fieldNumber += 1 - } - appendCursor + override def write( + value: UTF8String, + columnNumber: Int, + row: UnsafeRow, + baseObject: Object, + baseOffset: Long, + appendCursor: Int): Int = { + val numBytes = value.getBytes.length + PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, numBytes) + PlatformDependent.copyMemory( + value.getBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + baseObject, + baseOffset + appendCursor + 8, + numBytes + ) + row.setLong(columnNumber, appendCursor) + 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) } - -} \ No newline at end of file +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala index 956a80ade2f02..ba0b05514322d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala @@ -46,7 +46,8 @@ class UnsafeFixedWidthAggregationMapSuite extends FunSuite with Matchers { aggBufferSchema, groupKeySchema, MemoryAllocator.HEAP, - 1024 + 1024, + false ) assert(!map.iterator().hasNext) map.free() @@ -58,7 +59,8 @@ class UnsafeFixedWidthAggregationMapSuite extends FunSuite with Matchers { aggBufferSchema, groupKeySchema, MemoryAllocator.HEAP, - 1024 + 1024, + false ) val groupKey = new GenericRow(Array[Any](UTF8String("cats"))) @@ -77,5 +79,4 @@ class UnsafeFixedWidthAggregationMapSuite extends FunSuite with Matchers { map.free() } - } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 5bf2d808a7252..6009ded1d58dc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.catalyst.expressions +import org.scalatest.{FunSuite, Matchers} + import org.apache.spark.sql.types.{StringType, DataType, LongType, IntegerType} import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods -import org.scalatest.{FunSuite, Matchers} - class UnsafeRowConverterSuite extends FunSuite with Matchers { diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 3f48dfa4f94a0..20099f56141fd 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -69,8 +69,6 @@ public final class BytesToBytesMap { */ private final List dataPages = new LinkedList(); - private static final long PAGE_SIZE_BYTES = 64000000; - /** * The data page that will be used to store keys and values for new hashtable entries. When this * page becomes full, a new page will be allocated and this pointer will change to point to that @@ -102,16 +100,20 @@ public final class BytesToBytesMap { /** * The number of entries in the page table. */ - private static final int PAGE_TABLE_SIZE = 8096; // Use the upper 13 bits to address the table. + private static final int PAGE_TABLE_SIZE = (int) 1L << 13; - // TODO: This page table size places a limit on the maximum page size. We should account for this - // somewhere as part of final cleanup in this file. + /** + * The size of the data pages that hold key and value data. Map entries cannot span multiple + * pages, so this limits the maximum entry size. + */ + private static final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes + // This choice of page table size and page size means that we can address up to 500 gigabytes + // of memory. /** * A single array to store the key and value. * - * // TODO this comment may be out of date; fix it: * Position {@code 2 * i} in the array is used to track a pointer to the key at index {@code i}, * while position {@code 2 * i + 1} in the array holds the upper bits of the key's hashcode plus * the relative offset from the key pointer to the value at index {@code i}. @@ -131,18 +133,25 @@ public final class BytesToBytesMap { */ private int size; + /** + * The map will be expanded once the number of keys exceeds this threshold. + */ private int growthThreshold; + /** + * Mask for truncating hashcodes so that they do not exceed the long array's size. + */ private int mask; + /** + * Return value of {@link BytesToBytesMap#lookup(Object, long, int)}. + */ private final Location loc; private final boolean enablePerfMetrics; private long timeSpentResizingMs = 0; - private int numResizes = 0; - private long numProbes = 0; private long numKeyLookups = 0; @@ -191,7 +200,7 @@ protected void finalize() throws Throwable { /** * Returns an iterator for iterating over the entries of this map. * - * For efficiency, all calls to `next()` will return the same `Location` object. + * For efficiency, all calls to `next()` will return the same {@link Location} object. * * If any other lookups or operations are performed on this map while iterating over it, including * `lookup()`, the behavior of the returned iterator is undefined. @@ -479,6 +488,12 @@ public void putNewKey( } } + /** + * Allocate new data structures for this map. When calling this outside of the constructor, + * make sure to keep references to the old data structures so that you can free them. + * + * @param capacity the new map capacity + */ private void allocate(int capacity) { capacity = Math.max((int) Math.min(Integer.MAX_VALUE, nextPowerOf2(capacity)), 64); longArray = new LongArray(allocator.allocate(capacity * 8 * 2)); @@ -553,7 +568,6 @@ public long getNumHashCollisions() { private void growAndRehash() { long resizeStartTime = -1; if (enablePerfMetrics) { - numResizes++; resizeStartTime = System.currentTimeMillis(); } // Store references to the old data structures to be used when we re-hash @@ -588,9 +602,6 @@ private void growAndRehash() { } } - // TODO: we should probably have a try-finally block here to make sure that we free the allocated - // memory even if an error occurs. - // Deallocate the old data structures. allocator.free(oldLongArray.memoryBlock()); allocator.free(oldBitSet.memoryBlock());