From 90da7dc241f8eec2348c0434312c97c116330bc4 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 15 Jun 2018 13:47:48 -0700 Subject: [PATCH 1/2] [SPARK-24452][SQL][CORE] Avoid possible overflow in int add or multiple ## What changes were proposed in this pull request? This PR fixes possible overflow in int add or multiply. In particular, their overflows in multiply are detected by [Spotbugs](https://spotbugs.github.io/) The following assignments may cause overflow in right hand side. As a result, the result may be negative. ``` long = int * int long = int + int ``` To avoid this problem, this PR performs cast from int to long in right hand side. ## How was this patch tested? Existing UTs. Author: Kazuaki Ishizaki Closes #21481 from kiszk/SPARK-24452. --- .../spark/unsafe/map/BytesToBytesMap.java | 2 +- .../spark/deploy/worker/DriverRunner.scala | 2 +- .../apache/spark/rdd/AsyncRDDActions.scala | 2 +- .../apache/spark/storage/BlockManager.scala | 2 +- .../catalyst/expressions/UnsafeArrayData.java | 14 +-- .../VariableLengthRowBasedKeyValueBatch.java | 2 +- .../vectorized/OffHeapColumnVector.java | 106 +++++++++--------- .../vectorized/OnHeapColumnVector.java | 10 +- .../sources/RateStreamMicroBatchReader.scala | 2 +- .../spark/sql/hive/client/HiveShim.scala | 2 +- .../util/FileBasedWriteAheadLog.scala | 2 +- 11 files changed, 73 insertions(+), 73 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 5f0045507aaab..9a767dd739b91 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -703,7 +703,7 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff // must be stored in the same memory page. // (8 byte key length) (key) (value) (8 byte pointer to next value) int uaoSize = UnsafeAlignedOffset.getUaoSize(); - final long recordLength = (2 * uaoSize) + klen + vlen + 8; + final long recordLength = (2L * uaoSize) + klen + vlen + 8; if (currentPage == null || currentPage.size() - pageCursor < recordLength) { if (!acquireNewPage(recordLength + uaoSize)) { return false; diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 58a181128eb4d..a6d13d12fc28d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -225,7 +225,7 @@ private[deploy] class DriverRunner( // check if attempting another run keepTrying = supervise && exitCode != 0 && !killed if (keepTrying) { - if (clock.getTimeMillis() - processStart > successfulRunDuration * 1000) { + if (clock.getTimeMillis() - processStart > successfulRunDuration * 1000L) { waitSeconds = 1 } logInfo(s"Command exited with status $exitCode, re-launching after $waitSeconds s.") diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index 13db4985b0b80..ba9dae4ad48ec 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -95,7 +95,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi // the left side of max is >=1 whenever partsScanned >= 2 numPartsToTry = Math.max(1, (1.5 * num * partsScanned / results.size).toInt - partsScanned) - numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) + numPartsToTry = Math.min(numPartsToTry, partsScanned * 4L) } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e0276a4dc4224..df1a4bef616b2 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -291,7 +291,7 @@ private[spark] class BlockManager( case e: Exception if i < MAX_ATTEMPTS => logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}" + s" more times after waiting $SLEEP_TIME_SECS seconds...", e) - Thread.sleep(SLEEP_TIME_SECS * 1000) + Thread.sleep(SLEEP_TIME_SECS * 1000L) case NonFatal(e) => throw new SparkException("Unable to register with external shuffle server due to : " + e.getMessage, e) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index d5d934bc91cab..4dd2b7365652a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -83,7 +83,7 @@ public static long calculateSizeOfUnderlyingByteArray(long numFields, int elemen private long elementOffset; private long getElementOffset(int ordinal, int elementSize) { - return elementOffset + ordinal * elementSize; + return elementOffset + ordinal * (long)elementSize; } public Object getBaseObject() { return baseObject; } @@ -414,7 +414,7 @@ public byte[] toByteArray() { public short[] toShortArray() { short[] values = new short[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.SHORT_ARRAY_OFFSET, numElements * 2); + baseObject, elementOffset, values, Platform.SHORT_ARRAY_OFFSET, numElements * 2L); return values; } @@ -422,7 +422,7 @@ public short[] toShortArray() { public int[] toIntArray() { int[] values = new int[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.INT_ARRAY_OFFSET, numElements * 4); + baseObject, elementOffset, values, Platform.INT_ARRAY_OFFSET, numElements * 4L); return values; } @@ -430,7 +430,7 @@ public int[] toIntArray() { public long[] toLongArray() { long[] values = new long[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.LONG_ARRAY_OFFSET, numElements * 8); + baseObject, elementOffset, values, Platform.LONG_ARRAY_OFFSET, numElements * 8L); return values; } @@ -438,7 +438,7 @@ public long[] toLongArray() { public float[] toFloatArray() { float[] values = new float[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.FLOAT_ARRAY_OFFSET, numElements * 4); + baseObject, elementOffset, values, Platform.FLOAT_ARRAY_OFFSET, numElements * 4L); return values; } @@ -446,14 +446,14 @@ public float[] toFloatArray() { public double[] toDoubleArray() { double[] values = new double[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.DOUBLE_ARRAY_OFFSET, numElements * 8); + baseObject, elementOffset, values, Platform.DOUBLE_ARRAY_OFFSET, numElements * 8L); return values; } private static UnsafeArrayData fromPrimitiveArray( Object arr, int offset, int length, int elementSize) { final long headerInBytes = calculateHeaderPortionInBytes(length); - final long valueRegionInBytes = elementSize * length; + final long valueRegionInBytes = (long)elementSize * length; final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8; if (totalSizeInLongs > Integer.MAX_VALUE / 8) { throw new UnsupportedOperationException("Cannot convert this array to unsafe format as " + diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java index 905e6820ce6e2..c823de4810f2b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java @@ -41,7 +41,7 @@ public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueB @Override public UnsafeRow appendRow(Object kbase, long koff, int klen, Object vbase, long voff, int vlen) { - final long recordLength = 8 + klen + vlen + 8; + final long recordLength = 8L + klen + vlen + 8; // if run out of max supported rows or page size, return null if (numRows >= capacity || page == null || page.size() - pageCursor < recordLength) { return null; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 4733f36174f42..6fdadde628551 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -216,12 +216,12 @@ protected UTF8String getBytesAsUTF8String(int rowId, int count) { @Override public void putShort(int rowId, short value) { - Platform.putShort(null, data + 2 * rowId, value); + Platform.putShort(null, data + 2L * rowId, value); } @Override public void putShorts(int rowId, int count, short value) { - long offset = data + 2 * rowId; + long offset = data + 2L * rowId; for (int i = 0; i < count; ++i, offset += 2) { Platform.putShort(null, offset, value); } @@ -229,20 +229,20 @@ public void putShorts(int rowId, int count, short value) { @Override public void putShorts(int rowId, int count, short[] src, int srcIndex) { - Platform.copyMemory(src, Platform.SHORT_ARRAY_OFFSET + srcIndex * 2, - null, data + 2 * rowId, count * 2); + Platform.copyMemory(src, Platform.SHORT_ARRAY_OFFSET + srcIndex * 2L, + null, data + 2L * rowId, count * 2L); } @Override public void putShorts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 2, count * 2); + null, data + rowId * 2L, count * 2L); } @Override public short getShort(int rowId) { if (dictionary == null) { - return Platform.getShort(null, data + 2 * rowId); + return Platform.getShort(null, data + 2L * rowId); } else { return (short) dictionary.decodeToInt(dictionaryIds.getDictId(rowId)); } @@ -252,7 +252,7 @@ public short getShort(int rowId) { public short[] getShorts(int rowId, int count) { assert(dictionary == null); short[] array = new short[count]; - Platform.copyMemory(null, data + rowId * 2, array, Platform.SHORT_ARRAY_OFFSET, count * 2); + Platform.copyMemory(null, data + rowId * 2L, array, Platform.SHORT_ARRAY_OFFSET, count * 2L); return array; } @@ -262,12 +262,12 @@ public short[] getShorts(int rowId, int count) { @Override public void putInt(int rowId, int value) { - Platform.putInt(null, data + 4 * rowId, value); + Platform.putInt(null, data + 4L * rowId, value); } @Override public void putInts(int rowId, int count, int value) { - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4) { Platform.putInt(null, offset, value); } @@ -275,24 +275,24 @@ public void putInts(int rowId, int count, int value) { @Override public void putInts(int rowId, int count, int[] src, int srcIndex) { - Platform.copyMemory(src, Platform.INT_ARRAY_OFFSET + srcIndex * 4, - null, data + 4 * rowId, count * 4); + Platform.copyMemory(src, Platform.INT_ARRAY_OFFSET + srcIndex * 4L, + null, data + 4L * rowId, count * 4L); } @Override public void putInts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 4, count * 4); + null, data + rowId * 4L, count * 4L); } @Override public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, srcIndex + Platform.BYTE_ARRAY_OFFSET, - null, data + 4 * rowId, count * 4); + null, data + 4L * rowId, count * 4L); } else { int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4, srcOffset += 4) { Platform.putInt(null, offset, java.lang.Integer.reverseBytes(Platform.getInt(src, srcOffset))); @@ -303,7 +303,7 @@ public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) @Override public int getInt(int rowId) { if (dictionary == null) { - return Platform.getInt(null, data + 4 * rowId); + return Platform.getInt(null, data + 4L * rowId); } else { return dictionary.decodeToInt(dictionaryIds.getDictId(rowId)); } @@ -313,7 +313,7 @@ public int getInt(int rowId) { public int[] getInts(int rowId, int count) { assert(dictionary == null); int[] array = new int[count]; - Platform.copyMemory(null, data + rowId * 4, array, Platform.INT_ARRAY_OFFSET, count * 4); + Platform.copyMemory(null, data + rowId * 4L, array, Platform.INT_ARRAY_OFFSET, count * 4L); return array; } @@ -325,7 +325,7 @@ public int[] getInts(int rowId, int count) { public int getDictId(int rowId) { assert(dictionary == null) : "A ColumnVector dictionary should not have a dictionary for itself."; - return Platform.getInt(null, data + 4 * rowId); + return Platform.getInt(null, data + 4L * rowId); } // @@ -334,12 +334,12 @@ public int getDictId(int rowId) { @Override public void putLong(int rowId, long value) { - Platform.putLong(null, data + 8 * rowId, value); + Platform.putLong(null, data + 8L * rowId, value); } @Override public void putLongs(int rowId, int count, long value) { - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8) { Platform.putLong(null, offset, value); } @@ -347,24 +347,24 @@ public void putLongs(int rowId, int count, long value) { @Override public void putLongs(int rowId, int count, long[] src, int srcIndex) { - Platform.copyMemory(src, Platform.LONG_ARRAY_OFFSET + srcIndex * 8, - null, data + 8 * rowId, count * 8); + Platform.copyMemory(src, Platform.LONG_ARRAY_OFFSET + srcIndex * 8L, + null, data + 8L * rowId, count * 8L); } @Override public void putLongs(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 8, count * 8); + null, data + rowId * 8L, count * 8L); } @Override public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, srcIndex + Platform.BYTE_ARRAY_OFFSET, - null, data + 8 * rowId, count * 8); + null, data + 8L * rowId, count * 8L); } else { int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8, srcOffset += 8) { Platform.putLong(null, offset, java.lang.Long.reverseBytes(Platform.getLong(src, srcOffset))); @@ -375,7 +375,7 @@ public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) @Override public long getLong(int rowId) { if (dictionary == null) { - return Platform.getLong(null, data + 8 * rowId); + return Platform.getLong(null, data + 8L * rowId); } else { return dictionary.decodeToLong(dictionaryIds.getDictId(rowId)); } @@ -385,7 +385,7 @@ public long getLong(int rowId) { public long[] getLongs(int rowId, int count) { assert(dictionary == null); long[] array = new long[count]; - Platform.copyMemory(null, data + rowId * 8, array, Platform.LONG_ARRAY_OFFSET, count * 8); + Platform.copyMemory(null, data + rowId * 8L, array, Platform.LONG_ARRAY_OFFSET, count * 8L); return array; } @@ -395,12 +395,12 @@ public long[] getLongs(int rowId, int count) { @Override public void putFloat(int rowId, float value) { - Platform.putFloat(null, data + rowId * 4, value); + Platform.putFloat(null, data + rowId * 4L, value); } @Override public void putFloats(int rowId, int count, float value) { - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4) { Platform.putFloat(null, offset, value); } @@ -408,18 +408,18 @@ public void putFloats(int rowId, int count, float value) { @Override public void putFloats(int rowId, int count, float[] src, int srcIndex) { - Platform.copyMemory(src, Platform.FLOAT_ARRAY_OFFSET + srcIndex * 4, - null, data + 4 * rowId, count * 4); + Platform.copyMemory(src, Platform.FLOAT_ARRAY_OFFSET + srcIndex * 4L, + null, data + 4L * rowId, count * 4L); } @Override public void putFloats(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 4, count * 4); + null, data + rowId * 4L, count * 4L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4) { Platform.putFloat(null, offset, bb.getFloat(srcIndex + (4 * i))); } @@ -429,7 +429,7 @@ public void putFloats(int rowId, int count, byte[] src, int srcIndex) { @Override public float getFloat(int rowId) { if (dictionary == null) { - return Platform.getFloat(null, data + rowId * 4); + return Platform.getFloat(null, data + rowId * 4L); } else { return dictionary.decodeToFloat(dictionaryIds.getDictId(rowId)); } @@ -439,7 +439,7 @@ public float getFloat(int rowId) { public float[] getFloats(int rowId, int count) { assert(dictionary == null); float[] array = new float[count]; - Platform.copyMemory(null, data + rowId * 4, array, Platform.FLOAT_ARRAY_OFFSET, count * 4); + Platform.copyMemory(null, data + rowId * 4L, array, Platform.FLOAT_ARRAY_OFFSET, count * 4L); return array; } @@ -450,12 +450,12 @@ public float[] getFloats(int rowId, int count) { @Override public void putDouble(int rowId, double value) { - Platform.putDouble(null, data + rowId * 8, value); + Platform.putDouble(null, data + rowId * 8L, value); } @Override public void putDoubles(int rowId, int count, double value) { - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8) { Platform.putDouble(null, offset, value); } @@ -463,18 +463,18 @@ public void putDoubles(int rowId, int count, double value) { @Override public void putDoubles(int rowId, int count, double[] src, int srcIndex) { - Platform.copyMemory(src, Platform.DOUBLE_ARRAY_OFFSET + srcIndex * 8, - null, data + 8 * rowId, count * 8); + Platform.copyMemory(src, Platform.DOUBLE_ARRAY_OFFSET + srcIndex * 8L, + null, data + 8L * rowId, count * 8L); } @Override public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 8, count * 8); + null, data + rowId * 8L, count * 8L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8) { Platform.putDouble(null, offset, bb.getDouble(srcIndex + (8 * i))); } @@ -484,7 +484,7 @@ public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { @Override public double getDouble(int rowId) { if (dictionary == null) { - return Platform.getDouble(null, data + rowId * 8); + return Platform.getDouble(null, data + rowId * 8L); } else { return dictionary.decodeToDouble(dictionaryIds.getDictId(rowId)); } @@ -494,7 +494,7 @@ public double getDouble(int rowId) { public double[] getDoubles(int rowId, int count) { assert(dictionary == null); double[] array = new double[count]; - Platform.copyMemory(null, data + rowId * 8, array, Platform.DOUBLE_ARRAY_OFFSET, count * 8); + Platform.copyMemory(null, data + rowId * 8L, array, Platform.DOUBLE_ARRAY_OFFSET, count * 8L); return array; } @@ -504,26 +504,26 @@ public double[] getDoubles(int rowId, int count) { @Override public void putArray(int rowId, int offset, int length) { assert(offset >= 0 && offset + length <= childColumns[0].capacity); - Platform.putInt(null, lengthData + 4 * rowId, length); - Platform.putInt(null, offsetData + 4 * rowId, offset); + Platform.putInt(null, lengthData + 4L * rowId, length); + Platform.putInt(null, offsetData + 4L * rowId, offset); } @Override public int getArrayLength(int rowId) { - return Platform.getInt(null, lengthData + 4 * rowId); + return Platform.getInt(null, lengthData + 4L * rowId); } @Override public int getArrayOffset(int rowId) { - return Platform.getInt(null, offsetData + 4 * rowId); + return Platform.getInt(null, offsetData + 4L * rowId); } // APIs dealing with ByteArrays @Override public int putByteArray(int rowId, byte[] value, int offset, int length) { int result = arrayData().appendBytes(length, value, offset); - Platform.putInt(null, lengthData + 4 * rowId, length); - Platform.putInt(null, offsetData + 4 * rowId, result); + Platform.putInt(null, lengthData + 4L * rowId, length); + Platform.putInt(null, offsetData + 4L * rowId, result); return result; } @@ -533,19 +533,19 @@ protected void reserveInternal(int newCapacity) { int oldCapacity = (nulls == 0L) ? 0 : capacity; if (isArray() || type instanceof MapType) { this.lengthData = - Platform.reallocateMemory(lengthData, oldCapacity * 4, newCapacity * 4); + Platform.reallocateMemory(lengthData, oldCapacity * 4L, newCapacity * 4L); this.offsetData = - Platform.reallocateMemory(offsetData, oldCapacity * 4, newCapacity * 4); + Platform.reallocateMemory(offsetData, oldCapacity * 4L, newCapacity * 4L); } else if (type instanceof ByteType || type instanceof BooleanType) { this.data = Platform.reallocateMemory(data, oldCapacity, newCapacity); } else if (type instanceof ShortType) { - this.data = Platform.reallocateMemory(data, oldCapacity * 2, newCapacity * 2); + this.data = Platform.reallocateMemory(data, oldCapacity * 2L, newCapacity * 2L); } else if (type instanceof IntegerType || type instanceof FloatType || type instanceof DateType || DecimalType.is32BitDecimalType(type)) { - this.data = Platform.reallocateMemory(data, oldCapacity * 4, newCapacity * 4); + this.data = Platform.reallocateMemory(data, oldCapacity * 4L, newCapacity * 4L); } else if (type instanceof LongType || type instanceof DoubleType || DecimalType.is64BitDecimalType(type) || type instanceof TimestampType) { - this.data = Platform.reallocateMemory(data, oldCapacity * 8, newCapacity * 8); + this.data = Platform.reallocateMemory(data, oldCapacity * 8L, newCapacity * 8L); } else if (childColumns != null) { // Nothing to store. } else { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 23dcc104e67c4..577eab6ed14c8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -231,7 +231,7 @@ public void putShorts(int rowId, int count, short[] src, int srcIndex) { @Override public void putShorts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, shortData, - Platform.SHORT_ARRAY_OFFSET + rowId * 2, count * 2); + Platform.SHORT_ARRAY_OFFSET + rowId * 2L, count * 2L); } @Override @@ -276,7 +276,7 @@ public void putInts(int rowId, int count, int[] src, int srcIndex) { @Override public void putInts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, intData, - Platform.INT_ARRAY_OFFSET + rowId * 4, count * 4); + Platform.INT_ARRAY_OFFSET + rowId * 4L, count * 4L); } @Override @@ -342,7 +342,7 @@ public void putLongs(int rowId, int count, long[] src, int srcIndex) { @Override public void putLongs(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, longData, - Platform.LONG_ARRAY_OFFSET + rowId * 8, count * 8); + Platform.LONG_ARRAY_OFFSET + rowId * 8L, count * 8L); } @Override @@ -394,7 +394,7 @@ public void putFloats(int rowId, int count, float[] src, int srcIndex) { public void putFloats(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, floatData, - Platform.DOUBLE_ARRAY_OFFSET + rowId * 4, count * 4); + Platform.DOUBLE_ARRAY_OFFSET + rowId * 4L, count * 4L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); for (int i = 0; i < count; ++i) { @@ -443,7 +443,7 @@ public void putDoubles(int rowId, int count, double[] src, int srcIndex) { public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, doubleData, - Platform.DOUBLE_ARRAY_OFFSET + rowId * 8, count * 8); + Platform.DOUBLE_ARRAY_OFFSET + rowId * 8L, count * 8L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); for (int i = 0; i < count; ++i) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala index fbff8db987110..b393c48baee8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -202,7 +202,7 @@ class RateStreamMicroBatchInputPartitionReader( rangeEnd: Long, localStartTimeMs: Long, relativeMsPerValue: Double) extends InputPartitionReader[Row] { - private var count = 0 + private var count: Long = 0 override def next(): Boolean = { rangeStart + partitionId + numPartitions * count < rangeEnd diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 130e258e78ca2..8620f3f6d99fb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -342,7 +342,7 @@ private[client] class Shim_v0_12 extends Shim with Logging { } override def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long = { - conf.getIntVar(HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY) * 1000 + conf.getIntVar(HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY) * 1000L } override def loadPartition( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index ab7c8558321c8..2e8599026ea1d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -222,7 +222,7 @@ private[streaming] class FileBasedWriteAheadLog( pastLogs += LogInfo(currentLogWriterStartTime, currentLogWriterStopTime, _) } currentLogWriterStartTime = currentTime - currentLogWriterStopTime = currentTime + (rollingIntervalSecs * 1000) + currentLogWriterStopTime = currentTime + (rollingIntervalSecs * 1000L) val newLogPath = new Path(logDirectory, timeToLogFile(currentLogWriterStartTime, currentLogWriterStopTime)) currentLogPath = Some(newLogPath.toString) From e4fee395ecd93ad4579d9afbf0861f82a303e563 Mon Sep 17 00:00:00 2001 From: Mukul Murthy Date: Fri, 15 Jun 2018 13:56:48 -0700 Subject: [PATCH 2/2] [SPARK-24525][SS] Provide an option to limit number of rows in a MemorySink ## What changes were proposed in this pull request? Provide an option to limit number of rows in a MemorySink. Currently, MemorySink and MemorySinkV2 have unbounded size, meaning that if they're used on big data, they can OOM the stream. This change adds a maxMemorySinkRows option to limit how many rows MemorySink and MemorySinkV2 can hold. By default, they are still unbounded. ## How was this patch tested? Added new unit tests. Author: Mukul Murthy Closes #21559 from mukulmurthy/SPARK-24525. --- .../sql/execution/streaming/memory.scala | 70 ++++++++++++++-- .../streaming/sources/memoryV2.scala | 44 +++++++--- .../sql/streaming/DataStreamWriter.scala | 4 +- .../execution/streaming/MemorySinkSuite.scala | 62 +++++++++++++- .../streaming/MemorySinkV2Suite.scala | 80 ++++++++++++++++++- .../spark/sql/streaming/StreamTest.scala | 4 +- 6 files changed, 239 insertions(+), 25 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index b137f98045c5a..7fa13c4aa2c01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode @@ -221,19 +222,60 @@ class MemoryStreamInputPartition(records: Array[UnsafeRow]) } /** A common trait for MemorySinks with methods used for testing */ -trait MemorySinkBase extends BaseStreamingSink { +trait MemorySinkBase extends BaseStreamingSink with Logging { def allData: Seq[Row] def latestBatchData: Seq[Row] def dataSinceBatch(sinceBatchId: Long): Seq[Row] def latestBatchId: Option[Long] + + /** + * Truncates the given rows to return at most maxRows rows. + * @param rows The data that may need to be truncated. + * @param batchLimit Number of rows to keep in this batch; the rest will be truncated + * @param sinkLimit Total number of rows kept in this sink, for logging purposes. + * @param batchId The ID of the batch that sent these rows, for logging purposes. + * @return Truncated rows. + */ + protected def truncateRowsIfNeeded( + rows: Array[Row], + batchLimit: Int, + sinkLimit: Int, + batchId: Long): Array[Row] = { + if (rows.length > batchLimit && batchLimit >= 0) { + logWarning(s"Truncating batch $batchId to $batchLimit rows because of sink limit $sinkLimit") + rows.take(batchLimit) + } else { + rows + } + } +} + +/** + * Companion object to MemorySinkBase. + */ +object MemorySinkBase { + val MAX_MEMORY_SINK_ROWS = "maxRows" + val MAX_MEMORY_SINK_ROWS_DEFAULT = -1 + + /** + * Gets the max number of rows a MemorySink should store. This number is based on the memory + * sink row limit option if it is set. If not, we use a large value so that data truncates + * rather than causing out of memory errors. + * @param options Options for writing from which we get the max rows option + * @return The maximum number of rows a memorySink should store. + */ + def getMemorySinkCapacity(options: DataSourceOptions): Int = { + val maxRows = options.getInt(MAX_MEMORY_SINK_ROWS, MAX_MEMORY_SINK_ROWS_DEFAULT) + if (maxRows >= 0) maxRows else Int.MaxValue - 10 + } } /** * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink - with MemorySinkBase with Logging { +class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSourceOptions) + extends Sink with MemorySinkBase with Logging { private case class AddedData(batchId: Long, data: Array[Row]) @@ -241,6 +283,12 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink @GuardedBy("this") private val batches = new ArrayBuffer[AddedData]() + /** The number of rows in this MemorySink. */ + private var numRows = 0 + + /** The capacity in rows of this sink. */ + val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options) + /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { batches.flatMap(_.data) @@ -273,14 +321,23 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink logDebug(s"Committing batch $batchId to $this") outputMode match { case Append | Update => - val rows = AddedData(batchId, data.collect()) - synchronized { batches += rows } + var rowsToAdd = data.collect() + synchronized { + rowsToAdd = + truncateRowsIfNeeded(rowsToAdd, sinkCapacity - numRows, sinkCapacity, batchId) + val rows = AddedData(batchId, rowsToAdd) + batches += rows + numRows += rowsToAdd.length + } case Complete => - val rows = AddedData(batchId, data.collect()) + var rowsToAdd = data.collect() synchronized { + rowsToAdd = truncateRowsIfNeeded(rowsToAdd, sinkCapacity, sinkCapacity, batchId) + val rows = AddedData(batchId, rowsToAdd) batches.clear() batches += rows + numRows = rowsToAdd.length } case _ => @@ -294,6 +351,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink def clear(): Unit = synchronized { batches.clear() + numRows = 0 } override def toString(): String = "MemorySink" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 468313bfe8c3c..47b482007822d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -46,7 +46,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB schema: StructType, mode: OutputMode, options: DataSourceOptions): StreamWriter = { - new MemoryStreamWriter(this, mode) + new MemoryStreamWriter(this, mode, options) } private case class AddedData(batchId: Long, data: Array[Row]) @@ -55,6 +55,9 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB @GuardedBy("this") private val batches = new ArrayBuffer[AddedData]() + /** The number of rows in this MemorySink. */ + private var numRows = 0 + /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { batches.flatMap(_.data) @@ -81,7 +84,11 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB }.mkString("\n") } - def write(batchId: Long, outputMode: OutputMode, newRows: Array[Row]): Unit = { + def write( + batchId: Long, + outputMode: OutputMode, + newRows: Array[Row], + sinkCapacity: Int): Unit = { val notCommitted = synchronized { latestBatchId.isEmpty || batchId > latestBatchId.get } @@ -89,14 +96,21 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB logDebug(s"Committing batch $batchId to $this") outputMode match { case Append | Update => - val rows = AddedData(batchId, newRows) - synchronized { batches += rows } + synchronized { + val rowsToAdd = + truncateRowsIfNeeded(newRows, sinkCapacity - numRows, sinkCapacity, batchId) + val rows = AddedData(batchId, rowsToAdd) + batches += rows + numRows += rowsToAdd.length + } case Complete => - val rows = AddedData(batchId, newRows) synchronized { + val rowsToAdd = truncateRowsIfNeeded(newRows, sinkCapacity, sinkCapacity, batchId) + val rows = AddedData(batchId, rowsToAdd) batches.clear() batches += rows + numRows = rowsToAdd.length } case _ => @@ -110,6 +124,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB def clear(): Unit = synchronized { batches.clear() + numRows = 0 } override def toString(): String = "MemorySinkV2" @@ -117,16 +132,22 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} -class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode) +class MemoryWriter( + sink: MemorySinkV2, + batchId: Long, + outputMode: OutputMode, + options: DataSourceOptions) extends DataSourceWriter with Logging { + val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options) + override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) def commit(messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { case message: MemoryWriterCommitMessage => message.data } - sink.write(batchId, outputMode, newRows) + sink.write(batchId, outputMode, newRows, sinkCapacity) } override def abort(messages: Array[WriterCommitMessage]): Unit = { @@ -134,16 +155,21 @@ class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode) } } -class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode) +class MemoryStreamWriter( + val sink: MemorySinkV2, + outputMode: OutputMode, + options: DataSourceOptions) extends StreamWriter { + val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options) + override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { case message: MemoryWriterCommitMessage => message.data } - sink.write(epochId, outputMode, newRows) + sink.write(epochId, outputMode, newRows, sinkCapacity) } override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index e035c9cdc379e..43e80e4e54239 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{ForeachWriterProvider, MemoryPlanV2, MemorySinkV2} -import org.apache.spark.sql.sources.v2.StreamWriteSupport +import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -249,7 +249,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { val r = Dataset.ofRows(df.sparkSession, new MemoryPlanV2(s, df.schema.toAttributes)) (s, r) case _ => - val s = new MemorySink(df.schema, outputMode) + val s = new MemorySink(df.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) val r = Dataset.ofRows(df.sparkSession, new MemoryPlan(s)) (s, r) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala index 3bc36ce55d902..b2fd6ba27ebb8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.execution.streaming +import scala.collection.JavaConverters._ import scala.language.implicitConversions import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.streaming.{OutputMode, StreamTest} import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -36,7 +38,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("directly add data in Append output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Append) + val sink = new MemorySink(schema, OutputMode.Append, DataSourceOptions.empty()) // Before adding data, check output assert(sink.latestBatchId === None) @@ -68,9 +70,35 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { checkAnswer(sink.allData, 1 to 9) } + test("directly add data in Append output mode with row limit") { + implicit val schema = new StructType().add(new StructField("value", IntegerType)) + + var optionsMap = new scala.collection.mutable.HashMap[String, String] + optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString()) + var options = new DataSourceOptions(optionsMap.toMap.asJava) + val sink = new MemorySink(schema, OutputMode.Append, options) + + // Before adding data, check output + assert(sink.latestBatchId === None) + checkAnswer(sink.latestBatchData, Seq.empty) + checkAnswer(sink.allData, Seq.empty) + + // Add batch 0 and check outputs + sink.addBatch(0, 1 to 3) + assert(sink.latestBatchId === Some(0)) + checkAnswer(sink.latestBatchData, 1 to 3) + checkAnswer(sink.allData, 1 to 3) + + // Add batch 1 and check outputs + sink.addBatch(1, 4 to 6) + assert(sink.latestBatchId === Some(1)) + checkAnswer(sink.latestBatchData, 4 to 5) + checkAnswer(sink.allData, 1 to 5) // new data should not go over the limit + } + test("directly add data in Update output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Update) + val sink = new MemorySink(schema, OutputMode.Update, DataSourceOptions.empty()) // Before adding data, check output assert(sink.latestBatchId === None) @@ -104,7 +132,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("directly add data in Complete output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Complete) + val sink = new MemorySink(schema, OutputMode.Complete, DataSourceOptions.empty()) // Before adding data, check output assert(sink.latestBatchId === None) @@ -136,6 +164,32 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { checkAnswer(sink.allData, 7 to 9) } + test("directly add data in Complete output mode with row limit") { + implicit val schema = new StructType().add(new StructField("value", IntegerType)) + + var optionsMap = new scala.collection.mutable.HashMap[String, String] + optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString()) + var options = new DataSourceOptions(optionsMap.toMap.asJava) + val sink = new MemorySink(schema, OutputMode.Complete, options) + + // Before adding data, check output + assert(sink.latestBatchId === None) + checkAnswer(sink.latestBatchData, Seq.empty) + checkAnswer(sink.allData, Seq.empty) + + // Add batch 0 and check outputs + sink.addBatch(0, 1 to 3) + assert(sink.latestBatchId === Some(0)) + checkAnswer(sink.latestBatchData, 1 to 3) + checkAnswer(sink.allData, 1 to 3) + + // Add batch 1 and check outputs + sink.addBatch(1, 4 to 10) + assert(sink.latestBatchId === Some(1)) + checkAnswer(sink.latestBatchData, 4 to 8) + checkAnswer(sink.allData, 4 to 8) // new data should replace old data + } + test("registering as a table in Append output mode") { val input = MemoryStream[Int] @@ -211,7 +265,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("MemoryPlan statistics") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Append) + val sink = new MemorySink(schema, OutputMode.Append, DataSourceOptions.empty()) val plan = new MemoryPlan(sink) // Before adding data, check output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index 9be22d94b5654..e539510e15755 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -17,11 +17,16 @@ package org.apache.spark.sql.execution.streaming +import scala.collection.JavaConverters._ + import org.scalatest.BeforeAndAfter import org.apache.spark.sql.Row import org.apache.spark.sql.execution.streaming.sources._ +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.streaming.{OutputMode, StreamTest} +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.StructType class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("data writer") { @@ -40,7 +45,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("continuous writer") { val sink = new MemorySinkV2 - val writer = new MemoryStreamWriter(sink, OutputMode.Append()) + val writer = new MemoryStreamWriter(sink, OutputMode.Append(), DataSourceOptions.empty()) writer.commit(0, Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), @@ -62,7 +67,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("microbatch writer") { val sink = new MemorySinkV2 - new MemoryWriter(sink, 0, OutputMode.Append()).commit( + new MemoryWriter(sink, 0, OutputMode.Append(), DataSourceOptions.empty()).commit( Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), @@ -70,7 +75,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { )) assert(sink.latestBatchId.contains(0)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - new MemoryWriter(sink, 19, OutputMode.Append()).commit( + new MemoryWriter(sink, 19, OutputMode.Append(), DataSourceOptions.empty()).commit( Array( MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), MemoryWriterCommitMessage(0, Seq(Row(33))) @@ -80,4 +85,73 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33)) } + + test("continuous writer with row limit") { + val sink = new MemorySinkV2 + val optionsMap = new scala.collection.mutable.HashMap[String, String] + optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 7.toString()) + val options = new DataSourceOptions(optionsMap.toMap.asJava) + val appendWriter = new MemoryStreamWriter(sink, OutputMode.Append(), options) + appendWriter.commit(0, Array( + MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), + MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), + MemoryWriterCommitMessage(2, Seq(Row(6), Row(7))))) + assert(sink.latestBatchId.contains(0)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) + appendWriter.commit(19, Array( + MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), + MemoryWriterCommitMessage(0, Seq(Row(33))))) + assert(sink.latestBatchId.contains(19)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11)) + + assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11)) + + val completeWriter = new MemoryStreamWriter(sink, OutputMode.Complete(), options) + completeWriter.commit(20, Array( + MemoryWriterCommitMessage(4, Seq(Row(11), Row(22))), + MemoryWriterCommitMessage(5, Seq(Row(33))))) + assert(sink.latestBatchId.contains(20)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33)) + completeWriter.commit(21, Array( + MemoryWriterCommitMessage(0, Seq(Row(1), Row(2), Row(3))), + MemoryWriterCommitMessage(1, Seq(Row(4), Row(5), Row(6))), + MemoryWriterCommitMessage(2, Seq(Row(7), Row(8), Row(9))))) + assert(sink.latestBatchId.contains(21)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5, 6, 7)) + + assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5, 6, 7)) + } + + test("microbatch writer with row limit") { + val sink = new MemorySinkV2 + val optionsMap = new scala.collection.mutable.HashMap[String, String] + optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString()) + val options = new DataSourceOptions(optionsMap.toMap.asJava) + + new MemoryWriter(sink, 25, OutputMode.Append(), options).commit(Array( + MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), + MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))))) + assert(sink.latestBatchId.contains(25)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4)) + assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4)) + new MemoryWriter(sink, 26, OutputMode.Append(), options).commit(Array( + MemoryWriterCommitMessage(2, Seq(Row(5), Row(6))), + MemoryWriterCommitMessage(3, Seq(Row(7), Row(8))))) + assert(sink.latestBatchId.contains(26)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(5)) + assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5)) + + new MemoryWriter(sink, 27, OutputMode.Complete(), options).commit(Array( + MemoryWriterCommitMessage(4, Seq(Row(9), Row(10))), + MemoryWriterCommitMessage(5, Seq(Row(11), Row(12))))) + assert(sink.latestBatchId.contains(27)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(9, 10, 11, 12)) + assert(sink.allData.map(_.getInt(0)).sorted == Seq(9, 10, 11, 12)) + new MemoryWriter(sink, 28, OutputMode.Complete(), options).commit(Array( + MemoryWriterCommitMessage(4, Seq(Row(13), Row(14), Row(15))), + MemoryWriterCommitMessage(5, Seq(Row(16), Row(17), Row(18))))) + assert(sink.latestBatchId.contains(28)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(13, 14, 15, 16, 17)) + assert(sink.allData.map(_.getInt(0)).sorted == Seq(13, 14, 15, 16, 17)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 4c3fd58cb2e45..e41b4534ed51d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -45,6 +45,7 @@ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{Clock, SystemClock, Utils} @@ -337,7 +338,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be var currentStream: StreamExecution = null var lastStream: StreamExecution = null val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for - val sink = if (useV2Sink) new MemorySinkV2 else new MemorySink(stream.schema, outputMode) + val sink = if (useV2Sink) new MemorySinkV2 + else new MemorySink(stream.schema, outputMode, DataSourceOptions.empty()) val resetConfValues = mutable.Map[String, Option[String]]() val defaultCheckpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath