From 338e6bfae91302b4b3264a568b235ca5365b0dc2 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 8 Jul 2015 18:24:29 -0700 Subject: [PATCH] Support copy for UnsafeRows that do not use ObjectPools. --- .../UnsafeFixedWidthAggregationMap.java | 12 +++-- .../sql/catalyst/expressions/UnsafeRow.java | 32 +++++++++++- .../expressions/UnsafeRowConverter.scala | 10 +++- .../expressions/UnsafeRowConverterSuite.scala | 52 ++++++++++++++----- 4 files changed, 87 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java index 1e79f4b2e88e5..79d55b36dab01 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java @@ -120,9 +120,11 @@ public UnsafeFixedWidthAggregationMap( this.bufferPool = new ObjectPool(initialCapacity); InternalRow initRow = initProjection.apply(emptyRow); - this.emptyBuffer = new byte[bufferConverter.getSizeRequirement(initRow)]; + int emptyBufferSize = bufferConverter.getSizeRequirement(initRow); + this.emptyBuffer = new byte[emptyBufferSize]; int writtenLength = bufferConverter.writeRow( - initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, bufferPool); + initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, emptyBufferSize, + bufferPool); assert (writtenLength == emptyBuffer.length): "Size requirement calculation was wrong!"; // re-use the empty buffer only when there is no object saved in pool. reuseEmptyBuffer = bufferPool.size() == 0; @@ -142,6 +144,7 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { groupingKey, groupingKeyConversionScratchSpace, PlatformDependent.BYTE_ARRAY_OFFSET, + groupingKeySize, keyPool); assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!"; @@ -157,7 +160,7 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { // There is some objects referenced by emptyBuffer, so generate a new one InternalRow initRow = initProjection.apply(emptyRow); bufferConverter.writeRow(initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, - bufferPool); + groupingKeySize, bufferPool); } loc.putNewKey( groupingKeyConversionScratchSpace, @@ -175,6 +178,7 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { address.getBaseObject(), address.getBaseOffset(), bufferConverter.numFields(), + loc.getValueLength(), bufferPool ); return currentBuffer; @@ -214,12 +218,14 @@ public MapEntry next() { keyAddress.getBaseObject(), keyAddress.getBaseOffset(), keyConverter.numFields(), + loc.getKeyLength(), keyPool ); entry.value.pointTo( valueAddress.getBaseObject(), valueAddress.getBaseOffset(), bufferConverter.numFields(), + loc.getValueLength(), bufferPool ); return entry; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index aeb64b045812f..edb7202245289 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -68,6 +68,9 @@ public final class UnsafeRow extends MutableRow { /** The number of fields in this row, used for calculating the bitset width (and in assertions) */ private int numFields; + /** The size of this row's backing data, in bytes) */ + private int sizeInBytes; + public int length() { return numFields; } /** The width of the null tracking bit set, in bytes */ @@ -95,14 +98,17 @@ public UnsafeRow() { } * @param baseObject the base object * @param baseOffset the offset within the base object * @param numFields the number of fields in this row + * @param sizeInBytes the size of this row's backing data, in bytes * @param pool the object pool to hold arbitrary objects */ - public void pointTo(Object baseObject, long baseOffset, int numFields, ObjectPool pool) { + public void pointTo( + Object baseObject, long baseOffset, int numFields, int sizeInBytes, ObjectPool pool) { assert numFields >= 0 : "numFields should >= 0"; this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); this.baseObject = baseObject; this.baseOffset = baseOffset; this.numFields = numFields; + this.sizeInBytes = sizeInBytes; this.pool = pool; } @@ -336,9 +342,31 @@ public double getDouble(int i) { } } + /** + * Copies this row, returning a self-contained UnsafeRow that stores its data in an internal + * byte array rather than referencing data stored in a data page. + *

+ * This method is only supported on UnsafeRows that do not use ObjectPools. + */ @Override public InternalRow copy() { - throw new UnsupportedOperationException(); + if (pool != null) { + throw new UnsupportedOperationException( + "Copy is not supported for UnsafeRows that use object pools"); + } else { + UnsafeRow rowCopy = new UnsafeRow(); + final byte[] rowDataCopy = new byte[sizeInBytes]; + PlatformDependent.copyMemory( + baseObject, + baseOffset, + rowDataCopy, + PlatformDependent.BYTE_ARRAY_OFFSET, + sizeInBytes + ); + rowCopy.pointTo( + rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes, null); + return rowCopy; + } } @Override diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index 1f395497a9839..6af5e6200e57b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -70,10 +70,16 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { * @param row the row to convert * @param baseObject the base object of the destination address * @param baseOffset the base offset of the destination address + * @param rowLengthInBytes the length calculated by `getSizeRequirement(row)` * @return the number of bytes written. This should be equal to `getSizeRequirement(row)`. */ - def writeRow(row: InternalRow, baseObject: Object, baseOffset: Long, pool: ObjectPool): Int = { - unsafeRow.pointTo(baseObject, baseOffset, writers.length, pool) + def writeRow( + row: InternalRow, + baseObject: Object, + baseOffset: Long, + rowLengthInBytes: Int, + pool: ObjectPool): Int = { + unsafeRow.pointTo(baseObject, baseOffset, writers.length, rowLengthInBytes, pool) if (writers.length > 0) { // zero-out the bitset diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 96d4e64ea344a..d00aeb4dfbf47 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -44,19 +44,32 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val sizeRequired: Int = converter.getSizeRequirement(row) assert(sizeRequired === 8 + (3 * 8)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + val numBytesWritten = + converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null) assert(numBytesWritten === sizeRequired) val unsafeRow = new UnsafeRow() - unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + unsafeRow.pointTo( + buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, null) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getLong(1) === 1) assert(unsafeRow.getInt(2) === 2) + // We can copy UnsafeRows as long as they don't reference ObjectPools + val unsafeRowCopy = unsafeRow.copy() + assert(unsafeRowCopy.getLong(0) === 0) + assert(unsafeRowCopy.getLong(1) === 1) + assert(unsafeRowCopy.getInt(2) === 2) + unsafeRow.setLong(1, 3) assert(unsafeRow.getLong(1) === 3) unsafeRow.setInt(2, 4) assert(unsafeRow.getInt(2) === 4) + + // Mutating the original row should not have changed the copy + assert(unsafeRowCopy.getLong(0) === 0) + assert(unsafeRowCopy.getLong(1) === 1) + assert(unsafeRowCopy.getInt(2) === 2) } test("basic conversion with primitive, string and binary types") { @@ -73,12 +86,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) + ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + val numBytesWritten = converter.writeRow( + row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null) assert(numBytesWritten === sizeRequired) val unsafeRow = new UnsafeRow() val pool = new ObjectPool(10) - unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool) + unsafeRow.pointTo( + buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, pool) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getString(1) === "Hello") assert(unsafeRow.get(2) === "World".getBytes) @@ -96,6 +111,11 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { unsafeRow.update(2, "Hello World".getBytes) assert(unsafeRow.get(2) === "Hello World".getBytes) assert(pool.size === 2) + + // We do not support copy() for UnsafeRows that reference ObjectPools + intercept[UnsupportedOperationException] { + unsafeRow.copy() + } } test("basic conversion with primitive, decimal and array") { @@ -111,12 +131,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val sizeRequired: Int = converter.getSizeRequirement(row) assert(sizeRequired === 8 + (8 * 3)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, pool) + val numBytesWritten = + converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, pool) assert(numBytesWritten === sizeRequired) assert(pool.size === 2) val unsafeRow = new UnsafeRow() - unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool) + unsafeRow.pointTo( + buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, pool) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.get(1) === Decimal(1)) assert(unsafeRow.get(2) === Array(2)) @@ -142,11 +164,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(sizeRequired === 8 + (8 * 4) + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + val numBytesWritten = + converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null) assert(numBytesWritten === sizeRequired) val unsafeRow = new UnsafeRow() - unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + unsafeRow.pointTo( + buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, null) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getString(1) === "Hello") // Date is represented as Int in unsafeRow @@ -190,12 +214,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val sizeRequired: Int = converter.getSizeRequirement(rowWithAllNullColumns) val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8) val numBytesWritten = converter.writeRow( - rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, + sizeRequired, null) assert(numBytesWritten === sizeRequired) val createdFromNull = new UnsafeRow() createdFromNull.pointTo( - createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, + sizeRequired, null) for (i <- 0 to fieldTypes.length - 1) { assert(createdFromNull.isNullAt(i)) } @@ -233,10 +259,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val pool = new ObjectPool(1) val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8 + 2) converter.writeRow( - rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, pool) + rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, + sizeRequired, pool) val setToNullAfterCreation = new UnsafeRow() setToNullAfterCreation.pointTo( - setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool) + setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, + sizeRequired, pool) assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1))