From 4d7a3102d1627eeb877128bce4233c43991f5a13 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Thu, 16 Jul 2015 00:27:42 +0800 Subject: [PATCH 1/5] init commit for external aggregation --- .../expressions/UnsafeRowLocation.java | 199 ++++ .../sql/execution/UnsafeAppendOnlyMap.java | 499 +++++++++ .../execution/UnsafeExternalAggregation.java | 989 ++++++++++++++++++ .../sql/catalyst/expressions/Projection.scala | 18 + .../sql/catalyst/expressions/package.scala | 18 - .../sql/execution/GeneratedAggregate.scala | 173 ++- .../UnsafeExternalAggregationSuite.scala | 120 +++ 7 files changed, 1963 insertions(+), 53 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowLocation.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeAppendOnlyMap.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalAggregation.java create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalAggregationSuite.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowLocation.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowLocation.java new file mode 100644 index 0000000000000..56dacb7137b63 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowLocation.java @@ -0,0 +1,199 @@ +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.MemoryLocation; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +/** + * memory location of a pair + */ +public class UnsafeRowLocation { + + private final TaskMemoryManager memoryManager; + + /** + * An index into the hash map's Long array + */ + private int pos; + /** + * True if this location points to a position where a key is defined, false otherwise + */ + private boolean isDefined; + /** + * The hashcode of the most recent key passed to caching this hashcode here allows us to + * avoid re-hashing the key when storing a value for that key. + */ + private int keyHashcode; + + private final MemoryLocation keyMemoryLocation = new MemoryLocation(); + private final MemoryLocation valueMemoryLocation = new MemoryLocation(); + private int keyLength; + private int valueLength; + + public UnsafeRowLocation(TaskMemoryManager memoryManager) { + this.memoryManager = memoryManager; + } + + public UnsafeRowLocation with(long fullKeyAddress) { + this.isDefined = true; + updateAddressesAndSizes(fullKeyAddress); + return this; + } + + public UnsafeRowLocation with(Object page, long offsetInPage) { + this.isDefined = true; + updateAddressesAndSizes(page, offsetInPage); + return this; + } + + public UnsafeRowLocation with(int keyLength, byte[] keyArray, int valueLength, + byte[] valueArray) { + this.isDefined = true; + this.keyLength = keyLength; + keyMemoryLocation.setObjAndOffset(keyArray, PlatformDependent.BYTE_ARRAY_OFFSET); + this.valueLength = valueLength; + valueMemoryLocation.setObjAndOffset(valueArray, PlatformDependent.BYTE_ARRAY_OFFSET); + return this; + } + + public UnsafeRowLocation with(int pos, int keyHashcode, boolean isDefined, long fullKeyAddress) { + this.pos = pos; + this.isDefined = isDefined; + this.keyHashcode = keyHashcode; + if (isDefined) { + updateAddressesAndSizes(fullKeyAddress); + } + return this; + } + + public void updateAddressesAndSizes(long fullKeyAddress) { + updateAddressesAndSizes( + memoryManager.getPage(fullKeyAddress), + memoryManager.getOffsetInPage(fullKeyAddress)); + } + + private void updateAddressesAndSizes(Object page, long offsetInPage) { + long position = offsetInPage; + keyLength = (int) PlatformDependent.UNSAFE.getLong(page, position); + position += 8; // word used to store the key size + keyMemoryLocation.setObjAndOffset(page, position); + position += keyLength; + valueLength = (int) PlatformDependent.UNSAFE.getLong(page, position); + position += 8; // word used to store the key size + valueMemoryLocation.setObjAndOffset(page, position); + } + + /** + * Returns true if the key is defined at this position, and false otherwise. + */ + public boolean isDefined() { + return isDefined; + } + + /** + * Set whether the key is defined. + */ + public void setDefined(boolean isDefined) { + this.isDefined = isDefined; + } + + /** + * Returns the hashcode of the key. + */ + public long getKeyHashcode() { + return this.keyHashcode; + } + + /** + * Set the hashcode of the key. + */ + public void setKeyHashcode(int keyHashcode) { + this.keyHashcode = keyHashcode; + } + + /** + * Set an index into the hash map's Long array. + */ + public void setPos(int pos) { + this.pos = pos; + } + + /** + * Returns the index into the hash map's Long array. + */ + public int getPos() { + return this.pos; + } + + /** + * Returns the address of the key defined at this position. + * This points to the first byte of the key data. + * Unspecified behavior if the key is not defined. + * For efficiency reasons, calls to this method always returns the same MemoryLocation object. + */ + public MemoryLocation getKeyAddress() { + assert (isDefined); + return keyMemoryLocation; + } + + /** + * Returns the base object of the key. + */ + public Object getKeyBaseObject() { + assert (isDefined); + return keyMemoryLocation.getBaseObject(); + } + + /** + * Returns the base offset of the key. + */ + public long getKeyBaseOffset() { + assert (isDefined); + return keyMemoryLocation.getBaseOffset(); + } + + /** + * Returns the length of the key defined at this position. + * Unspecified behavior if the key is not defined. + */ + public int getKeyLength() { + assert (isDefined); + return keyLength; + } + + /** + * Returns the address of the value defined at this position. + * This points to the first byte of the value data. + * Unspecified behavior if the key is not defined. + * For efficiency reasons, calls to this method always returns the same MemoryLocation object. + */ + public MemoryLocation getValueAddress() { + assert (isDefined); + return valueMemoryLocation; + } + + /** + * Returns the base object of the value. + */ + public Object getValueBaseObject() { + assert (isDefined); + return valueMemoryLocation.getBaseObject(); + } + + /** + * Return the base offset of the value. + */ + public long getValueBaseOffset() { + assert (isDefined); + return valueMemoryLocation.getBaseOffset(); + } + + /** + * Returns the length of the value defined at this position. + * Unspecified behavior if the key is not defined. + */ + public int getValueLength() { + assert (isDefined); + return valueLength; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeAppendOnlyMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeAppendOnlyMap.java new file mode 100644 index 0000000000000..daed2e8b767d3 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeAppendOnlyMap.java @@ -0,0 +1,499 @@ +package org.apache.spark.sql.execution; + +import java.util.Comparator; +import java.util.Iterator; + +import scala.math.Ordering; + +import com.google.common.annotations.VisibleForTesting; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.catalyst.expressions.UnsafeRowLocation; +import org.apache.spark.sql.catalyst.util.ObjectPool; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.bitset.BitSet; +import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.map.HashMapGrowthStrategy; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.MemoryLocation; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.collection.SortDataFormat; +import org.apache.spark.util.collection.Sorter; + +/** + * An append-only hash map where keys and values are contiguous regions of bytes. + *

+ * This is backed by a power-of-2-sized hash table, using quadratic probing with triangular numbers, + * which is guaranteed to exhaust the space. + *

+ * The map can support up to 2^29 keys. If the key cardinality is higher than this, you should + * probably be using sorting instead of hashing for better cache locality. + *

+ * This class is not thread safe. + */ +public class UnsafeAppendOnlyMap { + + /** + * The maximum number of keys that BytesToBytesMap supports. The hash table has to be + * power-of-2-sized and its backing Java array can contain at most (1 << 30) elements, since + * that's the largest power-of-2 that's less than Integer.MAX_VALUE. We need two long array + * entries per key, giving us a maximum capacity of (1 << 29). + */ + @VisibleForTesting + static final int MAX_CAPACITY = (1 << 29); + + private static final Murmur3_x86_32 HASHER = new Murmur3_x86_32(0); + + private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING; + + private final TaskMemoryManager memoryManager; + + /** + * A single array to store the key and value. + *

+ * Position {@code 2 * i} in the array is used to track a pointer to the key at index {@code i}, + * while position {@code 2 * i + 1} in the array holds key's full 32-bit hashcode. + */ + private long[] longArray; + + /** + * A {@link org.apache.spark.unsafe.bitset.BitSet} used to track location of the map where the key is set. + * Size of the bitset should be half of the size of the long array. + */ + private BitSet bitset; + + private final double loadFactor; + + /** + * Number of keys defined in the map. + */ + private int size; + + /** + * The map will be expanded once the number of keys exceeds this threshold. + */ + private int growthThreshold; + + private int capacity; + + /** + * Mask for truncating hashcodes so that they do not exceed the long array's size. + * This is a strength reduction optimization; we're essentially performing a modulus operation, + * but doing so with a bitmask because this is a power-of-2-sized hash map. + */ + private int mask; + + /** + * Return value of {@link UnsafeAppendOnlyMap#lookup(Object, long, int)}. + */ + private final UnsafeRowLocation loc; + + private final boolean enablePerfMetrics; + private long timeSpentResizingNs = 0; + + private long numKeyLookups = 0; + private long numProbes = 0; + private long numHashCollisions = 0; + + public UnsafeAppendOnlyMap( + TaskMemoryManager memoryManager, + int initialCapacity, + double loadFactor, + boolean enablePerfMetrics) { + this.memoryManager = memoryManager; + this.loadFactor = loadFactor; + this.loc = new UnsafeRowLocation(memoryManager); + this.enablePerfMetrics = enablePerfMetrics; + + if (initialCapacity <= 0) { + throw new IllegalArgumentException("Initial capacity must be greater than 0"); + } + if (initialCapacity > MAX_CAPACITY) { + throw new IllegalArgumentException("Initial capacity " + initialCapacity + + " exceeds maximum capacity of " + MAX_CAPACITY); + } + + this.capacity = initialCapacity; + allocate(initialCapacity); + } + + public UnsafeAppendOnlyMap( + TaskMemoryManager memoryManager, + int initialCapacity, + boolean enablePerfMetrics) { + this(memoryManager, initialCapacity, 0.70, enablePerfMetrics); + } + + private static final class SortComparator implements Comparator { + + private final TaskMemoryManager memoryManager; + private final Ordering ordering; + private final int numFields; + private final ObjectPool objPool; + private final UnsafeRow row1 = new UnsafeRow(); + private final UnsafeRow row2 = new UnsafeRow(); + + SortComparator(TaskMemoryManager memoryManager, + Ordering ordering, int numFields, ObjectPool objPool) { + this.memoryManager = memoryManager; + this.numFields = numFields; + this.ordering = ordering; + this.objPool = objPool; + } + + @Override + public int compare(Long fullKeyAddress1, Long fullKeyAddress2) { + final Object baseObject1 = memoryManager.getPage(fullKeyAddress1); + final long baseOffset1 = memoryManager.getOffsetInPage(fullKeyAddress1) + 8; + + final Object baseObject2 = memoryManager.getPage(fullKeyAddress2); + final long baseOffset2 = memoryManager.getOffsetInPage(fullKeyAddress2) + 8; + + row1.pointTo(baseObject1, baseOffset1, numFields, -1, objPool); + row2.pointTo(baseObject2, baseOffset2, numFields, -1, objPool); + return ordering.compare(row1, row2); + } + } + + private static final class KVArraySortDataFormat extends SortDataFormat { + + @Override + public Long getKey(long[] data, int pos) { + return data[2 * pos]; + } + + @Override + public Long newKey() { + return 0L; + } + + @Override + public void swap(long[] data,int pos0, int pos1) { + long tmpKey = data[2 * pos0]; + long tmpVal = data[2 * pos0 + 1]; + data[2 * pos0] = data[2 * pos1]; + data[2 * pos0 + 1] = data[2 * pos1 + 1]; + data[2 * pos1] = tmpKey; + data[2 * pos1 + 1] = tmpVal; + } + + @Override + public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) { + dst[2 * dstPos] = src[2 * srcPos]; + dst[2 * dstPos + 1] = src[2 * srcPos + 1]; + } + + @Override + public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) { + System.arraycopy(src, 2 * srcPos, dst, 2 * dstPos, 2 * length); + } + + @Override + public long[] allocate(int length) { + return new long[2 * length]; + } + } + + public Iterator getSortedIterator(Ordering ordering, + int numFields, ObjectPool objPool) { + // Pack KV pairs into the front of the underlying array + int keyIndex = 0; + int newIndex = 0; + while (keyIndex < capacity) { + if (bitset.isSet(keyIndex)) { + longArray[2 * newIndex] = longArray[2 * keyIndex]; + longArray[2 * newIndex + 1] = longArray[2 * keyIndex + 1]; + newIndex += 1; + } + keyIndex += 1; + } + Comparator sortComparator = new SortComparator(this.memoryManager, + ordering, numFields, objPool); + Sorter sorter = new Sorter<>(new KVArraySortDataFormat()); + sorter.sort(longArray, 0, newIndex, sortComparator); + return new UnsafeMapSorterIterator(newIndex, longArray, this.loc); + } + + /** + * Iterate through the data and memory location of records are returned in order of the key. + */ + public static final class UnsafeMapSorterIterator implements Iterator { + + private final long[] pointerArray; + private final int numRecords; + private int currentRecordNumber = 0; + private final UnsafeRowLocation loc; + + public UnsafeMapSorterIterator(int numRecords, long[] pointerArray, + UnsafeRowLocation loc) { + this.numRecords = numRecords; + this.pointerArray = pointerArray; + this.loc = loc; + } + + @Override + public boolean hasNext() { + return currentRecordNumber != numRecords; + } + + @Override + public UnsafeRowLocation next() { + loc.with(pointerArray[currentRecordNumber * 2]); + currentRecordNumber++; + return loc; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + } + + /** + * Allocate new data structures for this map. When calling this outside of the constructor, + * make sure to keep references to the old data structures so that you can free them. + * + * @param capacity the new map capacity + */ + private void allocate(int capacity) { + assert (capacity >= 0); + assert (capacity <= MAX_CAPACITY); + + longArray = new long[capacity * 2]; + bitset = new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64])); + + this.growthThreshold = (int) (capacity * loadFactor); + this.mask = capacity - 1; + } + + public boolean hasSpaceForAnotherRecord() { + return size < growthThreshold; + } + + /** + * Returns the total amount of memory, in bytes, consumed by this map's managed structures. + */ + public long getMemoryUsage() { + return capacity * 8L * 2 + capacity / 8; + } + + public static int getCapacity(int initialCapacity) { + // The capacity needs to be divisible by 64 so that our bit set can be sized properly + return Math.max((int) Math.min(MAX_CAPACITY, nextPowerOf2(initialCapacity)), 64); + } + + public static long getMemoryUsage(int capacity) { + return capacity * 8L * 2 + capacity / 8; + } + + public int getNextCapacity() { + int nextCapacity = + Math.min(growthStrategy.nextCapacity(capacity), MAX_CAPACITY); + // The capacity needs to be divisible by 64 so that our bit set can be sized properly + return Math.max((int) Math.min(MAX_CAPACITY, nextPowerOf2(nextCapacity)), 64); + } + + public int numRecords() { + return size; + } + + /** + * Free all allocated memory associated with this map, including the storage for keys and values + * as well as the hash map array itself. + *

+ * This method is idempotent. + */ + public void free() { + if (longArray != null) { + longArray = null; + } + if (bitset != null) { + // The bitset's heap memory isn't managed by a memory manager, so no need to free it here. + bitset = null; + } + } + + /** + * Returns the total amount of time spent resizing this map (in nanoseconds). + */ + public long getTimeSpentResizingNs() { + if (!enablePerfMetrics) { + throw new IllegalStateException(); + } + return timeSpentResizingNs; + } + + /** + * Returns the average number of probes per key lookup. + */ + public double getAverageProbesPerLookup() { + if (!enablePerfMetrics) { + throw new IllegalStateException(); + } + return (1.0 * numProbes) / numKeyLookups; + } + + public long getNumHashCollisions() { + if (!enablePerfMetrics) { + throw new IllegalStateException(); + } + return numHashCollisions; + } + + /** + * Looks up a key, and return a UnsafeRowLocation handle that can be used to test existence + * and read/write values. + *

+ * This function always return the same UnsafeRowLocation instance to avoid object allocation. + */ + public UnsafeRowLocation lookup( + Object keyBaseObject, + long keyBaseOffset, + int keyRowLengthBytes) { + if (enablePerfMetrics) { + numKeyLookups++; + } + final int hashcode = HASHER.hashUnsafeWords(keyBaseObject, keyBaseOffset, keyRowLengthBytes); + int pos = hashcode & mask; + int step = 1; + while (true) { + if (enablePerfMetrics) { + numProbes++; + } + if (!bitset.isSet(pos)) { + // This is a new key. + return loc.with(pos, hashcode, false, longArray[pos * 2]); + } else { + long stored = longArray[pos * 2 + 1]; + if ((int) (stored) == hashcode) { + // Full hash code matches. Let's compare the keys for equality. + loc.with(pos, hashcode, true, longArray[pos * 2]); + if (loc.getKeyLength() == keyRowLengthBytes) { + final MemoryLocation keyAddress = loc.getKeyAddress(); + final Object storedKeyBaseObject = keyAddress.getBaseObject(); + final long storedKeyBaseOffset = keyAddress.getBaseOffset(); + final boolean areEqual = ByteArrayMethods.arrayEquals( + keyBaseObject, + keyBaseOffset, + storedKeyBaseObject, + storedKeyBaseOffset, + keyRowLengthBytes + ); + if (areEqual) { + return loc; + } else { + if (enablePerfMetrics) { + numHashCollisions++; + } + } + } + } + } + pos = (pos + step) & mask; + step++; + } + } + + /** + * Store a new key and value. This method may only be called once for a given key; if you want + * to update the value associated with a key, then you can directly manipulate the bytes stored + * at the value address. + *

+ * It is only valid to call this method immediately after calling `lookup()` using the same key. + *

+ * The key and value must be word-aligned (that is, their sizes must multiples of 8). + *

+ * After calling this method, calls to `get[Key|Value]Address()` and `get[Key|Value]Length` + * will return information on the data stored by this `putNewKey` call. + *

+ * As an example usage, here's the proper way to store a new key: + *

+ *

+   *   Location loc = map.lookup(keyBaseObject, keyBaseOffset, keyLengthInBytes);
+   *   if (!loc.isDefined()) {
+   *     loc.putNewKey(keyBaseObject, keyBaseOffset, keyLengthInBytes, ...)
+   *   }
+   * 
+ *

+ * Unspecified behavior if the key is not defined. + */ + public void putNewKey( + long storedKeyAddress, + UnsafeRowLocation location) { + if (size == MAX_CAPACITY) { + throw new IllegalStateException("BytesToBytesMap has reached maximum capacity"); + } + size++; + bitset.set(location.getPos()); + + longArray[location.getPos() * 2] = storedKeyAddress; + longArray[location.getPos() * 2 + 1] = location.getKeyHashcode(); + location.updateAddressesAndSizes(storedKeyAddress); + location.setDefined(true); + if (size > growthThreshold && longArray.length < MAX_CAPACITY) { + growAndRehash(); + } + } + + /** + * Grows the size of the hash table and re-hash everything. + */ + @VisibleForTesting + public void growAndRehash() { + long resizeStartTime = -1; + if (enablePerfMetrics) { + resizeStartTime = System.nanoTime(); + } + // Store references to the old data structures to be used when we re-hash + final long[] oldLongArray = longArray; + final BitSet oldBitSet = bitset; + final int oldCapacity = (int) oldBitSet.capacity(); + + int nextCapacity = + Math.min(growthStrategy.nextCapacity(oldCapacity), MAX_CAPACITY); + // The capacity needs to be divisible by 64 so that our bit set can be sized properly + capacity = + Math.max((int) Math.min(MAX_CAPACITY, nextPowerOf2(nextCapacity)), 64); + // Allocate the new data structures + allocate(capacity); + + // Re-mask (we don't recompute the hashcode because we stored all 32 bits of it) + for (int pos = oldBitSet.nextSetBit(0); + pos >= 0; pos = oldBitSet.nextSetBit(pos + 1)) { + final long keyPointer = oldLongArray[pos * 2]; + final int hashcode = (int) oldLongArray[pos * 2 + 1]; + int newPos = hashcode & mask; + int step = 1; + boolean keepGoing = true; + + // No need to check for equality here when we insert so this has one less if branch than + // the similar code path in addWithoutResize. + while (keepGoing) { + if (!bitset.isSet(newPos)) { + bitset.set(newPos); + longArray[newPos * 2] = keyPointer; + longArray[newPos * 2 + 1] = hashcode; + keepGoing = false; + } else { + newPos = (newPos + step) & mask; + step++; + } + } + } + + // Deallocate the old data structures. + //memoryManager.free(oldLongArray.memoryBlock()); + if (enablePerfMetrics) { + timeSpentResizingNs += System.nanoTime() - resizeStartTime; + } + } + + /** + * Returns the next number greater or equal num that is power of 2. + */ + private static long nextPowerOf2(long num) { + final long highBit = Long.highestOneBit(num); + return (highBit == num) ? num : highBit << 1; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalAggregation.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalAggregation.java new file mode 100644 index 0000000000000..2a2fa5a11ffe3 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalAggregation.java @@ -0,0 +1,989 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution; + +import java.io.*; +import java.util.*; + +import scala.Function1; +import scala.Tuple2; +import scala.math.Ordering; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.io.ByteStreams; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.SparkEnv; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.DummySerializerInstance; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.*; +import org.apache.spark.sql.catalyst.util.ObjectPool; +import org.apache.spark.sql.catalyst.util.UniqueObjectPool; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.BlockObjectWriter; +import org.apache.spark.storage.TempLocalBlockId; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.MemoryLocation; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.Utils; + +/** + * Unsafe sort-based external aggregation. + */ +public final class UnsafeExternalAggregation { + + private final Logger logger = LoggerFactory.getLogger(UnsafeExternalAggregation.class); + + /** + * Special record length that is placed after the last record in a data page. + */ + private static final int END_OF_PAGE_MARKER = -1; + + private final TaskMemoryManager memoryManager; + private final ShuffleMemoryManager shuffleMemoryManager; + private final BlockManager blockManager; + private final TaskContext taskContext = TaskContext.get(); + private ShuffleWriteMetrics writeMetrics; + + /** + * The buffer size to use when writing spills using DiskBlockObjectWriter + */ + private final int fileBufferSizeBytes; + + /** + * A linked list for tracking all allocated data pages so that we can free all of our memory. + */ + private final List dataPages = new LinkedList(); + + /** + * The data page that will be used to store keys and values for new hashtable entries. When this + * page becomes full, a new page will be allocated and this pointer will change to point to that + * new page. + */ + private MemoryBlock currentDataPage = null; + + /** + * Offset into `currentDataPage` that points to the location where new data can be inserted into + * the page. This does not incorporate the page's base offset. + */ + private long pageCursor = 0; + + private long freeSpaceInCurrentPage = 0; + + /** + * The size of the data pages that hold key and value data. Map entries cannot span multiple + * pages, so this limits the maximum entry size. + */ + private static final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes + + /** + * An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the + * map, we copy this buffer and use it as the value. + */ + private byte[] emptyBuffer; + + /** + * An empty row used by `initProjection` + */ + private static final InternalRow emptyRow = new GenericInternalRow(); + + /** + * Whether can the empty aggregation buffer be reuse without calling `initProjection` or not. + */ + private boolean reuseEmptyBuffer; + + /** + * The projection used to initialize the emptyBuffer + */ + private final Function1 initProjection; + + private final MutableProjection updateProjection; + private final MutableProjection mergeProjection; + + /** + * Encodes grouping keys or buffers as UnsafeRows. + */ + private final UnsafeRowConverter keyConverter; + private final UnsafeRowConverter bufferConverter; + private final Ordering groupingKeyOrdering; + private final int groupingKeyNum; + private final int aggregationBufferNum; + + private final int initialCapacity; + + /** + * A hashmap which maps from opaque bytearray keys to bytearray values. + */ + private UnsafeAppendOnlyMap map; + + /** + * An object pool for objects that are used in grouping keys. + */ + private UniqueObjectPool keyPool; + + /** + * An object pool for objects that are used in aggregation buffers. + */ + private ObjectPool bufferPool; + + /** + * Re-used pointer to the current aggregation buffer + */ + private final UnsafeRow currentBuffer = new UnsafeRow(); + + /** + * Scratch space that is used when encoding grouping keys into UnsafeRow format. + *

+ * By default, this is a 8 kb array, but it will grow as necessary in case larger keys are + * encountered. + */ + private byte[] groupingKeyConversionScratchSpace = new byte[1024 * 8]; + + private boolean enablePerfMetrics; + + private int testSpillFrequency = 0; + + private long numRowsInserted = 0; + + private final LinkedList spillWriters = new LinkedList<>(); + + /** + * Create a new UnsafeFixedWidthAggregationMap. + * + * @param initProjection the default value for new keys (a "zero" of the agg. function) + * @param updateProjection update values for the same key + * @param mergeProjection merge values for the same key + * @param keyConverter the converter of the grouping key, used for row conversion. + * @param bufferConverter the converter of the aggregation buffer, used for row conversion. + * @param groupingKeyOrdering a comparator which sorts groupingkey. + * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing). + * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact) + */ + public UnsafeExternalAggregation( + Function1 initProjection, + MutableProjection updateProjection, + MutableProjection mergeProjection, + UnsafeRowConverter keyConverter, + UnsafeRowConverter bufferConverter, + Ordering groupingKeyOrdering, + int initialCapacity, + boolean enablePerfMetrics) throws IOException { + this.initProjection = initProjection; + this.updateProjection = updateProjection; + this.mergeProjection = mergeProjection; + this.keyConverter = keyConverter; + this.bufferConverter = bufferConverter; + this.groupingKeyOrdering = groupingKeyOrdering; + this.groupingKeyNum = keyConverter.numFields(); + this.aggregationBufferNum = bufferConverter.numFields(); + this.initialCapacity = initialCapacity; + this.enablePerfMetrics = enablePerfMetrics; + + this.memoryManager = TaskContext.get().taskMemoryManager(); + final SparkEnv sparkEnv = SparkEnv.get(); + this.shuffleMemoryManager = sparkEnv.shuffleMemoryManager(); + this.blockManager = sparkEnv.blockManager(); + + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided + this.fileBufferSizeBytes = + (int) sparkEnv.conf().getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + this.testSpillFrequency = sparkEnv.conf().getInt("spark.test.aggregate.spillFrequency", 0); + + initializeUnsafeAppendMap(); + } + + /** + * Allocates new sort data structures. Called when creating the sorter and after each spill. + */ + private void initializeUnsafeAppendMap() throws IOException { + this.writeMetrics = new ShuffleWriteMetrics(); + this.keyPool = new UniqueObjectPool(100); + this.bufferPool = new ObjectPool(initialCapacity); + + if (emptyBuffer == null) { + InternalRow initRow = initProjection.apply(emptyRow); + int sizeRequirement = bufferConverter.getSizeRequirement(initRow); + this.emptyBuffer = new byte[sizeRequirement]; + int writtenLength = bufferConverter.writeRow( + initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, sizeRequirement, null); + assert (writtenLength == emptyBuffer.length) : "Size requirement calculation was wrong!"; + // re-use the empty buffer only when there is no object saved in pool. + reuseEmptyBuffer = bufferPool.size() == 0; + } + + int capacity = UnsafeAppendOnlyMap.getCapacity(initialCapacity); + long memoryRequested = UnsafeAppendOnlyMap.getMemoryUsage(capacity); + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested); + if (memoryAcquired != memoryRequested) { + shuffleMemoryManager.release(memoryAcquired); + throw new IOException("Could not acquire " + memoryRequested + " bytes of memory"); + } + this.map = new UnsafeAppendOnlyMap(memoryManager, capacity, enablePerfMetrics); + } + + /** + * Return true if it can external aggregate with the groupKey schema & aggregationBuffer schema, + * false otherwise + */ + public static boolean supportSchema(StructType groupKeySchema, + StructType aggregationBufferSchema) { + for (StructField field : groupKeySchema.fields()) { + if (UnsafeColumnWriter.forType(field.dataType()) instanceof ObjectUnsafeColumnWriter) { + return false; + } + } + for (StructField field : aggregationBufferSchema.fields()) { + if (UnsafeColumnWriter.forType(field.dataType()) instanceof ObjectUnsafeColumnWriter) { + return false; + } + } + return true; + } + + /** + * Forces spills to occur every `frequency` records. Only for use in tests. + */ + @VisibleForTesting + public void setTestSpillFrequency(int frequency) { + assert frequency > 0 : "Frequency must be positive"; + testSpillFrequency = frequency; + } + + public void insertRow(InternalRow groupingKey, InternalRow currentRow) + throws IOException { + int groupingKeySize = keyConverter.getSizeRequirement(groupingKey); + // Make sure that the buffer is large enough to hold the key. If it's not, grow it: + if (groupingKeySize > groupingKeyConversionScratchSpace.length) { + groupingKeyConversionScratchSpace = new byte[groupingKeySize]; + } + numRowsInserted++; + if (testSpillFrequency > 0 && (numRowsInserted % testSpillFrequency) == 0) { + spill(); + } + UnsafeRow aggregationBuffer = this.getAggregationBuffer(groupingKey, groupingKeySize); + JoinedRow3 joinedRow = new JoinedRow3(aggregationBuffer, currentRow); + this.updateProjection.target(aggregationBuffer).apply(joinedRow); + } + + /** + * Return the aggregation buffer for the current group. For efficiency, all calls to this method + * return the same object. + */ + public UnsafeRow getAggregationBuffer(InternalRow groupingKey, int groupingKeySize) + throws IOException { + final int actualGroupingKeySize = keyConverter.writeRow( + groupingKey, + groupingKeyConversionScratchSpace, + PlatformDependent.BYTE_ARRAY_OFFSET, + groupingKeySize, + keyPool); + assert (groupingKeySize + == actualGroupingKeySize) : "Size requirement calculation was wrong!"; + + Object groupingKeyBaseObject = groupingKeyConversionScratchSpace; + // Probe our map using the serialized key + final UnsafeRowLocation loc = map.lookup( + groupingKeyBaseObject, + PlatformDependent.BYTE_ARRAY_OFFSET, + groupingKeySize); + if (!loc.isDefined()) { + // This is the first time that we've seen this grouping key, so we'll insert a copy of the + // empty aggregation buffer into the map: + if (!reuseEmptyBuffer) { + // There is some objects referenced by emptyBuffer, so generate a new one + InternalRow initRow = initProjection.apply(emptyRow); + bufferConverter + .writeRow(initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, groupingKeySize, + bufferPool); + } + if (!this.putNewKey( + groupingKeyBaseObject, + PlatformDependent.BYTE_ARRAY_OFFSET, + groupingKeySize, + emptyBuffer, + PlatformDependent.BYTE_ARRAY_OFFSET, + emptyBuffer.length, + loc)) { + // because spill makes putting new key failed, it should get AggregationBuffer again + return this.getAggregationBuffer(groupingKey, groupingKeySize); + } + } + // Reset the pointer to point to the value that we just stored or looked up: + final MemoryLocation address = loc.getValueAddress(); + currentBuffer.pointTo( + address.getBaseObject(), + address.getBaseOffset(), + bufferConverter.numFields(), + loc.getValueLength(), + bufferPool + ); + return currentBuffer; + } + + public boolean putNewKey( + Object keyBaseObject, + long keyBaseOffset, + int keyLengthBytes, + Object valueBaseObject, + long valueBaseOffset, + int valueLengthBytes, + UnsafeRowLocation location) throws IOException { + assert (!location.isDefined()) : "Can only set value once for a key"; + + assert (keyLengthBytes % 8 == 0); + assert (valueLengthBytes % 8 == 0); + + // Here, we'll copy the data into our data pages. Because we only store a relative offset from + // the key address instead of storing the absolute address of the value, the key and value + // must be stored in the same memory page. + // (8 byte key length) (key) (8 byte value length) (value) + final int requiredSize = 8 + keyLengthBytes + 8 + valueLengthBytes; + if (!haveSpaceForRecord(requiredSize)) { + if (!allocateSpaceForRecord(requiredSize)){ + // if spill have been happened, re-insert current groupingKey + return false; + } + } + + freeSpaceInCurrentPage -= requiredSize; + // Compute all of our offsets up-front: + final Object pageBaseObject = currentDataPage.getBaseObject(); + final long pageBaseOffset = currentDataPage.getBaseOffset(); + final long keySizeOffsetInPage = pageBaseOffset + pageCursor; + pageCursor += 8; // word used to store the key size + final long keyDataOffsetInPage = pageBaseOffset + pageCursor; + pageCursor += keyLengthBytes; + final long valueSizeOffsetInPage = pageBaseOffset + pageCursor; + pageCursor += 8; // word used to store the value size + final long valueDataOffsetInPage = pageBaseOffset + pageCursor; + pageCursor += valueLengthBytes; + + // Copy the key + PlatformDependent.UNSAFE.putLong(pageBaseObject, keySizeOffsetInPage, keyLengthBytes); + PlatformDependent.copyMemory( + keyBaseObject, keyBaseOffset, pageBaseObject, keyDataOffsetInPage, keyLengthBytes); + // Copy the value + PlatformDependent.UNSAFE.putLong(pageBaseObject, valueSizeOffsetInPage, valueLengthBytes); + PlatformDependent.copyMemory( + valueBaseObject, valueBaseOffset, pageBaseObject, valueDataOffsetInPage, valueLengthBytes); + + final long storedKeyAddress = memoryManager.encodePageNumberAndOffset( + currentDataPage, keySizeOffsetInPage); + this.map.putNewKey(storedKeyAddress, location); + return true; + } + + /** + * Checks whether there is enough space to insert a new record into the sorter. + * + * @param requiredSpace the required space in the data page, in bytes, including space for storing + * the record size. + + * @return true if the record can be inserted without requiring more allocations, false otherwise. + */ + private boolean haveSpaceForRecord(int requiredSpace) { + assert (requiredSpace > 0); + return (map.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage)); + } + + /** + * Allocates more memory in order to insert an additional record. This will request additional + * memory from the {@link org.apache.spark.shuffle.ShuffleMemoryManager} and + * spill if the requested memory can not be obtained. + * + * @param requiredSpace the required space in the data page, in bytes, including space for storing + * the record size. + */ + private boolean allocateSpaceForRecord(int requiredSpace) throws IOException { + boolean noSpill = true; + assert (requiredSpace <= PAGE_SIZE_BYTES - 8); // Reserve 8 bytes for the end-of-page marker. + if (!map.hasSpaceForAnotherRecord()) { + logger.debug("Attempting to grow size of hash table"); + final int nextCapacity = map.getNextCapacity(); + final long oldPointerArrayMemoryUsage = map.getMemoryUsage(); + final long memoryToGrowPointerArray = UnsafeAppendOnlyMap.getMemoryUsage(nextCapacity); + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray); + if (memoryAcquired < memoryToGrowPointerArray) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + } else { + map.growAndRehash(); + shuffleMemoryManager.release(oldPointerArrayMemoryUsage); + } + } + + // If there's not enough space in the current page, allocate a new page (8 bytes are reserved + // for the end-of-page marker). + if (currentDataPage == null || freeSpaceInCurrentPage < requiredSpace) { + logger.trace("Required space {} is less than free space in current page ({})", + requiredSpace, freeSpaceInCurrentPage); + if (currentDataPage != null) { + // There wasn't enough space in the current page, so write an end-of-page marker: + final Object pageBaseObject = currentDataPage.getBaseObject(); + final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor; + PlatformDependent.UNSAFE.putLong(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER); + } + long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE_BYTES); + if (memoryAcquired < PAGE_SIZE_BYTES) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + noSpill = false; + final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE_BYTES); + if (memoryAcquiredAfterSpilling != PAGE_SIZE_BYTES) { + shuffleMemoryManager.release(memoryAcquiredAfterSpilling); + throw new IOException("Unable to acquire " + PAGE_SIZE_BYTES + " bytes of memory"); + } else { + memoryAcquired = memoryAcquiredAfterSpilling; + } + } + MemoryBlock newPage = memoryManager.allocatePage(memoryAcquired); + dataPages.add(newPage); + pageCursor = 0; + freeSpaceInCurrentPage = PAGE_SIZE_BYTES - 8; + currentDataPage = newPage; + } + return noSpill; + } + + /** + * Sort and spill the current records in response to memory pressure. + */ + @VisibleForTesting + void spill() throws IOException { + logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", + Thread.currentThread().getId(), + Utils.bytesToString(getMemoryUsage()), + spillWriters.size(), + spillWriters.size() > 1 ? " times" : " time"); + + final UnsafeSorterKVSpillWriter spillWriter = + new UnsafeSorterKVSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, + map.numRecords()); + spillWriters.add(spillWriter); + final Iterator sortedRecords = map.getSortedIterator( + groupingKeyOrdering, groupingKeyNum, keyPool); + while (sortedRecords.hasNext()) { + UnsafeRowLocation location = sortedRecords.next(); + spillWriter.write(location); + } + spillWriter.close(); + + final long sorterMemoryUsage = map.getMemoryUsage(); + map = null; + shuffleMemoryManager.release(sorterMemoryUsage); + final long spillSize = freeMemory(); + taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); + + initializeUnsafeAppendMap(); + } + + public AbstractScalaIterator getSortedIterator() { + return new AbstractScalaIterator() { + + Iterator sorter = + map.getSortedIterator(groupingKeyOrdering, groupingKeyNum, keyPool); + + @Override + public boolean hasNext() { + return sorter.hasNext(); + } + + @Override + public UnsafeRowLocation next() { + return sorter.next(); + } + }; + } + + /** + * Return an iterator that merges the in-memory map with the spilled files. + * If no spill has occurred, simply return the in-memory map's iterator. + */ + public AbstractScalaIterator iterator() throws IOException { + if (spillWriters.isEmpty()) { + return this.getMemoryIterator(); + } else { + return this.merge(this.getSortedIterator()); + } + } + + /** + * Returns an iterator over the keys and values in in-memory map. + */ + public AbstractScalaIterator getMemoryIterator() { + return new UnsafeAppendOnlyMapIterator(map.numRecords()); + } + + /** + * Merge aggregate of the in-memory map with the spilled files, giving an iterator over elements. + */ + private AbstractScalaIterator merge(AbstractScalaIterator inMemory) + throws IOException { + + final Comparator keyOrdering = + new Comparator() { + private final UnsafeRow row1 = new UnsafeRow(); + private final UnsafeRow row2 = new UnsafeRow(); + + public int compare(BufferedIterator o1, BufferedIterator o2){ + row1.pointTo( + o1.getRecordLocation().getKeyBaseObject(), + o1.getRecordLocation().getKeyBaseOffset(), + groupingKeyNum, + o1.getRecordLocation().getKeyLength(), + null); + row2.pointTo( + o2.getRecordLocation().getKeyBaseObject(), + o2.getRecordLocation().getKeyBaseOffset(), + groupingKeyNum, + o2.getRecordLocation().getKeyLength(), + null); + return groupingKeyOrdering.compare(row1, row2); + } + }; + final Queue priorityQueue = + new PriorityQueue + (spillWriters.size() + 1, keyOrdering); + BufferedIterator inMemoryBuffer = this.asBuffered(inMemory); + if (inMemoryBuffer.hasNext()) { + inMemoryBuffer.loadNext(); + priorityQueue.add(inMemoryBuffer); + } + + for (int i = 0; i < spillWriters.size(); i++) { + BufferedIterator spillBuffer = this.asBuffered(spillWriters.get(i).getReader(blockManager)); + if (spillBuffer.hasNext()) { + spillBuffer.loadNext(); + priorityQueue.add(spillBuffer); + } + } + final AbstractScalaIterator mergeIter = + new AbstractScalaIterator() { + + BufferedIterator topIter = null; + + @Override + public boolean hasNext() { + return !priorityQueue.isEmpty() || (topIter != null && topIter.hasNext()); + } + + @Override + public UnsafeRowLocation next() throws IOException { + if (topIter != null && topIter.hasNext()) { + topIter.loadNext(); + priorityQueue.add(topIter); + } + topIter = priorityQueue.poll(); + return topIter.getRecordLocation(); + } + }; + + final BufferedIterator sorted = asBuffered(mergeIter); + return new AbstractScalaIterator() { + + private UnsafeRow currentKey = new UnsafeRow(); + private UnsafeRow currentValue = new UnsafeRow(); + private UnsafeRow nextKey = new UnsafeRow(); + private UnsafeRow nextValue = new UnsafeRow(); + private UnsafeRowLocation currentLocation = null; + + @Override + public boolean hasNext() { + return currentLocation != null || sorted.hasNext(); + } + + @Override + public MapEntry next() throws IOException { + try { + if (currentLocation == null) { + sorted.loadNext(); + currentLocation = sorted.getRecordLocation(); + } + currentKey.pointTo( + currentLocation.getKeyBaseObject(), + currentLocation.getKeyBaseOffset(), + groupingKeyNum, + currentLocation.getKeyLength(), + null); + currentKey = (UnsafeRow)currentKey.copy(); + currentValue.pointTo( + currentLocation.getValueBaseObject(), + currentLocation.getValueBaseOffset(), + aggregationBufferNum, + currentLocation.getValueLength(), + null); + currentValue = (UnsafeRow)currentValue.copy(); + currentLocation = null; + while (sorted.hasNext()) { + sorted.loadNext(); + UnsafeRowLocation nextLocation = sorted.getRecordLocation(); + nextKey.pointTo( + nextLocation.getKeyBaseObject(), + nextLocation.getKeyBaseOffset(), + groupingKeyNum, + nextLocation.getKeyLength(), + null); + nextValue.pointTo( + nextLocation.getValueBaseObject(), + nextLocation.getValueBaseOffset(), + aggregationBufferNum, + nextLocation.getValueLength(), + null); + + if (groupingKeyOrdering.compare(currentKey, nextKey) != 0) { + currentLocation = nextLocation; + break; + } + + JoinedRow3 joinedRow = new JoinedRow3(currentValue, nextValue); + mergeProjection.target(currentValue).apply(joinedRow); + } + + return new MapEntry(currentKey, currentValue); + } catch (IOException e) { + cleanupResources(); + // Scala iterators don't declare any checked exceptions, so we need to use this hack + // to re-throw the exception: + PlatformDependent.throwException(e); + } + throw new RuntimeException("Exception should have been re-thrown in next()"); + } + }; + } + + public BufferedIterator asBuffered(AbstractScalaIterator iterator) { + return new BufferedIterator(iterator); + } + + public interface AbstractScalaIterator { + + public abstract boolean hasNext(); + + public abstract E next() throws IOException; + + } + + public class BufferedIterator { + + private UnsafeRowLocation location = null; + private AbstractScalaIterator iterator; + + public BufferedIterator(AbstractScalaIterator iterator) { + this.iterator = iterator; + } + + public boolean hasNext() { + return iterator.hasNext(); + } + + public void loadNext() throws IOException { + location = iterator.next(); + } + + public UnsafeRowLocation getRecordLocation() { + return location; + } + } + + /** + * Mutable pair object + */ + public class MapEntry { + + public UnsafeRow key; + public UnsafeRow value; + + public MapEntry() { + this.key = new UnsafeRow(); + this.value = new UnsafeRow(); + } + + public MapEntry(UnsafeRow key, UnsafeRow value) { + this.key = key; + this.value = value; + } + } + + + public class UnsafeSorterKVSpillWriter { + + static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; + + // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to + // be an API to directly transfer bytes from managed memory to the disk writer, we buffer + // data through a byte array. + private byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE]; + + private final File file; + private final BlockId blockId; + private final int numRecordsToWrite; + private BlockObjectWriter writer; + private int numRecordsSpilled = 0; + + public UnsafeSorterKVSpillWriter( + BlockManager blockManager, + int fileBufferSize, + ShuffleWriteMetrics writeMetrics, + int numRecordsToWrite) throws IOException { + final Tuple2 spilledFileInfo = + blockManager.diskBlockManager().createTempLocalBlock(); + this.file = spilledFileInfo._2(); + this.blockId = spilledFileInfo._1(); + this.numRecordsToWrite = numRecordsToWrite; + // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. + // Our write path doesn't actually use this serializer (since we end up calling the `write()` + // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work + // around this, we pass a dummy no-op serializer. + writer = blockManager.getDiskWriter( + blockId, file, DummySerializerInstance.INSTANCE, fileBufferSize, writeMetrics); + // Write the number of records + writeIntToBuffer(numRecordsToWrite, 0); + writer.write(writeBuffer, 0, 4); + } + + // Based on DataOutputStream.writeInt. + private void writeIntToBuffer(int v, int offset) throws IOException { + writeBuffer[offset + 0] = (byte)(v >>> 24); + writeBuffer[offset + 1] = (byte)(v >>> 16); + writeBuffer[offset + 2] = (byte)(v >>> 8); + writeBuffer[offset + 3] = (byte)(v >>> 0); + } + + public void write(UnsafeRowLocation loc) throws IOException { + if (numRecordsSpilled == numRecordsToWrite) { + throw new IllegalStateException( + "Number of records written exceeded numRecordsToWrite = " + numRecordsToWrite); + } else { + numRecordsSpilled++; + } + this.write(loc.getKeyAddress().getBaseObject(), loc.getKeyAddress().getBaseOffset(), + loc.getKeyLength()); + this.write(loc.getValueAddress().getBaseObject(), loc.getValueAddress().getBaseOffset(), + loc.getValueLength()); + writer.recordWritten(); + } + + + /** + * Write a record to a spill file. + * + * @param baseObject the base object / memory page containing the record + * @param baseOffset the base offset which points directly to the record data. + * @param recordLength the length of the record. + */ + public void write( + Object baseObject, + long baseOffset, + int recordLength) throws IOException { + writeIntToBuffer(recordLength, 0); + int dataRemaining = recordLength; + int freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE - 4; // space used by len + long recordReadPosition = baseOffset; + while (dataRemaining > 0) { + final int toTransfer = Math.min(freeSpaceInWriteBuffer, dataRemaining); + PlatformDependent.copyMemory( + baseObject, + recordReadPosition, + writeBuffer, + PlatformDependent.BYTE_ARRAY_OFFSET + (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer), + toTransfer); + writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer) + toTransfer); + recordReadPosition += toTransfer; + dataRemaining -= toTransfer; + freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE; + } + if (freeSpaceInWriteBuffer < DISK_WRITE_BUFFER_SIZE) { + writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer)); + } + } + + public void close() throws IOException { + writer.commitAndClose(); + writer = null; + writeBuffer = null; + } + + public AbstractScalaIterator getReader(BlockManager blockManager) + throws IOException { + return new UnsafeSorterKVSpillReader(blockManager, file, blockId); + } + } + + final class UnsafeSorterKVSpillReader implements AbstractScalaIterator { + + private InputStream in; + private DataInputStream din; + + // Variables that change with every kv read: + private int numRecordsRemaining; + + private int keyLength; + private byte[] keyArray = new byte[1024 * 1024]; + + private int valueLength; + private byte[] valueArray = new byte[1024 * 1024]; + private final UnsafeRowLocation location = new UnsafeRowLocation(memoryManager); + + public UnsafeSorterKVSpillReader( + BlockManager blockManager, + File file, + BlockId blockId) throws IOException { + assert (file.length() > 0); + final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file)); + this.in = blockManager.wrapForCompression(blockId, bs); + this.din = new DataInputStream(this.in); + numRecordsRemaining = din.readInt(); + } + + @Override + public boolean hasNext() { + return (numRecordsRemaining > 0); + } + + @Override + public UnsafeRowLocation next() throws IOException { + keyLength = din.readInt(); + if (keyLength > keyArray.length) { + keyArray = new byte[keyLength]; + } + ByteStreams.readFully(in, keyArray, 0, keyLength); + valueLength = din.readInt(); + if (valueLength > valueArray.length) { + valueArray = new byte[valueLength]; + } + ByteStreams.readFully(in, valueArray, 0, valueLength); + numRecordsRemaining--; + if (numRecordsRemaining == 0) { + in.close(); + in = null; + din = null; + } + location.with(keyLength, keyArray, valueLength, valueArray); + return location; + } + } + + + public class UnsafeAppendOnlyMapIterator implements AbstractScalaIterator { + + private final MapEntry entry = new MapEntry(); + private final UnsafeRowLocation loc = new UnsafeRowLocation(memoryManager); + private int currentRecordNumber = 0; + private Object pageBaseObject; + private long offsetInPage; + private int numRecords; + + public UnsafeAppendOnlyMapIterator(int numRecords) { + this.numRecords = numRecords; + if (dataPages.iterator().hasNext()) { + advanceToNextPage(); + } + } + + private void advanceToNextPage() { + final MemoryBlock currentPage = dataPages.iterator().next(); + pageBaseObject = currentPage.getBaseObject(); + offsetInPage = currentPage.getBaseOffset(); + } + + @Override + public boolean hasNext() { + return currentRecordNumber != numRecords; + } + + @Override + public MapEntry next() { + int keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage); + if (keyLength == END_OF_PAGE_MARKER) { + advanceToNextPage(); + keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage); + } + loc.with(pageBaseObject, offsetInPage); + offsetInPage += 8 + 8 + keyLength + loc.getValueLength(); + + MemoryLocation keyAddress = loc.getKeyAddress(); + MemoryLocation valueAddress = loc.getValueAddress(); + entry.key.pointTo( + keyAddress.getBaseObject(), + keyAddress.getBaseOffset(), + keyConverter.numFields(), + loc.getKeyLength(), + keyPool + ); + entry.value.pointTo( + valueAddress.getBaseObject(), + valueAddress.getBaseOffset(), + bufferConverter.numFields(), + loc.getValueLength(), + bufferPool + ); + currentRecordNumber++; + return entry; + } + } + + private long getMemoryUsage() { + return map.getMemoryUsage() + (dataPages.size() * PAGE_SIZE_BYTES); + } + + private void cleanupResources() { + this.freeMemory(); + } + + /** + * Free the unsafe memory associated with this map. + */ + public long freeMemory() { + long memoryFreed = 0; + for (MemoryBlock block : dataPages) { + memoryManager.freePage(block); + shuffleMemoryManager.release(block.size()); + memoryFreed += block.size(); + } + if (map != null) { + map.free(); + } + dataPages.clear(); + currentDataPage = null; + pageCursor = 0; + freeSpaceInCurrentPage = 0; + return memoryFreed; + } + + @SuppressWarnings("UseOfSystemOutOrSystemErr") + public void printPerfMetrics() { + if (!enablePerfMetrics) { + throw new IllegalStateException("Perf metrics not enabled"); + } + System.out.println("Average probes per lookup: " + map.getAverageProbesPerLookup()); + System.out.println("Number of hash collisions: " + map.getNumHashCollisions()); + System.out.println("Time spent resizing (ns): " + map.getTimeSpentResizingNs()); + System.out.println("Total memory consumption (bytes): " + map.getMemoryUsage()); + System.out.println("Number of unique objects in keys: " + keyPool.size()); + System.out.println("Number of objects in buffers: " + bufferPool.size()); + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 886a486bf5ee0..8019da5482a6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -19,6 +19,24 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow +/** + * Converts a [[InternalRow]] to another Row given a sequence of expression that define each + * column of the new row. If the schema of the input row is specified, then the given expression + * will be bound to that schema. + * + * In contrast to a normal projection, a MutableProjection reuses the same underlying row object + * each time an input row is added. This significantly reduces the cost of calculating the + * projection, but means that it is not safe to hold on to a reference to a [[InternalRow]] after + * `next()` has been called on the [[Iterator]] that produced it. Instead, the user must call + * `InternalRow.copy()` and hold on to the returned [[InternalRow]] before calling `next()`. + */ +abstract class MutableProjection extends Projection { + def currentValue: InternalRow + + /** Uses the given row to store the output of the projection. */ + def target(row: MutableRow): MutableProjection +} + /** * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. * @param expressions a sequence of expressions that determine the value of each column of the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 30b7f8d3766a5..5cb59f8c3cba0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -62,22 +62,4 @@ package object expressions { * will be bound to that schema. */ abstract class Projection extends (InternalRow => InternalRow) - - /** - * Converts a [[InternalRow]] to another Row given a sequence of expression that define each - * column of the new row. If the schema of the input row is specified, then the given expression - * will be bound to that schema. - * - * In contrast to a normal projection, a MutableProjection reuses the same underlying row object - * each time an input row is added. This significantly reduces the cost of calculating the - * projection, but means that it is not safe to hold on to a reference to a [[InternalRow]] after - * `next()` has been called on the [[Iterator]] that produced it. Instead, the user must call - * `InternalRow.copy()` and hold on to the returned [[InternalRow]] before calling `next()`. - */ - abstract class MutableProjection extends Projection { - def currentValue: InternalRow - - /** Uses the given row to store the output of the projection. */ - def target(row: MutableRow): MutableProjection - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index c069da016f9f0..aac24a7a6bb02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -17,19 +17,25 @@ package org.apache.spark.sql.execution +import java.io.IOException + import org.apache.spark.TaskContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering import org.apache.spark.sql.catalyst.trees._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.PlatformDependent case class AggregateEvaluation( schema: Seq[Attribute], initialValues: Seq[Expression], update: Seq[Expression], + mergeSchema: Seq[Attribute], + merge: Seq[Expression], result: Expression) /** @@ -85,9 +91,16 @@ case class GeneratedAggregate( val currentCount = AttributeReference("currentCount", LongType, nullable = false)() val initialValue = Literal(0L) val updateFunction = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount) + val mergeCount = AttributeReference("mergeCount", LongType, nullable = false)() + val mergeFunction = Coalesce( + Add( + Coalesce(currentCount :: initialValue :: Nil), + Coalesce(mergeCount :: initialValue :: Nil) + ) :: currentCount :: initialValue :: Nil) val result = currentCount - AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result) + AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, + mergeCount :: Nil, mergeFunction :: Nil, result) case s @ Sum(expr) => val calcType = @@ -109,6 +122,14 @@ case class GeneratedAggregate( Coalesce(currentSum :: zero :: Nil), Cast(expr, calcType) ) :: currentSum :: zero :: Nil) + + val mergeSum = AttributeReference("mergeSum", calcType, nullable = false)() + val mergeFunction = Coalesce( + Add( + Coalesce(currentSum :: zero :: Nil), + Coalesce(mergeSum :: zero :: Nil) + ) :: currentSum :: zero :: Nil) + val result = expr.dataType match { case DecimalType.Fixed(_, _) => @@ -116,7 +137,8 @@ case class GeneratedAggregate( case _ => currentSum } - AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) + AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, + mergeSum :: Nil, mergeFunction :: Nil, result) case cs @ CombineSum(expr) => val calcType = @@ -148,6 +170,15 @@ case class GeneratedAggregate( Cast(expr, calcType)) :: currentSum :: zero :: Nil), currentSum) + val mergeSum = AttributeReference("mergeSum", calcType, nullable = false)() + val mergeFunction = If( + IsNotNull(mergeSum), + Coalesce( + Add( + Coalesce(currentSum :: zero :: Nil), + Cast(mergeSum, calcType)) :: currentSum :: zero :: Nil), + currentSum) + val result = expr.dataType match { case DecimalType.Fixed(_, _) => @@ -155,28 +186,37 @@ case class GeneratedAggregate( case _ => currentSum } - AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) + AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, + mergeSum :: Nil, mergeFunction :: Nil, result) case m @ Max(expr) => val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)() val initialValue = Literal.create(null, expr.dataType) val updateMax = MaxOf(currentMax, expr) + val mergeMax = AttributeReference("mergeMax", expr.dataType, nullable = false)() + val mergeFunction = MaxOf(currentMax, mergeMax) AggregateEvaluation( currentMax :: Nil, initialValue :: Nil, updateMax :: Nil, + mergeMax :: Nil, + mergeFunction :: Nil, currentMax) case m @ Min(expr) => val currentMin = AttributeReference("currentMin", expr.dataType, nullable = true)() val initialValue = Literal.create(null, expr.dataType) val updateMin = MinOf(currentMin, expr) + val mergeMin = AttributeReference("mergeMin", expr.dataType, nullable = false)() + val mergeFunction = MinOf(currentMin, mergeMin) AggregateEvaluation( currentMin :: Nil, initialValue :: Nil, updateMin :: Nil, + mergeMin :: Nil, + mergeFunction :: Nil, currentMin) case CollectHashSet(Seq(expr)) => @@ -184,11 +224,16 @@ case class GeneratedAggregate( AttributeReference("hashSet", new OpenHashSetUDT(expr.dataType), nullable = false)() val initialValue = NewSet(expr.dataType) val addToSet = AddItemToSet(expr, set) + val mergeSet = + AttributeReference("mergeHashSet", new OpenHashSetUDT(expr.dataType), nullable = false)() + val mergeFunction = AddItemToSet(mergeSet, set) AggregateEvaluation( set :: Nil, initialValue :: Nil, addToSet :: Nil, + mergeSet :: Nil, + mergeFunction :: Nil, set) case CombineSetsAndCount(inputSet) => @@ -197,11 +242,17 @@ case class GeneratedAggregate( AttributeReference("hashSet", new OpenHashSetUDT(elementType), nullable = false)() val initialValue = NewSet(elementType) val collectSets = CombineSets(set, inputSet) + val mergeSet = + AttributeReference("mergeHashSet", new OpenHashSetUDT(elementType), nullable = false)() + val mergeFunction = CombineSets(set, mergeSet) + AggregateEvaluation( set :: Nil, initialValue :: Nil, collectSets :: Nil, + mergeSet :: Nil, + mergeFunction :: Nil, CountSet(set)) case o => sys.error(s"$o can't be codegened.") @@ -281,40 +332,92 @@ case class GeneratedAggregate( Iterator(resultProjection(buffer)) } else if (unsafeEnabled) { log.info("Using Unsafe-based aggregator") - val aggregationMap = new UnsafeFixedWidthAggregationMap( - newAggregationBuffer, - new UnsafeRowConverter(groupKeySchema), - new UnsafeRowConverter(aggregationBufferSchema), - TaskContext.get.taskMemoryManager(), - 1024 * 16, // initial capacity - false // disable tracking of performance metrics - ) - - while (iter.hasNext) { - val currentRow: InternalRow = iter.next() - val groupKey: InternalRow = groupProjection(currentRow) - val aggregationBuffer = aggregationMap.getAggregationBuffer(groupKey) - updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow)) - } - - new Iterator[InternalRow] { - private[this] val mapIterator = aggregationMap.iterator() - private[this] val resultProjection = resultProjectionBuilder() + if (UnsafeExternalAggregation.supportSchema(groupKeySchema, + aggregationBufferSchema)) { + val groupingKeyOrdering: Ordering[InternalRow] = GenerateOrdering.generate( + groupingExpressions.map(_.dataType).zipWithIndex.map { case(dt, index) => + new SortOrder(BoundReference(index, dt, nullable = true), Ascending) + }) + val mergeExpressions = computeFunctions.flatMap(_.merge) + val mergeSchema = computeFunctions.flatMap(_.schema) ++ + computeFunctions.flatMap(_.mergeSchema) + val mergeProjection = newMutableProjection(mergeExpressions, mergeSchema)() + log.info(s"Merge Expressions: ${mergeExpressions.mkString(",")}") + log.info(s"mergeSchema: ${mergeSchema.mkString(",")}") + val aggregationMap = new UnsafeExternalAggregation( + newAggregationBuffer, + updateProjection, + mergeProjection, + new UnsafeRowConverter(groupKeySchema), + new UnsafeRowConverter(aggregationBufferSchema), + groupingKeyOrdering, + 1024 * 16, // initial capacity + false + ) + + while (iter.hasNext) { + val currentRow: InternalRow = iter.next() + val groupKey: InternalRow = groupProjection(currentRow) + aggregationMap.insertRow(groupKey, currentRow) + } - def hasNext: Boolean = mapIterator.hasNext + new Iterator[InternalRow] { + private[this] val mapIterator = aggregationMap.iterator() + private[this] val resultProjection = resultProjectionBuilder() + + def hasNext: Boolean = mapIterator.hasNext + + def next(): InternalRow = { + val entry = mapIterator.next() + val result = resultProjection(joinedRow(entry.key, entry.value)) + if (hasNext) { + result + } else { + // This is the last element in the iterator, so let's free the buffer. Before we do, + // though, we need to make a defensive copy of the result so that we don't return an + // object that might contain dangling pointers to the freed memory + val resultCopy = result.copy() + aggregationMap.freeMemory() + resultCopy + } + } + } + } else { + val aggregationMap = new UnsafeFixedWidthAggregationMap( + newAggregationBuffer, + new UnsafeRowConverter(groupKeySchema), + new UnsafeRowConverter(aggregationBufferSchema), + TaskContext.get.taskMemoryManager(), + 1024 * 16, // initial capacity + false // disable tracking of performance metrics + ) + + while (iter.hasNext) { + val currentRow: InternalRow = iter.next() + val groupKey: InternalRow = groupProjection(currentRow) + val aggregationBuffer = aggregationMap.getAggregationBuffer(groupKey) + updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow)) + } - def next(): InternalRow = { - val entry = mapIterator.next() - val result = resultProjection(joinedRow(entry.key, entry.value)) - if (hasNext) { - result - } else { - // This is the last element in the iterator, so let's free the buffer. Before we do, - // though, we need to make a defensive copy of the result so that we don't return an - // object that might contain dangling pointers to the freed memory - val resultCopy = result.copy() - aggregationMap.free() - resultCopy + new Iterator[InternalRow] { + private[this] val mapIterator = aggregationMap.iterator() + private[this] val resultProjection = resultProjectionBuilder() + + def hasNext: Boolean = mapIterator.hasNext + + def next(): InternalRow = { + val entry = mapIterator.next() + val result = resultProjection(joinedRow(entry.key, entry.value)) + if (hasNext) { + result + } else { + // This is the last element in the iterator, so let's free the buffer. Before we do, + // though, we need to make a defensive copy of the result so that we don't return an + // object that might contain dangling pointers to the freed memory + val resultCopy = result.copy() + aggregationMap.free() + resultCopy + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalAggregationSuite.scala new file mode 100644 index 0000000000000..b73930ca3cb3f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalAggregationSuite.scala @@ -0,0 +1,120 @@ +package org.apache.spark.sql.execution + +import org.apache.spark.SparkEnv +import org.apache.spark.sql.{QueryTest, TestData, Row, SQLConf} +import org.scalatest.BeforeAndAfterAll + +class UnsafeExternalAggregationSuite extends QueryTest with BeforeAndAfterAll { + + TestData + val sqlContext = org.apache.spark.sql.test.TestSQLContext + import sqlContext.sql + + test("aggregation with codegen") { + val originalValue = sqlContext.conf.codegenEnabled + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) + val unsafeOriginalValue = sqlContext.conf.unsafeEnabled + sqlContext.setConf(SQLConf.UNSAFE_ENABLED, true) + SparkEnv.get.conf.set("spark.test.aggregate.spillFrequency","5") + // Prepare a table that we can group some rows. + sqlContext.table("testData") + .unionAll(sqlContext.table("testData")) + .unionAll(sqlContext.table("testData")) + .registerTempTable("testData3x") + + def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { + val df = sql(sqlText) + // First, check if we have GeneratedAggregate. + var hasGeneratedAgg = false + df.queryExecution.executedPlan.foreach { + case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true + case _ => + } + if (!hasGeneratedAgg) { + fail( + s""" + |Codegen is enabled, but query $sqlText does not have GeneratedAggregate in the plan. + |${df.queryExecution.simpleString} + """.stripMargin) + } + // Then, check results. + checkAnswer(df, expectedResults) + } + + try { + // Just to group rows. + testCodeGen( + "SELECT key FROM testData3x GROUP BY key", + (1 to 100).map(Row(_))) + // COUNT + testCodeGen( + "SELECT key, count(value) FROM testData3x GROUP BY key", + (1 to 100).map(i => Row(i, 3))) + testCodeGen( + "SELECT count(key) FROM testData3x", + Row(300) :: Nil) + // COUNT DISTINCT ON int + testCodeGen( + "SELECT value, count(distinct key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 1))) + testCodeGen( + "SELECT count(distinct key) FROM testData3x", + Row(100) :: Nil) + // SUM + testCodeGen( + "SELECT value, sum(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 3 * i))) + testCodeGen( + "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x", + Row(5050 * 3, 5050 * 3.0) :: Nil) + // AVERAGE + testCodeGen( + "SELECT value, avg(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT avg(key) FROM testData3x", + Row(50.5) :: Nil) + // MAX + testCodeGen( + "SELECT value, max(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT max(key) FROM testData3x", + Row(100) :: Nil) + // MIN + testCodeGen( + "SELECT value, min(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT min(key) FROM testData3x", + Row(1) :: Nil) + // Some combinations. + testCodeGen( + """ + |SELECT + | value, + | sum(key), + | max(key), + | min(key), + | avg(key), + | count(key), + | count(distinct key) + |FROM testData3x + |GROUP BY value + """.stripMargin, + (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1))) + testCodeGen( + "SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x", + Row(100, 1, 50.5, 300, 100) :: Nil) + // Aggregate with Code generation handling all null values + testCodeGen( + "SELECT sum('a'), avg('a'), count(null) FROM testData", + Row(0, null, 0) :: Nil) + } finally { + sqlContext.dropTempTable("testData3x") + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue) + sqlContext.setConf(SQLConf.UNSAFE_ENABLED, unsafeOriginalValue) + SparkEnv.get.conf.set("spark.test.aggregate.spillFrequency","0") + } + } +} From 688cdf44ccc6e73bcbd689a47c7537100e0fc8fa Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Thu, 16 Jul 2015 00:40:03 +0800 Subject: [PATCH 2/5] remove unused import --- .../org/apache/spark/sql/execution/GeneratedAggregate.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index aac24a7a6bb02..c3486f4b9bd1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution -import java.io.IOException - import org.apache.spark.TaskContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD @@ -28,7 +26,6 @@ import org.apache.spark.sql.catalyst.trees._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.PlatformDependent case class AggregateEvaluation( schema: Seq[Attribute], From 91168108cd9ffb9edae6f66837fee830f44d30e4 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Thu, 16 Jul 2015 00:47:29 +0800 Subject: [PATCH 3/5] add apache license header --- .../catalyst/expressions/UnsafeRowLocation.java | 17 +++++++++++++++++ .../sql/execution/UnsafeAppendOnlyMap.java | 17 +++++++++++++++++ .../UnsafeExternalAggregationSuite.scala | 17 +++++++++++++++++ 3 files changed, 51 insertions(+) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowLocation.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowLocation.java index 56dacb7137b63..b77fae6b3e8cd 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowLocation.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowLocation.java @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.unsafe.PlatformDependent; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeAppendOnlyMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeAppendOnlyMap.java index daed2e8b767d3..f265d73fcfa22 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeAppendOnlyMap.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeAppendOnlyMap.java @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.sql.execution; import java.util.Comparator; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalAggregationSuite.scala index b73930ca3cb3f..e0bb96fcbbb43 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalAggregationSuite.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.sql.execution import org.apache.spark.SparkEnv From 68552cc8cb4f93f2a6fed935b4f50b850e0ce9cc Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Thu, 16 Jul 2015 16:55:59 +0800 Subject: [PATCH 4/5] update with scala style --- .../spark/sql/execution/UnsafeExternalAggregationSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalAggregationSuite.scala index e0bb96fcbbb43..45cfb6a397236 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalAggregationSuite.scala @@ -32,7 +32,7 @@ class UnsafeExternalAggregationSuite extends QueryTest with BeforeAndAfterAll { sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) val unsafeOriginalValue = sqlContext.conf.unsafeEnabled sqlContext.setConf(SQLConf.UNSAFE_ENABLED, true) - SparkEnv.get.conf.set("spark.test.aggregate.spillFrequency","5") + SparkEnv.get.conf.set("spark.test.aggregate.spillFrequency", "5") // Prepare a table that we can group some rows. sqlContext.table("testData") .unionAll(sqlContext.table("testData")) @@ -131,7 +131,7 @@ class UnsafeExternalAggregationSuite extends QueryTest with BeforeAndAfterAll { sqlContext.dropTempTable("testData3x") sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue) sqlContext.setConf(SQLConf.UNSAFE_ENABLED, unsafeOriginalValue) - SparkEnv.get.conf.set("spark.test.aggregate.spillFrequency","0") + SparkEnv.get.conf.set("spark.test.aggregate.spillFrequency", "0") } } } From e6516838b46bff38a7bc07e2ee9e82a235bbf29e Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Thu, 16 Jul 2015 22:47:22 +0800 Subject: [PATCH 5/5] merge with master --- .../apache/spark/sql/execution/UnsafeExternalAggregation.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalAggregation.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalAggregation.java index 2a2fa5a11ffe3..077fcb0dff7bc 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalAggregation.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalAggregation.java @@ -42,7 +42,7 @@ import org.apache.spark.sql.types.StructType; import org.apache.spark.storage.BlockId; import org.apache.spark.storage.BlockManager; -import org.apache.spark.storage.BlockObjectWriter; +import org.apache.spark.storage.DiskBlockObjectWriter; import org.apache.spark.storage.TempLocalBlockId; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -740,7 +740,7 @@ public class UnsafeSorterKVSpillWriter { private final File file; private final BlockId blockId; private final int numRecordsToWrite; - private BlockObjectWriter writer; + private DiskBlockObjectWriter writer; private int numRecordsSpilled = 0; public UnsafeSorterKVSpillWriter(