diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/map/BitSet.java similarity index 98% rename from unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java rename to core/src/main/java/org/apache/spark/util/collection/unsafe/map/BitSet.java index 7c124173b0bbb..2abbf0b9af51f 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/map/BitSet.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.unsafe.bitset; +package org.apache.spark.util.collection.unsafe.map; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/map/BitSetMethods.java similarity index 98% rename from unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java rename to core/src/main/java/org/apache/spark/util/collection/unsafe/map/BitSetMethods.java index 27462c7fa5e62..a173b7912380d 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/map/BitSetMethods.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.unsafe.bitset; +package org.apache.spark.util.collection.unsafe.map; import org.apache.spark.unsafe.PlatformDependent; diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/map/BytesToBytesMap.java similarity index 79% rename from unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java rename to core/src/main/java/org/apache/spark/util/collection/unsafe/map/BytesToBytesMap.java index d0bde69cc1068..a4eb66d655be4 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/map/BytesToBytesMap.java @@ -15,22 +15,21 @@ * limitations under the License. */ -package org.apache.spark.unsafe.map; +package org.apache.spark.util.collection.unsafe.map; -import java.lang.Override; -import java.lang.UnsupportedOperationException; +import java.util.Comparator; import java.util.Iterator; import java.util.LinkedList; import java.util.List; import com.google.common.annotations.VisibleForTesting; - -import org.apache.spark.unsafe.*; +import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.array.ByteArrayMethods; -import org.apache.spark.unsafe.array.LongArray; -import org.apache.spark.unsafe.bitset.BitSet; -import org.apache.spark.unsafe.hash.Murmur3_x86_32; -import org.apache.spark.unsafe.memory.*; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.MemoryLocation; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.collection.Sorter; +import org.apache.spark.util.collection.SortDataFormat; /** * An append-only hash map where keys and values are contiguous regions of bytes. @@ -98,7 +97,8 @@ public final class BytesToBytesMap { * Position {@code 2 * i} in the array is used to track a pointer to the key at index {@code i}, * while position {@code 2 * i + 1} in the array holds key's full 32-bit hashcode. */ - private LongArray longArray; + private long[] longArray; + // TODO: replace longArray with LongArray // TODO: we're wasting 32 bits of space here; we can probably store fewer bits of the hashcode // and exploit word-alignment to use fewer bits to hold the address. This might let us store // only one long per map entry, increasing the chance that this array will fit in cache at the @@ -184,6 +184,30 @@ public BytesToBytesMap( */ public int size() { return size; } + public boolean hasSpaceForAnotherRecord() { + return size < growthThreshold; + } + + public int getNextCapacity() { + int nextCapacity = + Math.min(growthStrategy.nextCapacity((int)bitset.capacity()), MAX_CAPACITY); + // The capacity needs to be divisible by 64 so that our bit set can be sized properly + return Math.max((int) Math.min(MAX_CAPACITY, nextPowerOf2(nextCapacity)), 64); + } + + public Location getNewLocation() { + return new Location(); + } + + public static int getCapacity(int initialCapacity) { + // The capacity needs to be divisible by 64 so that our bit set can be sized properly + return Math.max((int) Math.min(MAX_CAPACITY, nextPowerOf2(initialCapacity)), 64); + } + + public static long getMemoryUsage(int capacity) { + return capacity * 8L * 2 + capacity / 8; + } + private static final class BytesToBytesMapIterator implements Iterator { private final int numRecords; @@ -245,6 +269,84 @@ public Iterator iterator() { return new BytesToBytesMapIterator(size, dataPages.iterator(), loc); } + private static final class KVLongArraySortDataFormat extends SortDataFormat { + + @Override + public Long getKey(long[] data, int pos) { + return data[2 * pos]; + } + + @Override + public Long newKey() { + return 0L; + } + + @Override + public void swap(long[] data,int pos0, int pos1) { + long tmpKey = data[2 * pos0]; + long tmpVal = data[2 * pos0 + 1]; + data[2 * pos0] = data[2 * pos1]; + data[2 * pos0 + 1] = data[2 * pos1 + 1]; + data[2 * pos1] = tmpKey; + data[2 * pos1 + 1] = tmpVal; + } + + @Override + public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) { + dst[2 * dstPos] = src[2 * srcPos]; + dst[2 * dstPos + 1] = src[2 * srcPos + 1]; + } + + @Override + public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) { + System.arraycopy(src, 2 * srcPos, dst, 2 * dstPos, 2 * length); + } + + @Override + public long[] allocate(int length) { + return new long[2 * length]; + } + } + + public Iterator getSortedIterator(Comparator sortComparator) { + // Pack KV pairs into the front of the underlying array + int keyIndex = 0; + int newIndex = 0; + while (keyIndex < bitset.capacity()) { + if (bitset.isSet(keyIndex)) { + longArray[2 * newIndex] = longArray[2 * keyIndex]; + longArray[2 * newIndex + 1] = longArray[2 * keyIndex + 1]; + newIndex += 1; + } + keyIndex += 1; + } + + new Sorter(new KVLongArraySortDataFormat()).sort(longArray, 0, newIndex, sortComparator); + + final int numRecords = newIndex; + return new Iterator() { + + private int currentRecordNumber = 0; + + @Override + public boolean hasNext() { + return currentRecordNumber != numRecords; + } + + @Override + public Location next() { + loc.with(longArray[currentRecordNumber * 2]); + currentRecordNumber++; + return loc; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + /** * Looks up a key, and return a {@link Location} handle that can be used to test existence * and read/write values. @@ -269,7 +371,7 @@ public Location lookup( // This is a new key. return loc.with(pos, hashcode, false); } else { - long stored = longArray.get(pos * 2 + 1); + long stored = longArray[pos * 2 + 1]; if ((int) (stored) == hashcode) { // Full hash code matches. Let's compare the keys for equality. loc.with(pos, hashcode, true); @@ -303,6 +405,7 @@ public Location lookup( * Handle returned by {@link BytesToBytesMap#lookup(Object, long, int)} function. */ public final class Location { + /** An index into the hash map's Long array */ private int pos; /** True if this location points to a position where a key is defined, false otherwise */ @@ -334,23 +437,39 @@ private void updateAddressesAndSizes(Object page, long keyOffsetInPage) { valueMemoryLocation.setObjAndOffset(page, position); } + public Location with(long fullKeyAddress) { + this.isDefined = true; + updateAddressesAndSizes(fullKeyAddress); + return this; + } + Location with(int pos, int keyHashcode, boolean isDefined) { this.pos = pos; this.isDefined = isDefined; this.keyHashcode = keyHashcode; if (isDefined) { - final long fullKeyAddress = longArray.get(pos * 2); + final long fullKeyAddress = longArray[pos * 2]; updateAddressesAndSizes(fullKeyAddress); } return this; } - Location with(Object page, long keyOffsetInPage) { + public Location with(Object page, long keyOffsetInPage) { this.isDefined = true; updateAddressesAndSizes(page, keyOffsetInPage); return this; } + public Location with(int keyLength, byte[] keyArray, int valueLength, + byte[] valueArray) { + this.isDefined = true; + this.keyLength = keyLength; + keyMemoryLocation.setObjAndOffset(keyArray, PlatformDependent.BYTE_ARRAY_OFFSET); + this.valueLength = valueLength; + valueMemoryLocation.setObjAndOffset(valueArray, PlatformDependent.BYTE_ARRAY_OFFSET); + return this; + } + /** * Returns true if the key is defined at this position, and false otherwise. */ @@ -378,6 +497,22 @@ public int getKeyLength() { return keyLength; } + /** + * Returns the base object of the key. + */ + public Object getKeyBaseObject() { + assert (isDefined); + return keyMemoryLocation.getBaseObject(); + } + + /** + * Returns the base offset of the key. + */ + public long getKeyBaseOffset() { + assert (isDefined); + return keyMemoryLocation.getBaseOffset(); + } + /** * Returns the address of the value defined at this position. * This points to the first byte of the value data. @@ -398,6 +533,23 @@ public int getValueLength() { return valueLength; } + /** + * Returns the base object of the value. + */ + public Object getValueBaseObject() { + assert (isDefined); + return valueMemoryLocation.getBaseObject(); + } + + /** + * Return the base offset of the value. + */ + public long getValueBaseOffset() { + assert (isDefined); + return valueMemoryLocation.getBaseOffset(); + } + + /** * Store a new key and value. This method may only be called once for a given key; if you want * to update the value associated with a key, then you can directly manipulate the bytes stored @@ -425,6 +577,23 @@ public int getValueLength() { * Unspecified behavior if the key is not defined. *

*/ + + public void putNewKey(long storedKeyAddress) { + assert (!isDefined) : "Can only set value once for a key"; + if (size == MAX_CAPACITY) { + throw new IllegalStateException("BytesToBytesMap has reached maximum capacity"); + } + size++; + bitset.set(pos); + longArray[pos * 2] = storedKeyAddress; + longArray[pos * 2 + 1] = keyHashcode; + updateAddressesAndSizes(storedKeyAddress); + isDefined = true; + if (size > growthThreshold && longArray.length < MAX_CAPACITY) { + growAndRehash(); + } + } + public void putNewKey( Object keyBaseObject, long keyBaseOffset, @@ -432,20 +601,15 @@ public void putNewKey( Object valueBaseObject, long valueBaseOffset, int valueLengthBytes) { - assert (!isDefined) : "Can only set value once for a key"; assert (keyLengthBytes % 8 == 0); assert (valueLengthBytes % 8 == 0); - if (size == MAX_CAPACITY) { - throw new IllegalStateException("BytesToBytesMap has reached maximum capacity"); - } + // Here, we'll copy the data into our data pages. Because we only store a relative offset from // the key address instead of storing the absolute address of the value, the key and value // must be stored in the same memory page. // (8 byte key length) (key) (8 byte value length) (value) final long requiredSize = 8 + keyLengthBytes + 8 + valueLengthBytes; assert (requiredSize <= PAGE_SIZE_BYTES - 8); // Reserve 8 bytes for the end-of-page marker. - size++; - bitset.set(pos); // If there's not enough space in the current page, allocate a new page (8 bytes are reserved // for the end-of-page marker). @@ -485,13 +649,7 @@ public void putNewKey( final long storedKeyAddress = memoryManager.encodePageNumberAndOffset( currentDataPage, keySizeOffsetInPage); - longArray.set(pos * 2, storedKeyAddress); - longArray.set(pos * 2 + 1, keyHashcode); - updateAddressesAndSizes(storedKeyAddress); - isDefined = true; - if (size > growthThreshold && longArray.size() < MAX_CAPACITY) { - growAndRehash(); - } + this.putNewKey(storedKeyAddress); } } @@ -506,7 +664,7 @@ private void allocate(int capacity) { // The capacity needs to be divisible by 64 so that our bit set can be sized properly capacity = Math.max((int) Math.min(MAX_CAPACITY, nextPowerOf2(capacity)), 64); assert (capacity <= MAX_CAPACITY); - longArray = new LongArray(memoryManager.allocate(capacity * 8L * 2)); + longArray = new long[capacity * 2]; bitset = new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64])); this.growthThreshold = (int) (capacity * loadFactor); @@ -521,7 +679,6 @@ private void allocate(int capacity) { */ public void free() { if (longArray != null) { - memoryManager.free(longArray.memoryBlock()); longArray = null; } if (bitset != null) { @@ -541,7 +698,7 @@ public long getTotalMemoryConsumption() { return ( dataPages.size() * PAGE_SIZE_BYTES + bitset.memoryBlock().size() + - longArray.memoryBlock().size()); + longArray.length * 8L); } /** @@ -581,13 +738,13 @@ int getNumDataPages() { * Grows the size of the hash table and re-hash everything. */ @VisibleForTesting - void growAndRehash() { + public void growAndRehash() { long resizeStartTime = -1; if (enablePerfMetrics) { resizeStartTime = System.nanoTime(); } // Store references to the old data structures to be used when we re-hash - final LongArray oldLongArray = longArray; + final long[] oldLongArray = longArray; final BitSet oldBitSet = bitset; final int oldCapacity = (int) oldBitSet.capacity(); @@ -596,8 +753,8 @@ void growAndRehash() { // Re-mask (we don't recompute the hashcode because we stored all 32 bits of it) for (int pos = oldBitSet.nextSetBit(0); pos >= 0; pos = oldBitSet.nextSetBit(pos + 1)) { - final long keyPointer = oldLongArray.get(pos * 2); - final int hashcode = (int) oldLongArray.get(pos * 2 + 1); + final long keyPointer = oldLongArray[pos * 2]; + final int hashcode = (int) oldLongArray[pos * 2 + 1]; int newPos = hashcode & mask; int step = 1; boolean keepGoing = true; @@ -607,8 +764,8 @@ void growAndRehash() { while (keepGoing) { if (!bitset.isSet(newPos)) { bitset.set(newPos); - longArray.set(newPos * 2, keyPointer); - longArray.set(newPos * 2 + 1, hashcode); + longArray[newPos * 2] = keyPointer; + longArray[newPos * 2 + 1] = hashcode; keepGoing = false; } else { newPos = (newPos + step) & mask; @@ -617,8 +774,6 @@ void growAndRehash() { } } - // Deallocate the old data structures. - memoryManager.free(oldLongArray.memoryBlock()); if (enablePerfMetrics) { timeSpentResizingNs += System.nanoTime() - resizeStartTime; } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/map/HashMapGrowthStrategy.java similarity index 96% rename from unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java rename to core/src/main/java/org/apache/spark/util/collection/unsafe/map/HashMapGrowthStrategy.java index 20654e4eeaa02..d89a3213e7b23 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/map/HashMapGrowthStrategy.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.unsafe.map; +package org.apache.spark.util.collection.unsafe.map; /** * Interface that defines how we can grow the size of a hash map when it is over a threshold. diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/map/Murmur3_x86_32.java similarity index 98% rename from unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java rename to core/src/main/java/org/apache/spark/util/collection/unsafe/map/Murmur3_x86_32.java index 61f483ced3217..2b295d922b0b4 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/map/Murmur3_x86_32.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.unsafe.hash; +package org.apache.spark.util.collection.unsafe.map; import org.apache.spark.unsafe.PlatformDependent; diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/map/AbstractBytesToBytesMapSuite.java similarity index 99% rename from unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java rename to core/src/test/java/org/apache/spark/util/collection/unsafe/map/AbstractBytesToBytesMapSuite.java index dae47e4bab0cb..e448d9f840328 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.unsafe.map; +package org.apache.spark.util.collection.unsafe.map; import java.lang.Exception; import java.nio.ByteBuffer; diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/map/BitSetSuite.java similarity index 98% rename from unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java rename to core/src/test/java/org/apache/spark/util/collection/unsafe/map/BitSetSuite.java index a93fc0ee297c4..efee5c1327eb0 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/map/BitSetSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.unsafe.bitset; +package org.apache.spark.util.collection.unsafe.map; import junit.framework.Assert; import org.junit.Test; diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/map/BytesToBytesMapOffHeapSuite.java similarity index 95% rename from unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java rename to core/src/test/java/org/apache/spark/util/collection/unsafe/map/BytesToBytesMapOffHeapSuite.java index 5a10de49f54fe..68a4534d0d41b 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/map/BytesToBytesMapOffHeapSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.unsafe.map; +package org.apache.spark.util.collection.unsafe.map; import org.apache.spark.unsafe.memory.MemoryAllocator; diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/map/BytesToBytesMapOnHeapSuite.java similarity index 95% rename from unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java rename to core/src/test/java/org/apache/spark/util/collection/unsafe/map/BytesToBytesMapOnHeapSuite.java index 12cc9b25d93b3..6ec919ba17c17 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/map/BytesToBytesMapOnHeapSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.unsafe.map; +package org.apache.spark.util.collection.unsafe.map; import org.apache.spark.unsafe.memory.MemoryAllocator; diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/map/Murmur3_x86_32Suite.java similarity index 98% rename from unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java rename to core/src/test/java/org/apache/spark/util/collection/unsafe/map/Murmur3_x86_32Suite.java index 3b9175835229c..9aba304b3a1d5 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/map/Murmur3_x86_32Suite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.unsafe.hash; +package org.apache.spark.util.collection.unsafe.map; import java.util.HashSet; import java.util.Random; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java index 684de6e81d67c..fb2198854b10f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java @@ -23,7 +23,7 @@ import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.PlatformDependent; -import org.apache.spark.unsafe.map.BytesToBytesMap; +import org.apache.spark.util.collection.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryLocation; import org.apache.spark.unsafe.memory.TaskMemoryManager; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 87e5a89c19658..03f3bb4a4a14a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -27,8 +27,8 @@ import org.apache.spark.sql.types.DataType; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.array.ByteArrayMethods; -import org.apache.spark.unsafe.bitset.BitSetMethods; -import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.util.collection.unsafe.map.BitSetMethods; +import org.apache.spark.util.collection.unsafe.map.Murmur3_x86_32; import org.apache.spark.unsafe.types.UTF8String; import static org.apache.spark.sql.types.DataTypes.*; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalAggregation.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalAggregation.java new file mode 100644 index 0000000000000..585f9f93d521a --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalAggregation.java @@ -0,0 +1,979 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution; + +import java.io.*; +import java.util.*; + +import scala.Tuple2; +import scala.math.Ordering; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.io.ByteStreams; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.DummySerializerInstance; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.SparkConf; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction2; +import org.apache.spark.sql.catalyst.expressions.*; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.DiskBlockObjectWriter; +import org.apache.spark.storage.TempLocalBlockId; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.MemoryLocation; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.collection.unsafe.map.BytesToBytesMap; +import org.apache.spark.util.Utils; + +/** + * Unsafe sort-based external aggregation. + */ +public final class UnsafeExternalAggregation { + + private final Logger logger = LoggerFactory.getLogger(UnsafeExternalAggregation.class); + + /** + * Special record length that is placed after the last record in a data page. + */ + private static final int END_OF_PAGE_MARKER = -1; + + private final TaskMemoryManager memoryManager; + + private final ShuffleMemoryManager shuffleMemoryManager; + + private final BlockManager blockManager; + + private final TaskContext taskContext; + + private ShuffleWriteMetrics writeMetrics; + + /** + * The buffer size to use when writing spills using DiskBlockObjectWriter + */ + private final int fileBufferSizeBytes; + + /** + * A linked list for tracking all allocated data pages so that we can free all of our memory. + */ + private final List dataPages = new LinkedList(); + + /** + * The data page that will be used to store keys and values for new hashtable entries. When this + * page becomes full, a new page will be allocated and this pointer will change to point to that + * new page. + */ + private MemoryBlock currentDataPage = null; + + /** + * Offset into `currentDataPage` that points to the location where new data can be inserted into + * the page. This does not incorporate the page's base offset. + */ + private long pageCursor = 0; + + private long freeSpaceInCurrentPage = 0; + + /** + * The size of the data pages that hold key and value data. Map entries cannot span multiple + * pages, so this limits the maximum entry size. + */ + private static final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes + + private final int initialCapacity; + + /** + * A hashmap which maps from opaque byteArray keys to byteArray values. + */ + private BytesToBytesMap map; + + /** + * Re-used pointer to the current aggregation buffer + */ + private final UnsafeRow currentBuffer = new UnsafeRow(); + + private final JoinedRow joinedRow = new JoinedRow(); + + private MutableProjection algebraicUpdateProjection; + + private MutableProjection algebraicMergeProjection; + + private AggregateFunction2[] nonAlgebraicAggregateFunctions; + + /** + * An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the + * map, we copy this buffer and use it as the value. + */ + private final byte[] emptyAggregationBuffer; + + private final StructType aggregationBufferSchema; + + private final StructType groupingKeySchema; + + private final UnsafeProjection groupingKeyProjection; + + /** + * Encodes grouping keys or buffers as UnsafeRows. + */ + private final Ordering groupingKeyOrdering; + + private boolean enablePerfMetrics; + + private int testSpillFrequency = 0; + + private long numRowsInserted = 0; + + private final LinkedList spillWriters = new LinkedList<>(); + + public UnsafeExternalAggregation( + TaskMemoryManager memoryManager, + ShuffleMemoryManager shuffleMemoryManager, + BlockManager blockManager, + TaskContext taskContext, + MutableProjection algebraicUpdateProjection, + MutableProjection algebraicMergeProjection, + AggregateFunction2[] nonAlgebraicAggregateFunctions, + InternalRow emptyAggregationBuffer, + StructType aggregationBufferSchema, + StructType groupingKeySchema, + Ordering groupingKeyOrdering, + int initialCapacity, + SparkConf conf, + boolean enablePerfMetrics) throws IOException { + this.memoryManager = memoryManager; + this.shuffleMemoryManager = shuffleMemoryManager; + this.blockManager = blockManager; + this.taskContext = taskContext; + this.algebraicUpdateProjection = algebraicUpdateProjection; + this.algebraicMergeProjection = algebraicMergeProjection; + this.nonAlgebraicAggregateFunctions = nonAlgebraicAggregateFunctions; + this.aggregationBufferSchema = aggregationBufferSchema; + this.groupingKeySchema = groupingKeySchema; + this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema); + this.groupingKeyOrdering = groupingKeyOrdering; + this.initialCapacity = initialCapacity; + this.enablePerfMetrics = enablePerfMetrics; + + // Initialize the buffer for aggregation value + final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema); + this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); + assert(this.emptyAggregationBuffer.length == aggregationBufferSchema.length() * 8 + + UnsafeRow.calculateBitSetWidthInBytes(aggregationBufferSchema.length())); + + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided + this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + this.testSpillFrequency = conf.getInt("spark.test.aggregate.spillFrequency", 0); + + initializeUnsafeAppendMap(); + } + + /** + * Allocates new sort data structures. Called when creating the sorter and after each spill. + */ + private void initializeUnsafeAppendMap() throws IOException { + this.writeMetrics = new ShuffleWriteMetrics(); + + int capacity = BytesToBytesMap.getCapacity(initialCapacity); + long memoryRequested = BytesToBytesMap.getMemoryUsage(capacity); + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested); + if (memoryAcquired != memoryRequested) { + shuffleMemoryManager.release(memoryAcquired); + throw new IOException("Could not acquire " + memoryRequested + " bytes of memory"); + } + this.map = new BytesToBytesMap(memoryManager, capacity, enablePerfMetrics); + } + + /** + * Return true if it can external aggregate with the groupKey schema & aggregationBuffer schema, + * false otherwise + */ + public static boolean supportSchema(StructType groupKeySchema, + StructType aggregationBufferSchema) { + for (StructField field : groupKeySchema.fields()) { + if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) { + return false; + } + } + for (StructField field : aggregationBufferSchema.fields()) { + if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) { + return false; + } + } + return true; + } + + /** + * Forces spills to occur every `frequency` records. Only for use in tests. + */ + @VisibleForTesting + public void setTestSpillFrequency(int frequency) { + assert frequency > 0 : "Frequency must be positive"; + testSpillFrequency = frequency; + } + + public void insertRow(InternalRow groupingKey, InternalRow currentRow) + throws IOException { + UnsafeRow unsafeGroupingKeyRow = this.groupingKeyProjection.apply(groupingKey); + numRowsInserted++; + if (testSpillFrequency > 0 && (numRowsInserted % testSpillFrequency) == 0) { + spill(); + } + UnsafeRow aggregationBuffer = this.getAggregationBuffer(unsafeGroupingKeyRow); + + // Process all algebraic aggregate functions. + this.algebraicUpdateProjection.target(aggregationBuffer).apply( + joinedRow.apply(aggregationBuffer, currentRow)); + // Process all non-algebraic aggregate functions. + int i = 0; + while (i < nonAlgebraicAggregateFunctions.length) { + nonAlgebraicAggregateFunctions[i].update(aggregationBuffer, currentRow); + i += 1; + } + } + + /** + * Return the aggregation buffer for the current group. For efficiency, all calls to this method + * return the same object. + */ + public UnsafeRow getAggregationBuffer(UnsafeRow unsafeGroupingKeyRow) + throws IOException { + // Probe our map using the serialized key + final BytesToBytesMap.Location loc = map.lookup( + unsafeGroupingKeyRow.getBaseObject(), + unsafeGroupingKeyRow.getBaseOffset(), + unsafeGroupingKeyRow.getSizeInBytes()); + if (!loc.isDefined()) { + if (!this.putNewKey( + unsafeGroupingKeyRow.getBaseObject(), + unsafeGroupingKeyRow.getBaseOffset(), + unsafeGroupingKeyRow.getSizeInBytes(), + emptyAggregationBuffer, + PlatformDependent.BYTE_ARRAY_OFFSET, + emptyAggregationBuffer.length, + loc)) { + // because spill makes putting new key failed, it should get AggregationBuffer again + return this.getAggregationBuffer(unsafeGroupingKeyRow); + } + } + // Reset the pointer to point to the value that we just stored or looked up: + final MemoryLocation address = loc.getValueAddress(); + currentBuffer.pointTo( + address.getBaseObject(), + address.getBaseOffset(), + aggregationBufferSchema.length(), + loc.getValueLength() + ); + return currentBuffer; + } + + public boolean putNewKey( + Object keyBaseObject, + long keyBaseOffset, + int keyLengthBytes, + Object valueBaseObject, + long valueBaseOffset, + int valueLengthBytes, + BytesToBytesMap.Location location) throws IOException { + assert (!location.isDefined()) : "Can only set value once for a key"; + + assert (keyLengthBytes % 8 == 0); + assert (valueLengthBytes % 8 == 0); + + // Here, we'll copy the data into our data pages. Because we only store a relative offset from + // the key address instead of storing the absolute address of the value, the key and value + // must be stored in the same memory page. + // (8 byte key length) (key) (8 byte value length) (value) + final int requiredSize = 8 + keyLengthBytes + 8 + valueLengthBytes; + if (!haveSpaceForRecord(requiredSize)) { + if (!allocateSpaceForRecord(requiredSize)){ + // if spill have been happened, re-insert current groupingKey + return false; + } + } + + freeSpaceInCurrentPage -= requiredSize; + // Compute all of our offsets up-front: + final Object pageBaseObject = currentDataPage.getBaseObject(); + final long pageBaseOffset = currentDataPage.getBaseOffset(); + final long keySizeOffsetInPage = pageBaseOffset + pageCursor; + pageCursor += 8; // word used to store the key size + final long keyDataOffsetInPage = pageBaseOffset + pageCursor; + pageCursor += keyLengthBytes; + final long valueSizeOffsetInPage = pageBaseOffset + pageCursor; + pageCursor += 8; // word used to store the value size + final long valueDataOffsetInPage = pageBaseOffset + pageCursor; + pageCursor += valueLengthBytes; + + // Copy the key + PlatformDependent.UNSAFE.putLong(pageBaseObject, keySizeOffsetInPage, keyLengthBytes); + PlatformDependent.copyMemory( + keyBaseObject, keyBaseOffset, pageBaseObject, keyDataOffsetInPage, keyLengthBytes); + // Copy the value + PlatformDependent.UNSAFE.putLong(pageBaseObject, valueSizeOffsetInPage, valueLengthBytes); + PlatformDependent.copyMemory( + valueBaseObject, valueBaseOffset, pageBaseObject, valueDataOffsetInPage, valueLengthBytes); + + final long storedKeyAddress = memoryManager.encodePageNumberAndOffset( + currentDataPage, keySizeOffsetInPage); + location.putNewKey(storedKeyAddress); + return true; + } + + /** + * Checks whether there is enough space to insert a new record into the sorter. + * + * @param requiredSpace the required space in the data page, in bytes, including space for storing + * the record size. + + * @return true if the record can be inserted without requiring more allocations, false otherwise. + */ + private boolean haveSpaceForRecord(int requiredSpace) { + assert (requiredSpace > 0); + return (map.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage)); + } + + /** + * Allocates more memory in order to insert an additional record. This will request additional + * memory from the {@link org.apache.spark.shuffle.ShuffleMemoryManager} and + * spill if the requested memory can not be obtained. + * + * @param requiredSpace the required space in the data page, in bytes, including space for storing + * the record size. + */ + private boolean allocateSpaceForRecord(int requiredSpace) throws IOException { + boolean noSpill = true; + assert (requiredSpace <= PAGE_SIZE_BYTES - 8); // Reserve 8 bytes for the end-of-page marker. + if (!map.hasSpaceForAnotherRecord()) { + logger.debug("Attempting to grow size of hash table"); + final int nextCapacity = map.getNextCapacity(); + final long oldPointerArrayMemoryUsage = map.getTotalMemoryConsumption(); + final long memoryToGrowPointerArray = BytesToBytesMap.getMemoryUsage(nextCapacity); + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray); + if (memoryAcquired < memoryToGrowPointerArray) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + } else { + map.growAndRehash(); + shuffleMemoryManager.release(oldPointerArrayMemoryUsage); + } + } + + // If there's not enough space in the current page, allocate a new page (8 bytes are reserved + // for the end-of-page marker). + if (currentDataPage == null || freeSpaceInCurrentPage < requiredSpace) { + logger.trace("Required space {} is less than free space in current page ({})", + requiredSpace, freeSpaceInCurrentPage); + if (currentDataPage != null) { + // There wasn't enough space in the current page, so write an end-of-page marker: + final Object pageBaseObject = currentDataPage.getBaseObject(); + final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor; + PlatformDependent.UNSAFE.putLong(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER); + } + long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE_BYTES); + if (memoryAcquired < PAGE_SIZE_BYTES) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + noSpill = false; + final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE_BYTES); + if (memoryAcquiredAfterSpilling != PAGE_SIZE_BYTES) { + shuffleMemoryManager.release(memoryAcquiredAfterSpilling); + throw new IOException("Unable to acquire " + PAGE_SIZE_BYTES + " bytes of memory"); + } else { + memoryAcquired = memoryAcquiredAfterSpilling; + } + } + MemoryBlock newPage = memoryManager.allocatePage(memoryAcquired); + dataPages.add(newPage); + pageCursor = 0; + freeSpaceInCurrentPage = PAGE_SIZE_BYTES - 8; + currentDataPage = newPage; + } + return noSpill; + } + + private static final class SortComparator implements Comparator { + + private final TaskMemoryManager memoryManager; + private final Ordering ordering; + private final int numFields; + private final UnsafeRow row1 = new UnsafeRow(); + private final UnsafeRow row2 = new UnsafeRow(); + + SortComparator(TaskMemoryManager memoryManager, Ordering ordering, int numFields) { + this.memoryManager = memoryManager; + this.numFields = numFields; + this.ordering = ordering; + } + + @Override + public int compare(Long fullKeyAddress1, Long fullKeyAddress2) { + final Object baseObject1 = memoryManager.getPage(fullKeyAddress1); + final long baseOffset1 = memoryManager.getOffsetInPage(fullKeyAddress1) + 8; + + final Object baseObject2 = memoryManager.getPage(fullKeyAddress2); + final long baseOffset2 = memoryManager.getOffsetInPage(fullKeyAddress2) + 8; + + row1.pointTo(baseObject1, baseOffset1, numFields, -1); + row2.pointTo(baseObject2, baseOffset2, numFields, -1); + return ordering.compare(row1, row2); + } + } + + /** + * Sort and spill the current records in response to memory pressure. + */ + @VisibleForTesting + void spill() throws IOException { + logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", + Thread.currentThread().getId(), + Utils.bytesToString(getMemoryUsage()), + spillWriters.size(), + spillWriters.size() > 1 ? " times" : " time"); + + final UnsafeSorterKVSpillWriter spillWriter = + new UnsafeSorterKVSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, + map.size()); + spillWriters.add(spillWriter); + final Iterator sortedRecords = map.getSortedIterator( + new SortComparator(this.memoryManager, groupingKeyOrdering, groupingKeySchema.size())); + while (sortedRecords.hasNext()) { + BytesToBytesMap.Location location = sortedRecords.next(); + spillWriter.write(location); + } + spillWriter.close(); + + final long spillSize = freeMemory(); + taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); + + initializeUnsafeAppendMap(); + } + + public AbstractScalaIterator getSortedIterator() { + return new AbstractScalaIterator() { + + Iterator sorter = map.getSortedIterator( + new SortComparator(memoryManager, groupingKeyOrdering, groupingKeySchema.size())); + + @Override + public boolean hasNext() { + return sorter.hasNext(); + } + + @Override + public BytesToBytesMap.Location next() { + return sorter.next(); + } + }; + } + + /** + * Return an iterator that merges the in-memory map with the spilled files. + * If no spill has occurred, simply return the in-memory map's iterator. + */ + public AbstractScalaIterator iterator() throws IOException { + if (spillWriters.isEmpty()) { + return this.getMemoryIterator(); + } else { + return this.merge(this.getSortedIterator()); + } + } + + public UnsafeRow getKey(BytesToBytesMap.Location location) { + UnsafeRow key = new UnsafeRow(); + key.pointTo( + location.getKeyBaseObject(), + location.getKeyBaseOffset(), + 1, + location.getKeyLength()); + return key; + } + + public UnsafeRow getValue(BytesToBytesMap.Location location) { + UnsafeRow value = new UnsafeRow(); + value.pointTo( + location.getValueBaseObject(), + location.getValueBaseOffset(), + 1, + location.getValueLength()); + return value; + } + + /** + * Returns an iterator over the keys and values in in-memory map. + */ + public AbstractScalaIterator getMemoryIterator() { + return new BytesToBytesMapIterator(map.size()); + } + + /** + * Merge aggregate of the in-memory map with the spilled files, giving an iterator over elements. + */ + private AbstractScalaIterator merge( + AbstractScalaIterator inMemory) throws IOException { + + final Comparator keyOrdering = + new Comparator() { + private final UnsafeRow row1 = new UnsafeRow(); + private final UnsafeRow row2 = new UnsafeRow(); + + public int compare(BufferedIterator o1, BufferedIterator o2){ + row1.pointTo( + o1.getRecordLocation().getKeyBaseObject(), + o1.getRecordLocation().getKeyBaseOffset(), + groupingKeySchema.size(), + o1.getRecordLocation().getKeyLength()); + row2.pointTo( + o2.getRecordLocation().getKeyBaseObject(), + o2.getRecordLocation().getKeyBaseOffset(), + groupingKeySchema.size(), + o2.getRecordLocation().getKeyLength()); + return groupingKeyOrdering.compare(row1, row2); + } + }; + final Queue priorityQueue = + new PriorityQueue(spillWriters.size() + 1, keyOrdering); + BufferedIterator inMemoryBuffer = this.asBuffered(inMemory); + if (inMemoryBuffer.hasNext()) { + inMemoryBuffer.loadNext(); + priorityQueue.add(inMemoryBuffer); + } + + for (int i = 0; i < spillWriters.size(); i++) { + BufferedIterator spillBuffer = this.asBuffered(spillWriters.get(i).getReader(blockManager)); + if (spillBuffer.hasNext()) { + spillBuffer.loadNext(); + priorityQueue.add(spillBuffer); + } + } + final AbstractScalaIterator mergeIter = + new AbstractScalaIterator() { + + BufferedIterator topIter = null; + + @Override + public boolean hasNext() { + return !priorityQueue.isEmpty() || (topIter != null && topIter.hasNext()); + } + + @Override + public BytesToBytesMap.Location next() throws IOException { + if (topIter != null && topIter.hasNext()) { + topIter.loadNext(); + priorityQueue.add(topIter); + } + topIter = priorityQueue.poll(); + return topIter.getRecordLocation(); + } + }; + + final BufferedIterator sorted = asBuffered(mergeIter); + return new AbstractScalaIterator() { + + private UnsafeRow currentKey = new UnsafeRow(); + private UnsafeRow currentValue = new UnsafeRow(); + private UnsafeRow nextKey = new UnsafeRow(); + private UnsafeRow nextValue = new UnsafeRow(); + private BytesToBytesMap.Location currentLocation = null; + + @Override + public boolean hasNext() { + return currentLocation != null || sorted.hasNext(); + } + + @Override + public MapEntry next() throws IOException { + try { + if (currentLocation == null) { + sorted.loadNext(); + currentLocation = sorted.getRecordLocation(); + } + currentKey.pointTo( + currentLocation.getKeyBaseObject(), + currentLocation.getKeyBaseOffset(), + groupingKeySchema.size(), + currentLocation.getKeyLength()); + currentKey = currentKey.copy(); + currentValue.pointTo( + currentLocation.getValueBaseObject(), + currentLocation.getValueBaseOffset(), + aggregationBufferSchema.size(), + currentLocation.getValueLength()); + currentValue = currentValue.copy(); + currentLocation = null; + while (sorted.hasNext()) { + sorted.loadNext(); + BytesToBytesMap.Location nextLocation = sorted.getRecordLocation(); + nextKey.pointTo( + nextLocation.getKeyBaseObject(), + nextLocation.getKeyBaseOffset(), + groupingKeySchema.size(), + nextLocation.getKeyLength()); + nextValue.pointTo( + nextLocation.getValueBaseObject(), + nextLocation.getValueBaseOffset(), + aggregationBufferSchema.size(), + nextLocation.getValueLength()); + + if (groupingKeyOrdering.compare(currentKey, nextKey) != 0) { + currentLocation = nextLocation; + break; + } + + // Process all algebraic aggregate functions. + algebraicMergeProjection.target(currentValue).apply( + joinedRow.apply(currentValue, nextValue)); + // Process all non-algebraic aggregate functions. + int i = 0; + while (i < nonAlgebraicAggregateFunctions.length) { + nonAlgebraicAggregateFunctions[i].merge(currentValue, nextValue); + i += 1; + } + } + + return new MapEntry(currentKey, currentValue); + } catch (IOException e) { + cleanupResources(); + // Scala iterators don't declare any checked exceptions, so we need to use this hack + // to re-throw the exception: + PlatformDependent.throwException(e); + } + throw new RuntimeException("Exception should have been re-thrown in next()"); + } + }; + } + + public BufferedIterator asBuffered(AbstractScalaIterator iterator) { + return new BufferedIterator(iterator); + } + + public interface AbstractScalaIterator { + + public abstract boolean hasNext(); + + public abstract E next() throws IOException; + + } + + public class BufferedIterator { + + private BytesToBytesMap.Location location = null; + private AbstractScalaIterator iterator; + + public BufferedIterator(AbstractScalaIterator iterator) { + this.iterator = iterator; + } + + public boolean hasNext() { + return iterator.hasNext(); + } + + public void loadNext() throws IOException { + location = iterator.next(); + } + + public BytesToBytesMap.Location getRecordLocation() { + return location; + } + } + + /** + * Mutable pair object + */ + public class MapEntry { + + public UnsafeRow key; + public UnsafeRow value; + + public MapEntry() { + this.key = new UnsafeRow(); + this.value = new UnsafeRow(); + } + + public MapEntry(UnsafeRow key, UnsafeRow value) { + this.key = key; + this.value = value; + } + } + + + public class UnsafeSorterKVSpillWriter { + + static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; + + // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to + // be an API to directly transfer bytes from managed memory to the disk writer, we buffer + // data through a byte array. + private byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE]; + + private final File file; + private final BlockId blockId; + private final int numRecordsToWrite; + private DiskBlockObjectWriter writer; + private int numRecordsSpilled = 0; + + public UnsafeSorterKVSpillWriter( + BlockManager blockManager, + int fileBufferSize, + ShuffleWriteMetrics writeMetrics, + int numRecordsToWrite) throws IOException { + final Tuple2 spilledFileInfo = + blockManager.diskBlockManager().createTempLocalBlock(); + this.file = spilledFileInfo._2(); + this.blockId = spilledFileInfo._1(); + this.numRecordsToWrite = numRecordsToWrite; + // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. + // Our write path doesn't actually use this serializer (since we end up calling the `write()` + // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work + // around this, we pass a dummy no-op serializer. + writer = blockManager.getDiskWriter( + blockId, file, DummySerializerInstance.INSTANCE, fileBufferSize, writeMetrics); + // Write the number of records + writeIntToBuffer(numRecordsToWrite, 0); + writer.write(writeBuffer, 0, 4); + } + + // Based on DataOutputStream.writeInt. + private void writeIntToBuffer(int v, int offset) throws IOException { + writeBuffer[offset + 0] = (byte)(v >>> 24); + writeBuffer[offset + 1] = (byte)(v >>> 16); + writeBuffer[offset + 2] = (byte)(v >>> 8); + writeBuffer[offset + 3] = (byte)(v >>> 0); + } + + public void write(BytesToBytesMap.Location loc) throws IOException { + if (numRecordsSpilled == numRecordsToWrite) { + throw new IllegalStateException( + "Number of records written exceeded numRecordsToWrite = " + numRecordsToWrite); + } else { + numRecordsSpilled++; + } + this.write(loc.getKeyAddress().getBaseObject(), loc.getKeyAddress().getBaseOffset(), + loc.getKeyLength()); + this.write(loc.getValueAddress().getBaseObject(), loc.getValueAddress().getBaseOffset(), + loc.getValueLength()); + writer.recordWritten(); + } + + + /** + * Write a record to a spill file. + * + * @param baseObject the base object / memory page containing the record + * @param baseOffset the base offset which points directly to the record data. + * @param recordLength the length of the record. + */ + public void write( + Object baseObject, + long baseOffset, + int recordLength) throws IOException { + writeIntToBuffer(recordLength, 0); + int dataRemaining = recordLength; + int freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE - 4; // space used by len + long recordReadPosition = baseOffset; + while (dataRemaining > 0) { + final int toTransfer = Math.min(freeSpaceInWriteBuffer, dataRemaining); + PlatformDependent.copyMemory( + baseObject, + recordReadPosition, + writeBuffer, + PlatformDependent.BYTE_ARRAY_OFFSET + (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer), + toTransfer); + writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer) + toTransfer); + recordReadPosition += toTransfer; + dataRemaining -= toTransfer; + freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE; + } + if (freeSpaceInWriteBuffer < DISK_WRITE_BUFFER_SIZE) { + writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer)); + } + } + + public void close() throws IOException { + writer.commitAndClose(); + writer = null; + writeBuffer = null; + } + + public AbstractScalaIterator getReader(BlockManager blockManager) + throws IOException { + return new UnsafeSorterKVSpillReader(blockManager, file, blockId); + } + } + + final class UnsafeSorterKVSpillReader implements AbstractScalaIterator { + + private InputStream in; + private DataInputStream din; + + // Variables that change with every kv read: + private int numRecordsRemaining; + + private int keyLength; + private byte[] keyArray = new byte[1024 * 1024]; + + private int valueLength; + private byte[] valueArray = new byte[1024 * 1024]; + private final BytesToBytesMap.Location location = map.getNewLocation(); + + public UnsafeSorterKVSpillReader( + BlockManager blockManager, + File file, + BlockId blockId) throws IOException { + assert (file.length() > 0); + final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file)); + this.in = blockManager.wrapForCompression(blockId, bs); + this.din = new DataInputStream(this.in); + numRecordsRemaining = din.readInt(); + } + + @Override + public boolean hasNext() { + return (numRecordsRemaining > 0); + } + + @Override + public BytesToBytesMap.Location next() throws IOException { + keyLength = din.readInt(); + if (keyLength > keyArray.length) { + keyArray = new byte[keyLength]; + } + ByteStreams.readFully(in, keyArray, 0, keyLength); + valueLength = din.readInt(); + if (valueLength > valueArray.length) { + valueArray = new byte[valueLength]; + } + ByteStreams.readFully(in, valueArray, 0, valueLength); + numRecordsRemaining--; + if (numRecordsRemaining == 0) { + in.close(); + in = null; + din = null; + } + location.with(keyLength, keyArray, valueLength, valueArray); + return location; + } + } + + + public class BytesToBytesMapIterator implements AbstractScalaIterator { + + private final MapEntry entry = new MapEntry(); + private final BytesToBytesMap.Location loc = map.getNewLocation(); + private int currentRecordNumber = 0; + private Object pageBaseObject; + private long offsetInPage; + private int numRecords; + + public BytesToBytesMapIterator(int numRecords) { + this.numRecords = numRecords; + if (dataPages.iterator().hasNext()) { + advanceToNextPage(); + } + } + + private void advanceToNextPage() { + final MemoryBlock currentPage = dataPages.iterator().next(); + pageBaseObject = currentPage.getBaseObject(); + offsetInPage = currentPage.getBaseOffset(); + } + + @Override + public boolean hasNext() { + return currentRecordNumber != numRecords; + } + + @Override + public MapEntry next() { + int keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage); + if (keyLength == END_OF_PAGE_MARKER) { + advanceToNextPage(); + keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage); + } + loc.with(pageBaseObject, offsetInPage); + offsetInPage += 8 + 8 + keyLength + loc.getValueLength(); + + MemoryLocation keyAddress = loc.getKeyAddress(); + MemoryLocation valueAddress = loc.getValueAddress(); + entry.key.pointTo( + keyAddress.getBaseObject(), + keyAddress.getBaseOffset(), + groupingKeySchema.length(), + loc.getKeyLength() + ); + entry.value.pointTo( + valueAddress.getBaseObject(), + valueAddress.getBaseOffset(), + aggregationBufferSchema.length(), + loc.getValueLength() + ); + currentRecordNumber++; + return entry; + } + } + + private long getMemoryUsage() { + return map.getTotalMemoryConsumption() + (dataPages.size() * PAGE_SIZE_BYTES); + } + + private void cleanupResources() { + this.freeMemory(); + } + + /** + * Free the unsafe memory associated with this map. + */ + public long freeMemory() { + long memoryFreed = 0; + for (MemoryBlock block : dataPages) { + memoryManager.freePage(block); + shuffleMemoryManager.release(block.size()); + memoryFreed += block.size(); + } + if (map != null) { + long sorterMemoryUsage = map.getTotalMemoryConsumption(); + map.free(); + map = null; + shuffleMemoryManager.release(sorterMemoryUsage); + memoryFreed += sorterMemoryUsage; + } + dataPages.clear(); + currentDataPage = null; + pageCursor = 0; + freeSpaceInCurrentPage = 0; + return memoryFreed; + } + + @SuppressWarnings("UseOfSystemOutOrSystemErr") + public void printPerfMetrics() { + if (!enablePerfMetrics) { + throw new IllegalStateException("Perf metrics not enabled"); + } + System.out.println("Average probes per lookup: " + map.getAverageProbesPerLookup()); + System.out.println("Number of hash collisions: " + map.getNumHashCollisions()); + System.out.println("Time spent resizing (ns): " + map.getTimeSpentResizingNs()); + System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption()); + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index cc89d74146b34..21f34a0009be9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -22,6 +22,24 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjecti import org.apache.spark.sql.types.{StructType, DataType} import org.apache.spark.unsafe.types.UTF8String +/** + * Converts a [[InternalRow]] to another Row given a sequence of expression that define each + * column of the new row. If the schema of the input row is specified, then the given expression + * will be bound to that schema. + * + * In contrast to a normal projection, a MutableProjection reuses the same underlying row object + * each time an input row is added. This significantly reduces the cost of calculating the + * projection, but means that it is not safe to hold on to a reference to a [[InternalRow]] after + * `next()` has been called on the [[Iterator]] that produced it. Instead, the user must call + * `InternalRow.copy()` and hold on to the returned [[InternalRow]] before calling `next()`. + */ +abstract class MutableProjection extends Projection { + def currentValue: InternalRow + + /** Uses the given row to store the output of the projection. */ + def target(row: MutableRow): MutableProjection +} + /** * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. * @param expressions a sequence of expressions that determine the value of each column of the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 30b7f8d3766a5..aba2c7bb4f40c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -63,21 +63,4 @@ package object expressions { */ abstract class Projection extends (InternalRow => InternalRow) - /** - * Converts a [[InternalRow]] to another Row given a sequence of expression that define each - * column of the new row. If the schema of the input row is specified, then the given expression - * will be bound to that schema. - * - * In contrast to a normal projection, a MutableProjection reuses the same underlying row object - * each time an input row is added. This significantly reduces the cost of calculating the - * projection, but means that it is not safe to hold on to a reference to a [[InternalRow]] after - * `next()` has been called on the [[Iterator]] that produced it. Instead, the user must call - * `InternalRow.copy()` and hold on to the returned [[InternalRow]] before calling `next()`. - */ - abstract class MutableProjection extends Projection { - def currentValue: InternalRow - - /** Uses the given row to store the output of the projection. */ - def target(row: MutableRow): MutableProjection - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 2a641b9d64a95..1dbf85a1a16b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -401,6 +401,9 @@ private[spark] object SQLConf { val USE_SQL_AGGREGATE2 = booleanConf("spark.sql.useAggregate2", defaultValue = Some(true), doc = "") + val USE_HYBRID_AGGREGATE = booleanConf("spark.sql.useHybridAggregate", + defaultValue = Some(false), doc = "") + val USE_SQL_SERIALIZER2 = booleanConf( "spark.sql.useSerializer2", defaultValue = Some(true), isPublic = false) @@ -474,6 +477,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2) + private[spark] def useHybridAggregate: Boolean = getConf(USE_HYBRID_AGGREGATE) + private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2) private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index eb4be1900b153..459111a0554a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{SQLContext, Strategy, execution} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} @@ -221,7 +221,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { expr.collect { case agg: AggregateExpression2 => agg } - }.toSet.toSeq + } // For those distinct aggregate expressions, we create a map from the // aggregate function to the corresponding attribute of the function. val aggregateFunctionMap = aggregateExpressions.map { agg => @@ -247,6 +247,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { aggregateExpressions, aggregateFunctionMap, resultExpressions, + sqlContext.conf.useHybridAggregate, planLater(child)) } else { aggregate.Utils.planAggregateWithOneDistinct( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala new file mode 100644 index 0000000000000..0d7732fd8ab8e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -0,0 +1,112 @@ +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ + +import scala.collection.mutable.ArrayBuffer + +private[sql] abstract class AggregationIterator( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression2], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]){ + + /////////////////////////////////////////////////////////////////////////// + // Static fields for this iterator + /////////////////////////////////////////////////////////////////////////// + + protected val aggregateFunctions: Array[AggregateFunction2] = { + var bufferOffset = initialBufferOffset + val functions = new Array[AggregateFunction2](aggregateExpressions.length) + var i = 0 + while (i < aggregateExpressions.length) { + val func = aggregateExpressions(i).aggregateFunction + val funcWithBoundReferences = aggregateExpressions(i).mode match { + case Partial | Complete if !func.isInstanceOf[AlgebraicAggregate] => + // We need to create BoundReferences if the function is not an + // AlgebraicAggregate (it does not support code-gen) and the mode of + // this function is Partial or Complete because we will call eval of this + // function's children in the update method of this aggregate function. + // Those eval calls require BoundReferences to work. + BindReferences.bindReference(func, inputAttributes) + case _ => func + } + // Set bufferOffset for this function. It is important that setting bufferOffset + // happens after all potential bindReference operations because bindReference + // will create a new instance of the function. + funcWithBoundReferences.bufferOffset = bufferOffset + bufferOffset += funcWithBoundReferences.bufferSchema.length + functions(i) = funcWithBoundReferences + i += 1 + } + functions + } + + // Positions of those non-algebraic aggregate functions in aggregateFunctions. + // For example, we have func1, func2, func3, func4 in aggregateFunctions, and + // func2 and func3 are non-algebraic aggregate functions. + // nonAlgebraicAggregateFunctionPositions will be [1, 2]. + protected val nonAlgebraicAggregateFunctionPositions: Array[Int] = { + val positions = new ArrayBuffer[Int]() + var i = 0 + while (i < aggregateFunctions.length) { + aggregateFunctions(i) match { + case agg: AlgebraicAggregate => + case _ => positions += i + } + i += 1 + } + positions.toArray + } + + // All non-algebraic aggregate functions. + protected val nonAlgebraicAggregateFunctions: Array[AggregateFunction2] = + nonAlgebraicAggregateFunctionPositions.map(aggregateFunctions) + + // The underlying buffer shared by all aggregate functions. + protected val buffer: MutableRow = { + // The number of elements of the underlying buffer of this operator. + // All aggregate functions are sharing this underlying buffer and they find their + // buffer values through bufferOffset. + var size = initialBufferOffset + var i = 0 + while (i < aggregateFunctions.length) { + size += aggregateFunctions(i).bufferSchema.length + i += 1 + } + new GenericMutableRow(size) + } + + /** Initializes buffer values for all aggregate functions. */ + protected def initializeBuffer(): Unit = { + algebraicInitialProjection(EmptyRow) + var i = 0 + while (i < nonAlgebraicAggregateFunctions.length) { + nonAlgebraicAggregateFunctions(i).initialize(buffer) + i += 1 + } + } + + protected val joinedRow = new JoinedRow + + // This is used to project expressions for the grouping expressions. + protected val groupGenerator = + newMutableProjection(groupingExpressions, inputAttributes)() + + protected def initialBufferOffset: Int + + protected val placeholderExpressions = Seq.fill(initialBufferOffset)(NoOp) + + // This projection is used to initialize buffer values for all AlgebraicAggregates. + protected val algebraicInitialProjection = { + val initExpressions = placeholderExpressions ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.initialValues + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + newMutableProjection(initExpressions, Nil)().target(buffer) + } + + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HybridAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HybridAggregationIterator.scala new file mode 100644 index 0000000000000..92734162f8c65 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HybridAggregationIterator.scala @@ -0,0 +1,273 @@ +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.{Logging, SparkEnv, TaskContext} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering +import org.apache.spark.sql.execution.UnsafeExternalAggregation +import org.apache.spark.sql.types.{StructField, StructType, NullType} + +/** + * An iterator used to do partial aggregations for hybrid aggregate + */ +class HybridPartialAggregationIterator( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression2], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]) + extends AggregationIterator( + groupingExpressions, + aggregateExpressions, + newMutableProjection, + inputAttributes, + inputIter) + with Logging { + + override protected def initialBufferOffset: Int = 0 + + // This projection is used to update buffer values for all AlgebraicAggregates. + private val algebraicUpdateProjection = { + val bufferSchema = aggregateFunctions.flatMap(_.bufferAttributes) + val updateExpressions = aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.updateExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)() + } + + // This projection is used to merge buffer values for all AlgebraicAggregates. + private val algebraicMergeProjection = { + val bufferSchemata = aggregateFunctions.flatMap(_.bufferAttributes) ++ + aggregateFunctions.flatMap(_.cloneBufferAttributes) + val mergeExpressions = aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.mergeExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + + newMutableProjection(mergeExpressions, bufferSchemata)() + } + + def iterator: Iterator[InternalRow] = { + + val groupingKeyOrdering: Ordering[InternalRow] = GenerateOrdering.generate( + groupingExpressions.map(_.dataType).zipWithIndex.map { case(dt, index) => + new SortOrder(BoundReference(index, dt, nullable = true), Ascending) + }) + + val aggregationBufferSchema: StructType = StructType.fromAttributes( + aggregateFunctions.flatMap(_.bufferAttributes) + ) + + val groupKeySchema: StructType = { + val fields = groupingExpressions.zipWithIndex.map { case (expr, idx) => + StructField(idx.toString, expr.dataType, expr.nullable) + } + StructType(fields) + } + + this.initializeBuffer() + + log.info(s"GroupKey Schema: ${groupKeySchema.mkString(",")}") + log.info(s"AggregationBuffer Schema: ${aggregationBufferSchema.mkString(",")}") + + val sparkEnv: SparkEnv = SparkEnv.get + val taskContext: TaskContext = TaskContext.get + val aggregationMap = new UnsafeExternalAggregation( + taskContext.taskMemoryManager, + sparkEnv.shuffleMemoryManager, + sparkEnv.blockManager, + taskContext, + algebraicUpdateProjection, + algebraicMergeProjection, + nonAlgebraicAggregateFunctions, + buffer, + aggregationBufferSchema, + groupKeySchema, + groupingKeyOrdering, + 1024 * 16, + sparkEnv.conf, + false + ) + + while (inputIter.hasNext) { + val currentRow: InternalRow = inputIter.next() + val groupKey: InternalRow = groupGenerator(currentRow).copy() + aggregationMap.insertRow(groupKey, currentRow) + } + + new Iterator[InternalRow] { + private[this] val mapIterator = aggregationMap.iterator() + + def hasNext: Boolean = mapIterator.hasNext + + def next(): InternalRow = { + val entry = mapIterator.next() + if (hasNext) { + joinedRow(entry.key, entry.value) + } else { + // This is the last element in the iterator, so let's free the buffer. Before we do, + // though, we need to make a defensive copy of the result so that we don't return an + // object that might contain dangling pointers to the freed memory + val resultCopy = joinedRow(entry.key.copy(), entry.value.copy()) + aggregationMap.freeMemory() + resultCopy + } + } + } + } +} + +/** + * An iterator used to do final aggregations for hybrid aggregate + */ +class HybridFinalAggregationIterator( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression2], + aggregateAttributes: Seq[Attribute], + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]) + extends AggregationIterator( + groupingExpressions, + aggregateExpressions, + newMutableProjection, + inputAttributes, + inputIter) + with Logging { + + // The result of aggregate functions. + private val aggregateResult: MutableRow = new GenericMutableRow(aggregateAttributes.length) + + // The projection used to generate the output rows of this operator. + // This is only used when we are generating final results of aggregate functions. + private val resultProjection = + newMutableProjection( + resultExpressions, groupingExpressions.map(_.toAttribute) ++ aggregateAttributes)() + + protected def initialBufferOffset: Int = 0 + + private val offsetAttributes = + Seq.fill(groupingExpressions.length)(AttributeReference("placeholder", NullType)()) + + // This projection is used to partial merge buffer values for all AlgebraicAggregates. + private val algebraicPartialMergeProjection = { + val bufferSchema = aggregateFunctions.flatMap(_.bufferAttributes) ++ + offsetAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes) + val mergeExpressions = aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.mergeExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + + newMutableProjection(mergeExpressions, bufferSchema)() + } + + // This projection is used to merge buffer values for all AlgebraicAggregates. + private val algebraicMergeProjection = { + val bufferSchemata = aggregateFunctions.flatMap(_.bufferAttributes) ++ + aggregateFunctions.flatMap(_.cloneBufferAttributes) + val mergeExpressions = aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.mergeExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + + newMutableProjection(mergeExpressions, bufferSchemata)() + } + + // This projection is used to evaluate all AlgebraicAggregates. + private val algebraicEvalProjection = { + val bufferSchemata = aggregateFunctions.flatMap(_.bufferAttributes) ++ + aggregateFunctions.flatMap(_.cloneBufferAttributes) + val evalExpressions = aggregateFunctions.map { + case ae: AlgebraicAggregate => ae.evaluateExpression + case agg: AggregateFunction2 => NoOp + } + + newMutableProjection(evalExpressions, bufferSchemata)() + } + + def iterator: Iterator[InternalRow] = { + + val groupingKeyOrdering: Ordering[InternalRow] = GenerateOrdering.generate( + groupingExpressions.map(_.dataType).zipWithIndex.map { case(dt, index) => + new SortOrder(BoundReference(index, dt, nullable = true), Ascending) + }) + + val aggregationBufferSchema: StructType = StructType.fromAttributes( + aggregateFunctions.flatMap(_.bufferAttributes) + ) + + val groupKeySchema: StructType = { + val fields = groupingExpressions.zipWithIndex.map { case (expr, idx) => + // This is a dummy field name + StructField(idx.toString, expr.dataType, expr.nullable) + } + StructType(fields) + } + + this.initializeBuffer() + + log.info(s"GroupKey Schema: ${groupKeySchema.mkString(",")}") + log.info(s"AggregationBuffer Schema: ${aggregationBufferSchema.mkString(",")}") + + val sparkEnv: SparkEnv = SparkEnv.get + val taskContext: TaskContext = TaskContext.get + val aggregationMap = new UnsafeExternalAggregation( + taskContext.taskMemoryManager, + sparkEnv.shuffleMemoryManager, + sparkEnv.blockManager, + taskContext, + algebraicPartialMergeProjection, + algebraicMergeProjection, + nonAlgebraicAggregateFunctions, + buffer, + aggregationBufferSchema, + groupKeySchema, + groupingKeyOrdering, + 1024 * 16, + sparkEnv.conf, + false + ) + + while (inputIter.hasNext) { + val currentRow: InternalRow = inputIter.next() + val groupKey: InternalRow = groupGenerator(currentRow).copy() + aggregationMap.insertRow(groupKey, currentRow) + } + + new Iterator[InternalRow] { + private[this] val mapIterator = aggregationMap.iterator() + private[this] var nextKey: InternalRow = _ + private[this] var nextValue: InternalRow = _ + + def hasNext: Boolean = mapIterator.hasNext + + def next(): InternalRow = { + val entry = mapIterator.next() + if (hasNext) { + nextKey = entry.key + nextValue = entry.value + } else { + nextKey = entry.key.copy() + nextValue = entry.value.copy() + aggregationMap.freeMemory() + } + // Generate results for all algebraic aggregate functions. + algebraicEvalProjection.target(aggregateResult)(nextValue) + // Generate results for all non-algebraic aggregate functions. + var i = 0 + while (i < nonAlgebraicAggregateFunctions.length) { + aggregateResult.update( + nonAlgebraicAggregateFunctionPositions(i), + nonAlgebraicAggregateFunctions(i).eval(nextValue)) + i += 1 + } + resultProjection(joinedRow(nextKey, aggregateResult)) + } + } + } + +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala index 0c9082897f390..0e59bb8f0c0cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala @@ -117,6 +117,70 @@ case class Aggregate2Sort( } } +case class Aggregate2Hybrid( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression2], + aggregateAttributes: Seq[Attribute], + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryNode { + + override def canProcessUnsafeRows: Boolean = true + + override def references: AttributeSet = { + val referencesInResults = + AttributeSet(resultExpressions.flatMap(_.references)) -- AttributeSet(aggregateAttributes) + + AttributeSet( + groupingExpressions.flatMap(_.references) ++ + aggregateExpressions.flatMap(_.references) ++ + referencesInResults) + } + + override def requiredChildDistribution: List[Distribution] = { + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.length == 0 => AllTuples :: Nil + case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil + case None => UnspecifiedDistribution :: Nil + } + } + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + child.execute().mapPartitions { iter => + assert(aggregateExpressions.length > 0, "Use Aggregate2Sort when no aggregate function") + val aggregationIterator: Iterator[InternalRow] = { + aggregateExpressions.map(_.mode).distinct.toList match { + case Partial :: Nil => + new HybridPartialAggregationIterator( + groupingExpressions, + aggregateExpressions, + newMutableProjection, + child.output, + iter).iterator + case Final :: Nil => + new HybridFinalAggregationIterator( + groupingExpressions, + aggregateExpressions, + aggregateAttributes, + resultExpressions, + newMutableProjection, + child.output, + iter).iterator + case other => + sys.error( + s"Could not evaluate ${aggregateExpressions} because we do not support evaluate " + + s"modes $other in this operator.") + } + } + + aggregationIterator + } + } +} + case class FinalAndCompleteAggregate2Sort( previousGroupingExpressions: Seq[NamedExpression], groupingExpressions: Seq[NamedExpression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala index 1b89edafa8dad..666f8429068dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala @@ -34,90 +34,13 @@ private[sql] abstract class SortAggregationIterator( newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), inputAttributes: Seq[Attribute], inputIter: Iterator[InternalRow]) - extends Iterator[InternalRow] { - - /////////////////////////////////////////////////////////////////////////// - // Static fields for this iterator - /////////////////////////////////////////////////////////////////////////// - - protected val aggregateFunctions: Array[AggregateFunction2] = { - var bufferOffset = initialBufferOffset - val functions = new Array[AggregateFunction2](aggregateExpressions.length) - var i = 0 - while (i < aggregateExpressions.length) { - val func = aggregateExpressions(i).aggregateFunction - val funcWithBoundReferences = aggregateExpressions(i).mode match { - case Partial | Complete if !func.isInstanceOf[AlgebraicAggregate] => - // We need to create BoundReferences if the function is not an - // AlgebraicAggregate (it does not support code-gen) and the mode of - // this function is Partial or Complete because we will call eval of this - // function's children in the update method of this aggregate function. - // Those eval calls require BoundReferences to work. - BindReferences.bindReference(func, inputAttributes) - case _ => func - } - // Set bufferOffset for this function. It is important that setting bufferOffset - // happens after all potential bindReference operations because bindReference - // will create a new instance of the function. - funcWithBoundReferences.bufferOffset = bufferOffset - bufferOffset += funcWithBoundReferences.bufferSchema.length - functions(i) = funcWithBoundReferences - i += 1 - } - functions - } - - // Positions of those non-algebraic aggregate functions in aggregateFunctions. - // For example, we have func1, func2, func3, func4 in aggregateFunctions, and - // func2 and func3 are non-algebraic aggregate functions. - // nonAlgebraicAggregateFunctionPositions will be [1, 2]. - protected val nonAlgebraicAggregateFunctionPositions: Array[Int] = { - val positions = new ArrayBuffer[Int]() - var i = 0 - while (i < aggregateFunctions.length) { - aggregateFunctions(i) match { - case agg: AlgebraicAggregate => - case _ => positions += i - } - i += 1 - } - positions.toArray - } - - // All non-algebraic aggregate functions. - protected val nonAlgebraicAggregateFunctions: Array[AggregateFunction2] = - nonAlgebraicAggregateFunctionPositions.map(aggregateFunctions) - - // This is used to project expressions for the grouping expressions. - protected val groupGenerator = - newMutableProjection(groupingExpressions, inputAttributes)() - - // The underlying buffer shared by all aggregate functions. - protected val buffer: MutableRow = { - // The number of elements of the underlying buffer of this operator. - // All aggregate functions are sharing this underlying buffer and they find their - // buffer values through bufferOffset. - var size = initialBufferOffset - var i = 0 - while (i < aggregateFunctions.length) { - size += aggregateFunctions(i).bufferSchema.length - i += 1 - } - new GenericMutableRow(size) - } - - protected val joinedRow = new JoinedRow - - protected val placeholderExpressions = Seq.fill(initialBufferOffset)(NoOp) - - // This projection is used to initialize buffer values for all AlgebraicAggregates. - protected val algebraicInitialProjection = { - val initExpressions = placeholderExpressions ++ aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.initialValues - case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) - } - newMutableProjection(initExpressions, Nil)().target(buffer) - } + extends AggregationIterator( + groupingExpressions, + aggregateExpressions, + newMutableProjection, + inputAttributes, + inputIter) + with Iterator[InternalRow] { /////////////////////////////////////////////////////////////////////////// // Mutable states @@ -136,16 +59,6 @@ private[sql] abstract class SortAggregationIterator( // Private methods /////////////////////////////////////////////////////////////////////////// - /** Initializes buffer values for all aggregate functions. */ - protected def initializeBuffer(): Unit = { - algebraicInitialProjection(EmptyRow) - var i = 0 - while (i < nonAlgebraicAggregateFunctions.length) { - nonAlgebraicAggregateFunctions(i).initialize(buffer) - i += 1 - } - } - protected def initialize(): Unit = { if (inputIter.hasNext) { initializeBuffer() @@ -218,8 +131,6 @@ private[sql] abstract class SortAggregationIterator( // Methods that need to be implemented /////////////////////////////////////////////////////////////////////////// - protected def initialBufferOffset: Int - protected def processRow(row: InternalRow): Unit protected def generateOutput(): InternalRow diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index 5bbe6c162ff4b..3a881f2fa5858 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -114,7 +114,7 @@ object Utils { expr.collect { case agg: AggregateExpression2 => agg } - }.toSet.toSeq + } val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct) val hasMultipleDistinctColumnSets = if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { @@ -178,6 +178,7 @@ object Utils { aggregateExpressions: Seq[AggregateExpression2], aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute], resultExpressions: Seq[NamedExpression], + useHybridAggregate: Boolean, child: SparkPlan): Seq[SparkPlan] = { // 1. Create an Aggregate Operator for partial aggregations. val namedGroupingExpressions = groupingExpressions.map { @@ -195,7 +196,20 @@ object Utils { val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg => agg.aggregateFunction.bufferAttributes } - val partialAggregate = + + var useHybridAggregateNew = useHybridAggregate + if (aggregateExpressions.length == 0) { + useHybridAggregateNew = false + } + val partialAggregate = if (useHybridAggregateNew) { + Aggregate2Hybrid( + None: Option[Seq[Expression]], + namedGroupingExpressions.map(_._2), + partialAggregateExpressions, + partialAggregateAttributes, + namedGroupingAttributes ++ partialAggregateAttributes, + child) + } else { Aggregate2Sort( None: Option[Seq[Expression]], namedGroupingExpressions.map(_._2), @@ -203,6 +217,7 @@ object Utils { partialAggregateAttributes, namedGroupingAttributes ++ partialAggregateAttributes, child) + } // 2. Create an Aggregate Operator for final aggregations. val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) @@ -222,13 +237,23 @@ object Utils { }.getOrElse(expression) }.asInstanceOf[NamedExpression] } - val finalAggregate = Aggregate2Sort( - Some(namedGroupingAttributes), - namedGroupingAttributes, - finalAggregateExpressions, - finalAggregateAttributes, - rewrittenResultExpressions, - partialAggregate) + val finalAggregate = if (useHybridAggregateNew) { + Aggregate2Hybrid( + Some(namedGroupingAttributes), + namedGroupingAttributes, + finalAggregateExpressions, + finalAggregateAttributes, + rewrittenResultExpressions, + partialAggregate) + } else { + Aggregate2Sort( + Some(namedGroupingAttributes), + namedGroupingAttributes, + finalAggregateExpressions, + finalAggregateAttributes, + rewrittenResultExpressions, + partialAggregate) + } finalAggregate :: Nil } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index cd386b7a3ecf9..a0e9858d43210 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -21,9 +21,10 @@ import org.scalatest.BeforeAndAfterAll import java.sql.Timestamp +import org.apache.spark.SparkEnv import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.errors.DialectException -import org.apache.spark.sql.execution.aggregate.Aggregate2Sort +import org.apache.spark.sql.execution.aggregate.{Aggregate2Hybrid, Aggregate2Sort} import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ @@ -305,6 +306,105 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } } + test("aggregation with hybridAggregate") { + val originalValue = sqlContext.conf.codegenEnabled + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) + val originalUseAggregate2 = sqlContext.conf.useSqlAggregate2 + sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2, true) + val originalUseHybridAggregate = sqlContext.conf.useHybridAggregate + sqlContext.setConf(SQLConf.USE_HYBRID_AGGREGATE, true) + SparkEnv.get.conf.set("spark.test.aggregate.spillFrequency","10") + // Prepare a table that we can group some rows. + sqlContext.table("testData") + .unionAll(sqlContext.table("testData")) + .unionAll(sqlContext.table("testData")) + .registerTempTable("testData3x") + + def testHybridAggregate(sqlText: String, expectedResults: Seq[Row]): Unit = { + val df = sql(sqlText) + // First, check if we have GeneratedAggregate. + var hasAggregate2Hybrid = false + df.queryExecution.executedPlan.foreach { + case newAggregate: Aggregate2Hybrid => hasAggregate2Hybrid = true + case _ => + } + if (!hasAggregate2Hybrid) { + fail( + s""" + |Codegen is enabled, but query $sqlText does not have Aggregate2Hybrid in the plan. + |${df.queryExecution.simpleString} + """.stripMargin) + } + // Then, check results. + checkAnswer(df, expectedResults) + } + + try { + // COUNT + testHybridAggregate( + "SELECT key, count(value) FROM testData3x GROUP BY key", + (1 to 100).map(i => Row(i, 3))) + testHybridAggregate( + "SELECT count(key) FROM testData3x", + Row(300) :: Nil) + // SUM + testHybridAggregate( + "SELECT value, sum(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 3 * i))) + testHybridAggregate( + "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x", + Row(5050 * 3, 5050 * 3.0) :: Nil) + // AVERAGE + testHybridAggregate( + "SELECT value, avg(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testHybridAggregate( + "SELECT avg(key) FROM testData3x", + Row(50.5) :: Nil) + // MAX + testHybridAggregate( + "SELECT value, max(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testHybridAggregate( + "SELECT max(key) FROM testData3x", + Row(100) :: Nil) + // MIN + testHybridAggregate( + "SELECT value, min(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testHybridAggregate( + "SELECT min(key) FROM testData3x", + Row(1) :: Nil) + // Some combinations. + testHybridAggregate( + """ + |SELECT + | value, + | sum(key), + | max(key), + | min(key), + | avg(key), + | count(key) + |FROM testData3x + |GROUP BY value + """.stripMargin, + (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3))) + testHybridAggregate( + "SELECT max(key), min(key), avg(key), count(key) FROM testData3x", + Row(100, 1, 50.5, 300) :: Nil) + // Aggregate with Code generation handling all null values + testHybridAggregate( + "SELECT sum('a'), avg('a'), count(null) FROM testData", + Row(null, null, 0) :: Nil) + } finally { + sqlContext.dropTempTable("testData3x") + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue) + sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2, originalUseAggregate2) + sqlContext.setConf(SQLConf.USE_HYBRID_AGGREGATE, originalUseHybridAggregate) + SparkEnv.get.conf.set("spark.test.aggregate.spillFrequency","0") + } + } + test("Add Parser of SQL COALESCE()") { checkAnswer( sql("""SELECT COALESCE(1, 2)"""),