diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/map/BitSet.java
similarity index 98%
rename from unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java
rename to core/src/main/java/org/apache/spark/util/collection/unsafe/map/BitSet.java
index 7c124173b0bbb..2abbf0b9af51f 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/map/BitSet.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.unsafe.bitset;
+package org.apache.spark.util.collection.unsafe.map;
import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.memory.MemoryBlock;
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/map/BitSetMethods.java
similarity index 98%
rename from unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java
rename to core/src/main/java/org/apache/spark/util/collection/unsafe/map/BitSetMethods.java
index 27462c7fa5e62..a173b7912380d 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/map/BitSetMethods.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.unsafe.bitset;
+package org.apache.spark.util.collection.unsafe.map;
import org.apache.spark.unsafe.PlatformDependent;
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/map/BytesToBytesMap.java
similarity index 79%
rename from unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
rename to core/src/main/java/org/apache/spark/util/collection/unsafe/map/BytesToBytesMap.java
index d0bde69cc1068..a4eb66d655be4 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/map/BytesToBytesMap.java
@@ -15,22 +15,21 @@
* limitations under the License.
*/
-package org.apache.spark.unsafe.map;
+package org.apache.spark.util.collection.unsafe.map;
-import java.lang.Override;
-import java.lang.UnsupportedOperationException;
+import java.util.Comparator;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import com.google.common.annotations.VisibleForTesting;
-
-import org.apache.spark.unsafe.*;
+import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.array.ByteArrayMethods;
-import org.apache.spark.unsafe.array.LongArray;
-import org.apache.spark.unsafe.bitset.BitSet;
-import org.apache.spark.unsafe.hash.Murmur3_x86_32;
-import org.apache.spark.unsafe.memory.*;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.MemoryLocation;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.collection.Sorter;
+import org.apache.spark.util.collection.SortDataFormat;
/**
* An append-only hash map where keys and values are contiguous regions of bytes.
@@ -98,7 +97,8 @@ public final class BytesToBytesMap {
* Position {@code 2 * i} in the array is used to track a pointer to the key at index {@code i},
* while position {@code 2 * i + 1} in the array holds key's full 32-bit hashcode.
*/
- private LongArray longArray;
+ private long[] longArray;
+ // TODO: replace longArray with LongArray
// TODO: we're wasting 32 bits of space here; we can probably store fewer bits of the hashcode
// and exploit word-alignment to use fewer bits to hold the address. This might let us store
// only one long per map entry, increasing the chance that this array will fit in cache at the
@@ -184,6 +184,30 @@ public BytesToBytesMap(
*/
public int size() { return size; }
+ public boolean hasSpaceForAnotherRecord() {
+ return size < growthThreshold;
+ }
+
+ public int getNextCapacity() {
+ int nextCapacity =
+ Math.min(growthStrategy.nextCapacity((int)bitset.capacity()), MAX_CAPACITY);
+ // The capacity needs to be divisible by 64 so that our bit set can be sized properly
+ return Math.max((int) Math.min(MAX_CAPACITY, nextPowerOf2(nextCapacity)), 64);
+ }
+
+ public Location getNewLocation() {
+ return new Location();
+ }
+
+ public static int getCapacity(int initialCapacity) {
+ // The capacity needs to be divisible by 64 so that our bit set can be sized properly
+ return Math.max((int) Math.min(MAX_CAPACITY, nextPowerOf2(initialCapacity)), 64);
+ }
+
+ public static long getMemoryUsage(int capacity) {
+ return capacity * 8L * 2 + capacity / 8;
+ }
+
private static final class BytesToBytesMapIterator implements Iterator {
private final int numRecords;
@@ -245,6 +269,84 @@ public Iterator iterator() {
return new BytesToBytesMapIterator(size, dataPages.iterator(), loc);
}
+ private static final class KVLongArraySortDataFormat extends SortDataFormat {
+
+ @Override
+ public Long getKey(long[] data, int pos) {
+ return data[2 * pos];
+ }
+
+ @Override
+ public Long newKey() {
+ return 0L;
+ }
+
+ @Override
+ public void swap(long[] data,int pos0, int pos1) {
+ long tmpKey = data[2 * pos0];
+ long tmpVal = data[2 * pos0 + 1];
+ data[2 * pos0] = data[2 * pos1];
+ data[2 * pos0 + 1] = data[2 * pos1 + 1];
+ data[2 * pos1] = tmpKey;
+ data[2 * pos1 + 1] = tmpVal;
+ }
+
+ @Override
+ public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) {
+ dst[2 * dstPos] = src[2 * srcPos];
+ dst[2 * dstPos + 1] = src[2 * srcPos + 1];
+ }
+
+ @Override
+ public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) {
+ System.arraycopy(src, 2 * srcPos, dst, 2 * dstPos, 2 * length);
+ }
+
+ @Override
+ public long[] allocate(int length) {
+ return new long[2 * length];
+ }
+ }
+
+ public Iterator getSortedIterator(Comparator sortComparator) {
+ // Pack KV pairs into the front of the underlying array
+ int keyIndex = 0;
+ int newIndex = 0;
+ while (keyIndex < bitset.capacity()) {
+ if (bitset.isSet(keyIndex)) {
+ longArray[2 * newIndex] = longArray[2 * keyIndex];
+ longArray[2 * newIndex + 1] = longArray[2 * keyIndex + 1];
+ newIndex += 1;
+ }
+ keyIndex += 1;
+ }
+
+ new Sorter(new KVLongArraySortDataFormat()).sort(longArray, 0, newIndex, sortComparator);
+
+ final int numRecords = newIndex;
+ return new Iterator() {
+
+ private int currentRecordNumber = 0;
+
+ @Override
+ public boolean hasNext() {
+ return currentRecordNumber != numRecords;
+ }
+
+ @Override
+ public Location next() {
+ loc.with(longArray[currentRecordNumber * 2]);
+ currentRecordNumber++;
+ return loc;
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+ };
+ }
+
/**
* Looks up a key, and return a {@link Location} handle that can be used to test existence
* and read/write values.
@@ -269,7 +371,7 @@ public Location lookup(
// This is a new key.
return loc.with(pos, hashcode, false);
} else {
- long stored = longArray.get(pos * 2 + 1);
+ long stored = longArray[pos * 2 + 1];
if ((int) (stored) == hashcode) {
// Full hash code matches. Let's compare the keys for equality.
loc.with(pos, hashcode, true);
@@ -303,6 +405,7 @@ public Location lookup(
* Handle returned by {@link BytesToBytesMap#lookup(Object, long, int)} function.
*/
public final class Location {
+
/** An index into the hash map's Long array */
private int pos;
/** True if this location points to a position where a key is defined, false otherwise */
@@ -334,23 +437,39 @@ private void updateAddressesAndSizes(Object page, long keyOffsetInPage) {
valueMemoryLocation.setObjAndOffset(page, position);
}
+ public Location with(long fullKeyAddress) {
+ this.isDefined = true;
+ updateAddressesAndSizes(fullKeyAddress);
+ return this;
+ }
+
Location with(int pos, int keyHashcode, boolean isDefined) {
this.pos = pos;
this.isDefined = isDefined;
this.keyHashcode = keyHashcode;
if (isDefined) {
- final long fullKeyAddress = longArray.get(pos * 2);
+ final long fullKeyAddress = longArray[pos * 2];
updateAddressesAndSizes(fullKeyAddress);
}
return this;
}
- Location with(Object page, long keyOffsetInPage) {
+ public Location with(Object page, long keyOffsetInPage) {
this.isDefined = true;
updateAddressesAndSizes(page, keyOffsetInPage);
return this;
}
+ public Location with(int keyLength, byte[] keyArray, int valueLength,
+ byte[] valueArray) {
+ this.isDefined = true;
+ this.keyLength = keyLength;
+ keyMemoryLocation.setObjAndOffset(keyArray, PlatformDependent.BYTE_ARRAY_OFFSET);
+ this.valueLength = valueLength;
+ valueMemoryLocation.setObjAndOffset(valueArray, PlatformDependent.BYTE_ARRAY_OFFSET);
+ return this;
+ }
+
/**
* Returns true if the key is defined at this position, and false otherwise.
*/
@@ -378,6 +497,22 @@ public int getKeyLength() {
return keyLength;
}
+ /**
+ * Returns the base object of the key.
+ */
+ public Object getKeyBaseObject() {
+ assert (isDefined);
+ return keyMemoryLocation.getBaseObject();
+ }
+
+ /**
+ * Returns the base offset of the key.
+ */
+ public long getKeyBaseOffset() {
+ assert (isDefined);
+ return keyMemoryLocation.getBaseOffset();
+ }
+
/**
* Returns the address of the value defined at this position.
* This points to the first byte of the value data.
@@ -398,6 +533,23 @@ public int getValueLength() {
return valueLength;
}
+ /**
+ * Returns the base object of the value.
+ */
+ public Object getValueBaseObject() {
+ assert (isDefined);
+ return valueMemoryLocation.getBaseObject();
+ }
+
+ /**
+ * Return the base offset of the value.
+ */
+ public long getValueBaseOffset() {
+ assert (isDefined);
+ return valueMemoryLocation.getBaseOffset();
+ }
+
+
/**
* Store a new key and value. This method may only be called once for a given key; if you want
* to update the value associated with a key, then you can directly manipulate the bytes stored
@@ -425,6 +577,23 @@ public int getValueLength() {
* Unspecified behavior if the key is not defined.
*
*/
+
+ public void putNewKey(long storedKeyAddress) {
+ assert (!isDefined) : "Can only set value once for a key";
+ if (size == MAX_CAPACITY) {
+ throw new IllegalStateException("BytesToBytesMap has reached maximum capacity");
+ }
+ size++;
+ bitset.set(pos);
+ longArray[pos * 2] = storedKeyAddress;
+ longArray[pos * 2 + 1] = keyHashcode;
+ updateAddressesAndSizes(storedKeyAddress);
+ isDefined = true;
+ if (size > growthThreshold && longArray.length < MAX_CAPACITY) {
+ growAndRehash();
+ }
+ }
+
public void putNewKey(
Object keyBaseObject,
long keyBaseOffset,
@@ -432,20 +601,15 @@ public void putNewKey(
Object valueBaseObject,
long valueBaseOffset,
int valueLengthBytes) {
- assert (!isDefined) : "Can only set value once for a key";
assert (keyLengthBytes % 8 == 0);
assert (valueLengthBytes % 8 == 0);
- if (size == MAX_CAPACITY) {
- throw new IllegalStateException("BytesToBytesMap has reached maximum capacity");
- }
+
// Here, we'll copy the data into our data pages. Because we only store a relative offset from
// the key address instead of storing the absolute address of the value, the key and value
// must be stored in the same memory page.
// (8 byte key length) (key) (8 byte value length) (value)
final long requiredSize = 8 + keyLengthBytes + 8 + valueLengthBytes;
assert (requiredSize <= PAGE_SIZE_BYTES - 8); // Reserve 8 bytes for the end-of-page marker.
- size++;
- bitset.set(pos);
// If there's not enough space in the current page, allocate a new page (8 bytes are reserved
// for the end-of-page marker).
@@ -485,13 +649,7 @@ public void putNewKey(
final long storedKeyAddress = memoryManager.encodePageNumberAndOffset(
currentDataPage, keySizeOffsetInPage);
- longArray.set(pos * 2, storedKeyAddress);
- longArray.set(pos * 2 + 1, keyHashcode);
- updateAddressesAndSizes(storedKeyAddress);
- isDefined = true;
- if (size > growthThreshold && longArray.size() < MAX_CAPACITY) {
- growAndRehash();
- }
+ this.putNewKey(storedKeyAddress);
}
}
@@ -506,7 +664,7 @@ private void allocate(int capacity) {
// The capacity needs to be divisible by 64 so that our bit set can be sized properly
capacity = Math.max((int) Math.min(MAX_CAPACITY, nextPowerOf2(capacity)), 64);
assert (capacity <= MAX_CAPACITY);
- longArray = new LongArray(memoryManager.allocate(capacity * 8L * 2));
+ longArray = new long[capacity * 2];
bitset = new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64]));
this.growthThreshold = (int) (capacity * loadFactor);
@@ -521,7 +679,6 @@ private void allocate(int capacity) {
*/
public void free() {
if (longArray != null) {
- memoryManager.free(longArray.memoryBlock());
longArray = null;
}
if (bitset != null) {
@@ -541,7 +698,7 @@ public long getTotalMemoryConsumption() {
return (
dataPages.size() * PAGE_SIZE_BYTES +
bitset.memoryBlock().size() +
- longArray.memoryBlock().size());
+ longArray.length * 8L);
}
/**
@@ -581,13 +738,13 @@ int getNumDataPages() {
* Grows the size of the hash table and re-hash everything.
*/
@VisibleForTesting
- void growAndRehash() {
+ public void growAndRehash() {
long resizeStartTime = -1;
if (enablePerfMetrics) {
resizeStartTime = System.nanoTime();
}
// Store references to the old data structures to be used when we re-hash
- final LongArray oldLongArray = longArray;
+ final long[] oldLongArray = longArray;
final BitSet oldBitSet = bitset;
final int oldCapacity = (int) oldBitSet.capacity();
@@ -596,8 +753,8 @@ void growAndRehash() {
// Re-mask (we don't recompute the hashcode because we stored all 32 bits of it)
for (int pos = oldBitSet.nextSetBit(0); pos >= 0; pos = oldBitSet.nextSetBit(pos + 1)) {
- final long keyPointer = oldLongArray.get(pos * 2);
- final int hashcode = (int) oldLongArray.get(pos * 2 + 1);
+ final long keyPointer = oldLongArray[pos * 2];
+ final int hashcode = (int) oldLongArray[pos * 2 + 1];
int newPos = hashcode & mask;
int step = 1;
boolean keepGoing = true;
@@ -607,8 +764,8 @@ void growAndRehash() {
while (keepGoing) {
if (!bitset.isSet(newPos)) {
bitset.set(newPos);
- longArray.set(newPos * 2, keyPointer);
- longArray.set(newPos * 2 + 1, hashcode);
+ longArray[newPos * 2] = keyPointer;
+ longArray[newPos * 2 + 1] = hashcode;
keepGoing = false;
} else {
newPos = (newPos + step) & mask;
@@ -617,8 +774,6 @@ void growAndRehash() {
}
}
- // Deallocate the old data structures.
- memoryManager.free(oldLongArray.memoryBlock());
if (enablePerfMetrics) {
timeSpentResizingNs += System.nanoTime() - resizeStartTime;
}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/map/HashMapGrowthStrategy.java
similarity index 96%
rename from unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java
rename to core/src/main/java/org/apache/spark/util/collection/unsafe/map/HashMapGrowthStrategy.java
index 20654e4eeaa02..d89a3213e7b23 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/map/HashMapGrowthStrategy.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.unsafe.map;
+package org.apache.spark.util.collection.unsafe.map;
/**
* Interface that defines how we can grow the size of a hash map when it is over a threshold.
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/map/Murmur3_x86_32.java
similarity index 98%
rename from unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
rename to core/src/main/java/org/apache/spark/util/collection/unsafe/map/Murmur3_x86_32.java
index 61f483ced3217..2b295d922b0b4 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/map/Murmur3_x86_32.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.unsafe.hash;
+package org.apache.spark.util.collection.unsafe.map;
import org.apache.spark.unsafe.PlatformDependent;
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/map/AbstractBytesToBytesMapSuite.java
similarity index 99%
rename from unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
rename to core/src/test/java/org/apache/spark/util/collection/unsafe/map/AbstractBytesToBytesMapSuite.java
index dae47e4bab0cb..e448d9f840328 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.unsafe.map;
+package org.apache.spark.util.collection.unsafe.map;
import java.lang.Exception;
import java.nio.ByteBuffer;
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/map/BitSetSuite.java
similarity index 98%
rename from unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java
rename to core/src/test/java/org/apache/spark/util/collection/unsafe/map/BitSetSuite.java
index a93fc0ee297c4..efee5c1327eb0 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/map/BitSetSuite.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.unsafe.bitset;
+package org.apache.spark.util.collection.unsafe.map;
import junit.framework.Assert;
import org.junit.Test;
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/map/BytesToBytesMapOffHeapSuite.java
similarity index 95%
rename from unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java
rename to core/src/test/java/org/apache/spark/util/collection/unsafe/map/BytesToBytesMapOffHeapSuite.java
index 5a10de49f54fe..68a4534d0d41b 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/map/BytesToBytesMapOffHeapSuite.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.unsafe.map;
+package org.apache.spark.util.collection.unsafe.map;
import org.apache.spark.unsafe.memory.MemoryAllocator;
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/map/BytesToBytesMapOnHeapSuite.java
similarity index 95%
rename from unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java
rename to core/src/test/java/org/apache/spark/util/collection/unsafe/map/BytesToBytesMapOnHeapSuite.java
index 12cc9b25d93b3..6ec919ba17c17 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/map/BytesToBytesMapOnHeapSuite.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.unsafe.map;
+package org.apache.spark.util.collection.unsafe.map;
import org.apache.spark.unsafe.memory.MemoryAllocator;
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/map/Murmur3_x86_32Suite.java
similarity index 98%
rename from unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java
rename to core/src/test/java/org/apache/spark/util/collection/unsafe/map/Murmur3_x86_32Suite.java
index 3b9175835229c..9aba304b3a1d5 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/map/Murmur3_x86_32Suite.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.unsafe.hash;
+package org.apache.spark.util.collection.unsafe.map;
import java.util.HashSet;
import java.util.Random;
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
index 684de6e81d67c..fb2198854b10f 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
@@ -23,7 +23,7 @@
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.PlatformDependent;
-import org.apache.spark.unsafe.map.BytesToBytesMap;
+import org.apache.spark.util.collection.unsafe.map.BytesToBytesMap;
import org.apache.spark.unsafe.memory.MemoryLocation;
import org.apache.spark.unsafe.memory.TaskMemoryManager;
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 87e5a89c19658..03f3bb4a4a14a 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -27,8 +27,8 @@
import org.apache.spark.sql.types.DataType;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.array.ByteArrayMethods;
-import org.apache.spark.unsafe.bitset.BitSetMethods;
-import org.apache.spark.unsafe.hash.Murmur3_x86_32;
+import org.apache.spark.util.collection.unsafe.map.BitSetMethods;
+import org.apache.spark.util.collection.unsafe.map.Murmur3_x86_32;
import org.apache.spark.unsafe.types.UTF8String;
import static org.apache.spark.sql.types.DataTypes.*;
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalAggregation.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalAggregation.java
new file mode 100644
index 0000000000000..585f9f93d521a
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalAggregation.java
@@ -0,0 +1,979 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution;
+
+import java.io.*;
+import java.util.*;
+
+import scala.Tuple2;
+import scala.math.Ordering;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.io.ByteStreams;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.serializer.DummySerializerInstance;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.SparkConf;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction2;
+import org.apache.spark.sql.catalyst.expressions.*;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.storage.BlockId;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.storage.DiskBlockObjectWriter;
+import org.apache.spark.storage.TempLocalBlockId;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.MemoryLocation;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.collection.unsafe.map.BytesToBytesMap;
+import org.apache.spark.util.Utils;
+
+/**
+ * Unsafe sort-based external aggregation.
+ */
+public final class UnsafeExternalAggregation {
+
+ private final Logger logger = LoggerFactory.getLogger(UnsafeExternalAggregation.class);
+
+ /**
+ * Special record length that is placed after the last record in a data page.
+ */
+ private static final int END_OF_PAGE_MARKER = -1;
+
+ private final TaskMemoryManager memoryManager;
+
+ private final ShuffleMemoryManager shuffleMemoryManager;
+
+ private final BlockManager blockManager;
+
+ private final TaskContext taskContext;
+
+ private ShuffleWriteMetrics writeMetrics;
+
+ /**
+ * The buffer size to use when writing spills using DiskBlockObjectWriter
+ */
+ private final int fileBufferSizeBytes;
+
+ /**
+ * A linked list for tracking all allocated data pages so that we can free all of our memory.
+ */
+ private final List dataPages = new LinkedList();
+
+ /**
+ * The data page that will be used to store keys and values for new hashtable entries. When this
+ * page becomes full, a new page will be allocated and this pointer will change to point to that
+ * new page.
+ */
+ private MemoryBlock currentDataPage = null;
+
+ /**
+ * Offset into `currentDataPage` that points to the location where new data can be inserted into
+ * the page. This does not incorporate the page's base offset.
+ */
+ private long pageCursor = 0;
+
+ private long freeSpaceInCurrentPage = 0;
+
+ /**
+ * The size of the data pages that hold key and value data. Map entries cannot span multiple
+ * pages, so this limits the maximum entry size.
+ */
+ private static final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes
+
+ private final int initialCapacity;
+
+ /**
+ * A hashmap which maps from opaque byteArray keys to byteArray values.
+ */
+ private BytesToBytesMap map;
+
+ /**
+ * Re-used pointer to the current aggregation buffer
+ */
+ private final UnsafeRow currentBuffer = new UnsafeRow();
+
+ private final JoinedRow joinedRow = new JoinedRow();
+
+ private MutableProjection algebraicUpdateProjection;
+
+ private MutableProjection algebraicMergeProjection;
+
+ private AggregateFunction2[] nonAlgebraicAggregateFunctions;
+
+ /**
+ * An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the
+ * map, we copy this buffer and use it as the value.
+ */
+ private final byte[] emptyAggregationBuffer;
+
+ private final StructType aggregationBufferSchema;
+
+ private final StructType groupingKeySchema;
+
+ private final UnsafeProjection groupingKeyProjection;
+
+ /**
+ * Encodes grouping keys or buffers as UnsafeRows.
+ */
+ private final Ordering groupingKeyOrdering;
+
+ private boolean enablePerfMetrics;
+
+ private int testSpillFrequency = 0;
+
+ private long numRowsInserted = 0;
+
+ private final LinkedList spillWriters = new LinkedList<>();
+
+ public UnsafeExternalAggregation(
+ TaskMemoryManager memoryManager,
+ ShuffleMemoryManager shuffleMemoryManager,
+ BlockManager blockManager,
+ TaskContext taskContext,
+ MutableProjection algebraicUpdateProjection,
+ MutableProjection algebraicMergeProjection,
+ AggregateFunction2[] nonAlgebraicAggregateFunctions,
+ InternalRow emptyAggregationBuffer,
+ StructType aggregationBufferSchema,
+ StructType groupingKeySchema,
+ Ordering groupingKeyOrdering,
+ int initialCapacity,
+ SparkConf conf,
+ boolean enablePerfMetrics) throws IOException {
+ this.memoryManager = memoryManager;
+ this.shuffleMemoryManager = shuffleMemoryManager;
+ this.blockManager = blockManager;
+ this.taskContext = taskContext;
+ this.algebraicUpdateProjection = algebraicUpdateProjection;
+ this.algebraicMergeProjection = algebraicMergeProjection;
+ this.nonAlgebraicAggregateFunctions = nonAlgebraicAggregateFunctions;
+ this.aggregationBufferSchema = aggregationBufferSchema;
+ this.groupingKeySchema = groupingKeySchema;
+ this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema);
+ this.groupingKeyOrdering = groupingKeyOrdering;
+ this.initialCapacity = initialCapacity;
+ this.enablePerfMetrics = enablePerfMetrics;
+
+ // Initialize the buffer for aggregation value
+ final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema);
+ this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes();
+ assert(this.emptyAggregationBuffer.length == aggregationBufferSchema.length() * 8 +
+ UnsafeRow.calculateBitSetWidthInBytes(aggregationBufferSchema.length()));
+
+ // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
+ this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
+ this.testSpillFrequency = conf.getInt("spark.test.aggregate.spillFrequency", 0);
+
+ initializeUnsafeAppendMap();
+ }
+
+ /**
+ * Allocates new sort data structures. Called when creating the sorter and after each spill.
+ */
+ private void initializeUnsafeAppendMap() throws IOException {
+ this.writeMetrics = new ShuffleWriteMetrics();
+
+ int capacity = BytesToBytesMap.getCapacity(initialCapacity);
+ long memoryRequested = BytesToBytesMap.getMemoryUsage(capacity);
+ final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested);
+ if (memoryAcquired != memoryRequested) {
+ shuffleMemoryManager.release(memoryAcquired);
+ throw new IOException("Could not acquire " + memoryRequested + " bytes of memory");
+ }
+ this.map = new BytesToBytesMap(memoryManager, capacity, enablePerfMetrics);
+ }
+
+ /**
+ * Return true if it can external aggregate with the groupKey schema & aggregationBuffer schema,
+ * false otherwise
+ */
+ public static boolean supportSchema(StructType groupKeySchema,
+ StructType aggregationBufferSchema) {
+ for (StructField field : groupKeySchema.fields()) {
+ if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) {
+ return false;
+ }
+ }
+ for (StructField field : aggregationBufferSchema.fields()) {
+ if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /**
+ * Forces spills to occur every `frequency` records. Only for use in tests.
+ */
+ @VisibleForTesting
+ public void setTestSpillFrequency(int frequency) {
+ assert frequency > 0 : "Frequency must be positive";
+ testSpillFrequency = frequency;
+ }
+
+ public void insertRow(InternalRow groupingKey, InternalRow currentRow)
+ throws IOException {
+ UnsafeRow unsafeGroupingKeyRow = this.groupingKeyProjection.apply(groupingKey);
+ numRowsInserted++;
+ if (testSpillFrequency > 0 && (numRowsInserted % testSpillFrequency) == 0) {
+ spill();
+ }
+ UnsafeRow aggregationBuffer = this.getAggregationBuffer(unsafeGroupingKeyRow);
+
+ // Process all algebraic aggregate functions.
+ this.algebraicUpdateProjection.target(aggregationBuffer).apply(
+ joinedRow.apply(aggregationBuffer, currentRow));
+ // Process all non-algebraic aggregate functions.
+ int i = 0;
+ while (i < nonAlgebraicAggregateFunctions.length) {
+ nonAlgebraicAggregateFunctions[i].update(aggregationBuffer, currentRow);
+ i += 1;
+ }
+ }
+
+ /**
+ * Return the aggregation buffer for the current group. For efficiency, all calls to this method
+ * return the same object.
+ */
+ public UnsafeRow getAggregationBuffer(UnsafeRow unsafeGroupingKeyRow)
+ throws IOException {
+ // Probe our map using the serialized key
+ final BytesToBytesMap.Location loc = map.lookup(
+ unsafeGroupingKeyRow.getBaseObject(),
+ unsafeGroupingKeyRow.getBaseOffset(),
+ unsafeGroupingKeyRow.getSizeInBytes());
+ if (!loc.isDefined()) {
+ if (!this.putNewKey(
+ unsafeGroupingKeyRow.getBaseObject(),
+ unsafeGroupingKeyRow.getBaseOffset(),
+ unsafeGroupingKeyRow.getSizeInBytes(),
+ emptyAggregationBuffer,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ emptyAggregationBuffer.length,
+ loc)) {
+ // because spill makes putting new key failed, it should get AggregationBuffer again
+ return this.getAggregationBuffer(unsafeGroupingKeyRow);
+ }
+ }
+ // Reset the pointer to point to the value that we just stored or looked up:
+ final MemoryLocation address = loc.getValueAddress();
+ currentBuffer.pointTo(
+ address.getBaseObject(),
+ address.getBaseOffset(),
+ aggregationBufferSchema.length(),
+ loc.getValueLength()
+ );
+ return currentBuffer;
+ }
+
+ public boolean putNewKey(
+ Object keyBaseObject,
+ long keyBaseOffset,
+ int keyLengthBytes,
+ Object valueBaseObject,
+ long valueBaseOffset,
+ int valueLengthBytes,
+ BytesToBytesMap.Location location) throws IOException {
+ assert (!location.isDefined()) : "Can only set value once for a key";
+
+ assert (keyLengthBytes % 8 == 0);
+ assert (valueLengthBytes % 8 == 0);
+
+ // Here, we'll copy the data into our data pages. Because we only store a relative offset from
+ // the key address instead of storing the absolute address of the value, the key and value
+ // must be stored in the same memory page.
+ // (8 byte key length) (key) (8 byte value length) (value)
+ final int requiredSize = 8 + keyLengthBytes + 8 + valueLengthBytes;
+ if (!haveSpaceForRecord(requiredSize)) {
+ if (!allocateSpaceForRecord(requiredSize)){
+ // if spill have been happened, re-insert current groupingKey
+ return false;
+ }
+ }
+
+ freeSpaceInCurrentPage -= requiredSize;
+ // Compute all of our offsets up-front:
+ final Object pageBaseObject = currentDataPage.getBaseObject();
+ final long pageBaseOffset = currentDataPage.getBaseOffset();
+ final long keySizeOffsetInPage = pageBaseOffset + pageCursor;
+ pageCursor += 8; // word used to store the key size
+ final long keyDataOffsetInPage = pageBaseOffset + pageCursor;
+ pageCursor += keyLengthBytes;
+ final long valueSizeOffsetInPage = pageBaseOffset + pageCursor;
+ pageCursor += 8; // word used to store the value size
+ final long valueDataOffsetInPage = pageBaseOffset + pageCursor;
+ pageCursor += valueLengthBytes;
+
+ // Copy the key
+ PlatformDependent.UNSAFE.putLong(pageBaseObject, keySizeOffsetInPage, keyLengthBytes);
+ PlatformDependent.copyMemory(
+ keyBaseObject, keyBaseOffset, pageBaseObject, keyDataOffsetInPage, keyLengthBytes);
+ // Copy the value
+ PlatformDependent.UNSAFE.putLong(pageBaseObject, valueSizeOffsetInPage, valueLengthBytes);
+ PlatformDependent.copyMemory(
+ valueBaseObject, valueBaseOffset, pageBaseObject, valueDataOffsetInPage, valueLengthBytes);
+
+ final long storedKeyAddress = memoryManager.encodePageNumberAndOffset(
+ currentDataPage, keySizeOffsetInPage);
+ location.putNewKey(storedKeyAddress);
+ return true;
+ }
+
+ /**
+ * Checks whether there is enough space to insert a new record into the sorter.
+ *
+ * @param requiredSpace the required space in the data page, in bytes, including space for storing
+ * the record size.
+
+ * @return true if the record can be inserted without requiring more allocations, false otherwise.
+ */
+ private boolean haveSpaceForRecord(int requiredSpace) {
+ assert (requiredSpace > 0);
+ return (map.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage));
+ }
+
+ /**
+ * Allocates more memory in order to insert an additional record. This will request additional
+ * memory from the {@link org.apache.spark.shuffle.ShuffleMemoryManager} and
+ * spill if the requested memory can not be obtained.
+ *
+ * @param requiredSpace the required space in the data page, in bytes, including space for storing
+ * the record size.
+ */
+ private boolean allocateSpaceForRecord(int requiredSpace) throws IOException {
+ boolean noSpill = true;
+ assert (requiredSpace <= PAGE_SIZE_BYTES - 8); // Reserve 8 bytes for the end-of-page marker.
+ if (!map.hasSpaceForAnotherRecord()) {
+ logger.debug("Attempting to grow size of hash table");
+ final int nextCapacity = map.getNextCapacity();
+ final long oldPointerArrayMemoryUsage = map.getTotalMemoryConsumption();
+ final long memoryToGrowPointerArray = BytesToBytesMap.getMemoryUsage(nextCapacity);
+ final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray);
+ if (memoryAcquired < memoryToGrowPointerArray) {
+ shuffleMemoryManager.release(memoryAcquired);
+ spill();
+ } else {
+ map.growAndRehash();
+ shuffleMemoryManager.release(oldPointerArrayMemoryUsage);
+ }
+ }
+
+ // If there's not enough space in the current page, allocate a new page (8 bytes are reserved
+ // for the end-of-page marker).
+ if (currentDataPage == null || freeSpaceInCurrentPage < requiredSpace) {
+ logger.trace("Required space {} is less than free space in current page ({})",
+ requiredSpace, freeSpaceInCurrentPage);
+ if (currentDataPage != null) {
+ // There wasn't enough space in the current page, so write an end-of-page marker:
+ final Object pageBaseObject = currentDataPage.getBaseObject();
+ final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor;
+ PlatformDependent.UNSAFE.putLong(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER);
+ }
+ long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE_BYTES);
+ if (memoryAcquired < PAGE_SIZE_BYTES) {
+ shuffleMemoryManager.release(memoryAcquired);
+ spill();
+ noSpill = false;
+ final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE_BYTES);
+ if (memoryAcquiredAfterSpilling != PAGE_SIZE_BYTES) {
+ shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
+ throw new IOException("Unable to acquire " + PAGE_SIZE_BYTES + " bytes of memory");
+ } else {
+ memoryAcquired = memoryAcquiredAfterSpilling;
+ }
+ }
+ MemoryBlock newPage = memoryManager.allocatePage(memoryAcquired);
+ dataPages.add(newPage);
+ pageCursor = 0;
+ freeSpaceInCurrentPage = PAGE_SIZE_BYTES - 8;
+ currentDataPage = newPage;
+ }
+ return noSpill;
+ }
+
+ private static final class SortComparator implements Comparator {
+
+ private final TaskMemoryManager memoryManager;
+ private final Ordering ordering;
+ private final int numFields;
+ private final UnsafeRow row1 = new UnsafeRow();
+ private final UnsafeRow row2 = new UnsafeRow();
+
+ SortComparator(TaskMemoryManager memoryManager, Ordering ordering, int numFields) {
+ this.memoryManager = memoryManager;
+ this.numFields = numFields;
+ this.ordering = ordering;
+ }
+
+ @Override
+ public int compare(Long fullKeyAddress1, Long fullKeyAddress2) {
+ final Object baseObject1 = memoryManager.getPage(fullKeyAddress1);
+ final long baseOffset1 = memoryManager.getOffsetInPage(fullKeyAddress1) + 8;
+
+ final Object baseObject2 = memoryManager.getPage(fullKeyAddress2);
+ final long baseOffset2 = memoryManager.getOffsetInPage(fullKeyAddress2) + 8;
+
+ row1.pointTo(baseObject1, baseOffset1, numFields, -1);
+ row2.pointTo(baseObject2, baseOffset2, numFields, -1);
+ return ordering.compare(row1, row2);
+ }
+ }
+
+ /**
+ * Sort and spill the current records in response to memory pressure.
+ */
+ @VisibleForTesting
+ void spill() throws IOException {
+ logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
+ Thread.currentThread().getId(),
+ Utils.bytesToString(getMemoryUsage()),
+ spillWriters.size(),
+ spillWriters.size() > 1 ? " times" : " time");
+
+ final UnsafeSorterKVSpillWriter spillWriter =
+ new UnsafeSorterKVSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics,
+ map.size());
+ spillWriters.add(spillWriter);
+ final Iterator sortedRecords = map.getSortedIterator(
+ new SortComparator(this.memoryManager, groupingKeyOrdering, groupingKeySchema.size()));
+ while (sortedRecords.hasNext()) {
+ BytesToBytesMap.Location location = sortedRecords.next();
+ spillWriter.write(location);
+ }
+ spillWriter.close();
+
+ final long spillSize = freeMemory();
+ taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
+
+ initializeUnsafeAppendMap();
+ }
+
+ public AbstractScalaIterator getSortedIterator() {
+ return new AbstractScalaIterator() {
+
+ Iterator sorter = map.getSortedIterator(
+ new SortComparator(memoryManager, groupingKeyOrdering, groupingKeySchema.size()));
+
+ @Override
+ public boolean hasNext() {
+ return sorter.hasNext();
+ }
+
+ @Override
+ public BytesToBytesMap.Location next() {
+ return sorter.next();
+ }
+ };
+ }
+
+ /**
+ * Return an iterator that merges the in-memory map with the spilled files.
+ * If no spill has occurred, simply return the in-memory map's iterator.
+ */
+ public AbstractScalaIterator iterator() throws IOException {
+ if (spillWriters.isEmpty()) {
+ return this.getMemoryIterator();
+ } else {
+ return this.merge(this.getSortedIterator());
+ }
+ }
+
+ public UnsafeRow getKey(BytesToBytesMap.Location location) {
+ UnsafeRow key = new UnsafeRow();
+ key.pointTo(
+ location.getKeyBaseObject(),
+ location.getKeyBaseOffset(),
+ 1,
+ location.getKeyLength());
+ return key;
+ }
+
+ public UnsafeRow getValue(BytesToBytesMap.Location location) {
+ UnsafeRow value = new UnsafeRow();
+ value.pointTo(
+ location.getValueBaseObject(),
+ location.getValueBaseOffset(),
+ 1,
+ location.getValueLength());
+ return value;
+ }
+
+ /**
+ * Returns an iterator over the keys and values in in-memory map.
+ */
+ public AbstractScalaIterator getMemoryIterator() {
+ return new BytesToBytesMapIterator(map.size());
+ }
+
+ /**
+ * Merge aggregate of the in-memory map with the spilled files, giving an iterator over elements.
+ */
+ private AbstractScalaIterator merge(
+ AbstractScalaIterator inMemory) throws IOException {
+
+ final Comparator keyOrdering =
+ new Comparator() {
+ private final UnsafeRow row1 = new UnsafeRow();
+ private final UnsafeRow row2 = new UnsafeRow();
+
+ public int compare(BufferedIterator o1, BufferedIterator o2){
+ row1.pointTo(
+ o1.getRecordLocation().getKeyBaseObject(),
+ o1.getRecordLocation().getKeyBaseOffset(),
+ groupingKeySchema.size(),
+ o1.getRecordLocation().getKeyLength());
+ row2.pointTo(
+ o2.getRecordLocation().getKeyBaseObject(),
+ o2.getRecordLocation().getKeyBaseOffset(),
+ groupingKeySchema.size(),
+ o2.getRecordLocation().getKeyLength());
+ return groupingKeyOrdering.compare(row1, row2);
+ }
+ };
+ final Queue priorityQueue =
+ new PriorityQueue(spillWriters.size() + 1, keyOrdering);
+ BufferedIterator inMemoryBuffer = this.asBuffered(inMemory);
+ if (inMemoryBuffer.hasNext()) {
+ inMemoryBuffer.loadNext();
+ priorityQueue.add(inMemoryBuffer);
+ }
+
+ for (int i = 0; i < spillWriters.size(); i++) {
+ BufferedIterator spillBuffer = this.asBuffered(spillWriters.get(i).getReader(blockManager));
+ if (spillBuffer.hasNext()) {
+ spillBuffer.loadNext();
+ priorityQueue.add(spillBuffer);
+ }
+ }
+ final AbstractScalaIterator mergeIter =
+ new AbstractScalaIterator() {
+
+ BufferedIterator topIter = null;
+
+ @Override
+ public boolean hasNext() {
+ return !priorityQueue.isEmpty() || (topIter != null && topIter.hasNext());
+ }
+
+ @Override
+ public BytesToBytesMap.Location next() throws IOException {
+ if (topIter != null && topIter.hasNext()) {
+ topIter.loadNext();
+ priorityQueue.add(topIter);
+ }
+ topIter = priorityQueue.poll();
+ return topIter.getRecordLocation();
+ }
+ };
+
+ final BufferedIterator sorted = asBuffered(mergeIter);
+ return new AbstractScalaIterator() {
+
+ private UnsafeRow currentKey = new UnsafeRow();
+ private UnsafeRow currentValue = new UnsafeRow();
+ private UnsafeRow nextKey = new UnsafeRow();
+ private UnsafeRow nextValue = new UnsafeRow();
+ private BytesToBytesMap.Location currentLocation = null;
+
+ @Override
+ public boolean hasNext() {
+ return currentLocation != null || sorted.hasNext();
+ }
+
+ @Override
+ public MapEntry next() throws IOException {
+ try {
+ if (currentLocation == null) {
+ sorted.loadNext();
+ currentLocation = sorted.getRecordLocation();
+ }
+ currentKey.pointTo(
+ currentLocation.getKeyBaseObject(),
+ currentLocation.getKeyBaseOffset(),
+ groupingKeySchema.size(),
+ currentLocation.getKeyLength());
+ currentKey = currentKey.copy();
+ currentValue.pointTo(
+ currentLocation.getValueBaseObject(),
+ currentLocation.getValueBaseOffset(),
+ aggregationBufferSchema.size(),
+ currentLocation.getValueLength());
+ currentValue = currentValue.copy();
+ currentLocation = null;
+ while (sorted.hasNext()) {
+ sorted.loadNext();
+ BytesToBytesMap.Location nextLocation = sorted.getRecordLocation();
+ nextKey.pointTo(
+ nextLocation.getKeyBaseObject(),
+ nextLocation.getKeyBaseOffset(),
+ groupingKeySchema.size(),
+ nextLocation.getKeyLength());
+ nextValue.pointTo(
+ nextLocation.getValueBaseObject(),
+ nextLocation.getValueBaseOffset(),
+ aggregationBufferSchema.size(),
+ nextLocation.getValueLength());
+
+ if (groupingKeyOrdering.compare(currentKey, nextKey) != 0) {
+ currentLocation = nextLocation;
+ break;
+ }
+
+ // Process all algebraic aggregate functions.
+ algebraicMergeProjection.target(currentValue).apply(
+ joinedRow.apply(currentValue, nextValue));
+ // Process all non-algebraic aggregate functions.
+ int i = 0;
+ while (i < nonAlgebraicAggregateFunctions.length) {
+ nonAlgebraicAggregateFunctions[i].merge(currentValue, nextValue);
+ i += 1;
+ }
+ }
+
+ return new MapEntry(currentKey, currentValue);
+ } catch (IOException e) {
+ cleanupResources();
+ // Scala iterators don't declare any checked exceptions, so we need to use this hack
+ // to re-throw the exception:
+ PlatformDependent.throwException(e);
+ }
+ throw new RuntimeException("Exception should have been re-thrown in next()");
+ }
+ };
+ }
+
+ public BufferedIterator asBuffered(AbstractScalaIterator iterator) {
+ return new BufferedIterator(iterator);
+ }
+
+ public interface AbstractScalaIterator {
+
+ public abstract boolean hasNext();
+
+ public abstract E next() throws IOException;
+
+ }
+
+ public class BufferedIterator {
+
+ private BytesToBytesMap.Location location = null;
+ private AbstractScalaIterator iterator;
+
+ public BufferedIterator(AbstractScalaIterator iterator) {
+ this.iterator = iterator;
+ }
+
+ public boolean hasNext() {
+ return iterator.hasNext();
+ }
+
+ public void loadNext() throws IOException {
+ location = iterator.next();
+ }
+
+ public BytesToBytesMap.Location getRecordLocation() {
+ return location;
+ }
+ }
+
+ /**
+ * Mutable pair object
+ */
+ public class MapEntry {
+
+ public UnsafeRow key;
+ public UnsafeRow value;
+
+ public MapEntry() {
+ this.key = new UnsafeRow();
+ this.value = new UnsafeRow();
+ }
+
+ public MapEntry(UnsafeRow key, UnsafeRow value) {
+ this.key = key;
+ this.value = value;
+ }
+ }
+
+
+ public class UnsafeSorterKVSpillWriter {
+
+ static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
+
+ // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to
+ // be an API to directly transfer bytes from managed memory to the disk writer, we buffer
+ // data through a byte array.
+ private byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE];
+
+ private final File file;
+ private final BlockId blockId;
+ private final int numRecordsToWrite;
+ private DiskBlockObjectWriter writer;
+ private int numRecordsSpilled = 0;
+
+ public UnsafeSorterKVSpillWriter(
+ BlockManager blockManager,
+ int fileBufferSize,
+ ShuffleWriteMetrics writeMetrics,
+ int numRecordsToWrite) throws IOException {
+ final Tuple2 spilledFileInfo =
+ blockManager.diskBlockManager().createTempLocalBlock();
+ this.file = spilledFileInfo._2();
+ this.blockId = spilledFileInfo._1();
+ this.numRecordsToWrite = numRecordsToWrite;
+ // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter.
+ // Our write path doesn't actually use this serializer (since we end up calling the `write()`
+ // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work
+ // around this, we pass a dummy no-op serializer.
+ writer = blockManager.getDiskWriter(
+ blockId, file, DummySerializerInstance.INSTANCE, fileBufferSize, writeMetrics);
+ // Write the number of records
+ writeIntToBuffer(numRecordsToWrite, 0);
+ writer.write(writeBuffer, 0, 4);
+ }
+
+ // Based on DataOutputStream.writeInt.
+ private void writeIntToBuffer(int v, int offset) throws IOException {
+ writeBuffer[offset + 0] = (byte)(v >>> 24);
+ writeBuffer[offset + 1] = (byte)(v >>> 16);
+ writeBuffer[offset + 2] = (byte)(v >>> 8);
+ writeBuffer[offset + 3] = (byte)(v >>> 0);
+ }
+
+ public void write(BytesToBytesMap.Location loc) throws IOException {
+ if (numRecordsSpilled == numRecordsToWrite) {
+ throw new IllegalStateException(
+ "Number of records written exceeded numRecordsToWrite = " + numRecordsToWrite);
+ } else {
+ numRecordsSpilled++;
+ }
+ this.write(loc.getKeyAddress().getBaseObject(), loc.getKeyAddress().getBaseOffset(),
+ loc.getKeyLength());
+ this.write(loc.getValueAddress().getBaseObject(), loc.getValueAddress().getBaseOffset(),
+ loc.getValueLength());
+ writer.recordWritten();
+ }
+
+
+ /**
+ * Write a record to a spill file.
+ *
+ * @param baseObject the base object / memory page containing the record
+ * @param baseOffset the base offset which points directly to the record data.
+ * @param recordLength the length of the record.
+ */
+ public void write(
+ Object baseObject,
+ long baseOffset,
+ int recordLength) throws IOException {
+ writeIntToBuffer(recordLength, 0);
+ int dataRemaining = recordLength;
+ int freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE - 4; // space used by len
+ long recordReadPosition = baseOffset;
+ while (dataRemaining > 0) {
+ final int toTransfer = Math.min(freeSpaceInWriteBuffer, dataRemaining);
+ PlatformDependent.copyMemory(
+ baseObject,
+ recordReadPosition,
+ writeBuffer,
+ PlatformDependent.BYTE_ARRAY_OFFSET + (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer),
+ toTransfer);
+ writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer) + toTransfer);
+ recordReadPosition += toTransfer;
+ dataRemaining -= toTransfer;
+ freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE;
+ }
+ if (freeSpaceInWriteBuffer < DISK_WRITE_BUFFER_SIZE) {
+ writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer));
+ }
+ }
+
+ public void close() throws IOException {
+ writer.commitAndClose();
+ writer = null;
+ writeBuffer = null;
+ }
+
+ public AbstractScalaIterator getReader(BlockManager blockManager)
+ throws IOException {
+ return new UnsafeSorterKVSpillReader(blockManager, file, blockId);
+ }
+ }
+
+ final class UnsafeSorterKVSpillReader implements AbstractScalaIterator {
+
+ private InputStream in;
+ private DataInputStream din;
+
+ // Variables that change with every kv read:
+ private int numRecordsRemaining;
+
+ private int keyLength;
+ private byte[] keyArray = new byte[1024 * 1024];
+
+ private int valueLength;
+ private byte[] valueArray = new byte[1024 * 1024];
+ private final BytesToBytesMap.Location location = map.getNewLocation();
+
+ public UnsafeSorterKVSpillReader(
+ BlockManager blockManager,
+ File file,
+ BlockId blockId) throws IOException {
+ assert (file.length() > 0);
+ final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file));
+ this.in = blockManager.wrapForCompression(blockId, bs);
+ this.din = new DataInputStream(this.in);
+ numRecordsRemaining = din.readInt();
+ }
+
+ @Override
+ public boolean hasNext() {
+ return (numRecordsRemaining > 0);
+ }
+
+ @Override
+ public BytesToBytesMap.Location next() throws IOException {
+ keyLength = din.readInt();
+ if (keyLength > keyArray.length) {
+ keyArray = new byte[keyLength];
+ }
+ ByteStreams.readFully(in, keyArray, 0, keyLength);
+ valueLength = din.readInt();
+ if (valueLength > valueArray.length) {
+ valueArray = new byte[valueLength];
+ }
+ ByteStreams.readFully(in, valueArray, 0, valueLength);
+ numRecordsRemaining--;
+ if (numRecordsRemaining == 0) {
+ in.close();
+ in = null;
+ din = null;
+ }
+ location.with(keyLength, keyArray, valueLength, valueArray);
+ return location;
+ }
+ }
+
+
+ public class BytesToBytesMapIterator implements AbstractScalaIterator {
+
+ private final MapEntry entry = new MapEntry();
+ private final BytesToBytesMap.Location loc = map.getNewLocation();
+ private int currentRecordNumber = 0;
+ private Object pageBaseObject;
+ private long offsetInPage;
+ private int numRecords;
+
+ public BytesToBytesMapIterator(int numRecords) {
+ this.numRecords = numRecords;
+ if (dataPages.iterator().hasNext()) {
+ advanceToNextPage();
+ }
+ }
+
+ private void advanceToNextPage() {
+ final MemoryBlock currentPage = dataPages.iterator().next();
+ pageBaseObject = currentPage.getBaseObject();
+ offsetInPage = currentPage.getBaseOffset();
+ }
+
+ @Override
+ public boolean hasNext() {
+ return currentRecordNumber != numRecords;
+ }
+
+ @Override
+ public MapEntry next() {
+ int keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage);
+ if (keyLength == END_OF_PAGE_MARKER) {
+ advanceToNextPage();
+ keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage);
+ }
+ loc.with(pageBaseObject, offsetInPage);
+ offsetInPage += 8 + 8 + keyLength + loc.getValueLength();
+
+ MemoryLocation keyAddress = loc.getKeyAddress();
+ MemoryLocation valueAddress = loc.getValueAddress();
+ entry.key.pointTo(
+ keyAddress.getBaseObject(),
+ keyAddress.getBaseOffset(),
+ groupingKeySchema.length(),
+ loc.getKeyLength()
+ );
+ entry.value.pointTo(
+ valueAddress.getBaseObject(),
+ valueAddress.getBaseOffset(),
+ aggregationBufferSchema.length(),
+ loc.getValueLength()
+ );
+ currentRecordNumber++;
+ return entry;
+ }
+ }
+
+ private long getMemoryUsage() {
+ return map.getTotalMemoryConsumption() + (dataPages.size() * PAGE_SIZE_BYTES);
+ }
+
+ private void cleanupResources() {
+ this.freeMemory();
+ }
+
+ /**
+ * Free the unsafe memory associated with this map.
+ */
+ public long freeMemory() {
+ long memoryFreed = 0;
+ for (MemoryBlock block : dataPages) {
+ memoryManager.freePage(block);
+ shuffleMemoryManager.release(block.size());
+ memoryFreed += block.size();
+ }
+ if (map != null) {
+ long sorterMemoryUsage = map.getTotalMemoryConsumption();
+ map.free();
+ map = null;
+ shuffleMemoryManager.release(sorterMemoryUsage);
+ memoryFreed += sorterMemoryUsage;
+ }
+ dataPages.clear();
+ currentDataPage = null;
+ pageCursor = 0;
+ freeSpaceInCurrentPage = 0;
+ return memoryFreed;
+ }
+
+ @SuppressWarnings("UseOfSystemOutOrSystemErr")
+ public void printPerfMetrics() {
+ if (!enablePerfMetrics) {
+ throw new IllegalStateException("Perf metrics not enabled");
+ }
+ System.out.println("Average probes per lookup: " + map.getAverageProbesPerLookup());
+ System.out.println("Number of hash collisions: " + map.getNumHashCollisions());
+ System.out.println("Time spent resizing (ns): " + map.getTimeSpentResizingNs());
+ System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption());
+ }
+
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index cc89d74146b34..21f34a0009be9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -22,6 +22,24 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjecti
import org.apache.spark.sql.types.{StructType, DataType}
import org.apache.spark.unsafe.types.UTF8String
+/**
+ * Converts a [[InternalRow]] to another Row given a sequence of expression that define each
+ * column of the new row. If the schema of the input row is specified, then the given expression
+ * will be bound to that schema.
+ *
+ * In contrast to a normal projection, a MutableProjection reuses the same underlying row object
+ * each time an input row is added. This significantly reduces the cost of calculating the
+ * projection, but means that it is not safe to hold on to a reference to a [[InternalRow]] after
+ * `next()` has been called on the [[Iterator]] that produced it. Instead, the user must call
+ * `InternalRow.copy()` and hold on to the returned [[InternalRow]] before calling `next()`.
+ */
+abstract class MutableProjection extends Projection {
+ def currentValue: InternalRow
+
+ /** Uses the given row to store the output of the projection. */
+ def target(row: MutableRow): MutableProjection
+}
+
/**
* A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions.
* @param expressions a sequence of expressions that determine the value of each column of the
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
index 30b7f8d3766a5..aba2c7bb4f40c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
@@ -63,21 +63,4 @@ package object expressions {
*/
abstract class Projection extends (InternalRow => InternalRow)
- /**
- * Converts a [[InternalRow]] to another Row given a sequence of expression that define each
- * column of the new row. If the schema of the input row is specified, then the given expression
- * will be bound to that schema.
- *
- * In contrast to a normal projection, a MutableProjection reuses the same underlying row object
- * each time an input row is added. This significantly reduces the cost of calculating the
- * projection, but means that it is not safe to hold on to a reference to a [[InternalRow]] after
- * `next()` has been called on the [[Iterator]] that produced it. Instead, the user must call
- * `InternalRow.copy()` and hold on to the returned [[InternalRow]] before calling `next()`.
- */
- abstract class MutableProjection extends Projection {
- def currentValue: InternalRow
-
- /** Uses the given row to store the output of the projection. */
- def target(row: MutableRow): MutableProjection
- }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 2a641b9d64a95..1dbf85a1a16b6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -401,6 +401,9 @@ private[spark] object SQLConf {
val USE_SQL_AGGREGATE2 = booleanConf("spark.sql.useAggregate2",
defaultValue = Some(true), doc = "")
+ val USE_HYBRID_AGGREGATE = booleanConf("spark.sql.useHybridAggregate",
+ defaultValue = Some(false), doc = "")
+
val USE_SQL_SERIALIZER2 = booleanConf(
"spark.sql.useSerializer2",
defaultValue = Some(true), isPublic = false)
@@ -474,6 +477,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2)
+ private[spark] def useHybridAggregate: Boolean = getConf(USE_HYBRID_AGGREGATE)
+
private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2)
private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index eb4be1900b153..459111a0554a6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.{SQLContext, Strategy, execution}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2}
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan}
@@ -221,7 +221,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
expr.collect {
case agg: AggregateExpression2 => agg
}
- }.toSet.toSeq
+ }
// For those distinct aggregate expressions, we create a map from the
// aggregate function to the corresponding attribute of the function.
val aggregateFunctionMap = aggregateExpressions.map { agg =>
@@ -247,6 +247,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
aggregateExpressions,
aggregateFunctionMap,
resultExpressions,
+ sqlContext.conf.useHybridAggregate,
planLater(child))
} else {
aggregate.Utils.planAggregateWithOneDistinct(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
new file mode 100644
index 0000000000000..0d7732fd8ab8e
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
@@ -0,0 +1,112 @@
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+
+import scala.collection.mutable.ArrayBuffer
+
+private[sql] abstract class AggregationIterator(
+ groupingExpressions: Seq[NamedExpression],
+ aggregateExpressions: Seq[AggregateExpression2],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+ inputAttributes: Seq[Attribute],
+ inputIter: Iterator[InternalRow]){
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Static fields for this iterator
+ ///////////////////////////////////////////////////////////////////////////
+
+ protected val aggregateFunctions: Array[AggregateFunction2] = {
+ var bufferOffset = initialBufferOffset
+ val functions = new Array[AggregateFunction2](aggregateExpressions.length)
+ var i = 0
+ while (i < aggregateExpressions.length) {
+ val func = aggregateExpressions(i).aggregateFunction
+ val funcWithBoundReferences = aggregateExpressions(i).mode match {
+ case Partial | Complete if !func.isInstanceOf[AlgebraicAggregate] =>
+ // We need to create BoundReferences if the function is not an
+ // AlgebraicAggregate (it does not support code-gen) and the mode of
+ // this function is Partial or Complete because we will call eval of this
+ // function's children in the update method of this aggregate function.
+ // Those eval calls require BoundReferences to work.
+ BindReferences.bindReference(func, inputAttributes)
+ case _ => func
+ }
+ // Set bufferOffset for this function. It is important that setting bufferOffset
+ // happens after all potential bindReference operations because bindReference
+ // will create a new instance of the function.
+ funcWithBoundReferences.bufferOffset = bufferOffset
+ bufferOffset += funcWithBoundReferences.bufferSchema.length
+ functions(i) = funcWithBoundReferences
+ i += 1
+ }
+ functions
+ }
+
+ // Positions of those non-algebraic aggregate functions in aggregateFunctions.
+ // For example, we have func1, func2, func3, func4 in aggregateFunctions, and
+ // func2 and func3 are non-algebraic aggregate functions.
+ // nonAlgebraicAggregateFunctionPositions will be [1, 2].
+ protected val nonAlgebraicAggregateFunctionPositions: Array[Int] = {
+ val positions = new ArrayBuffer[Int]()
+ var i = 0
+ while (i < aggregateFunctions.length) {
+ aggregateFunctions(i) match {
+ case agg: AlgebraicAggregate =>
+ case _ => positions += i
+ }
+ i += 1
+ }
+ positions.toArray
+ }
+
+ // All non-algebraic aggregate functions.
+ protected val nonAlgebraicAggregateFunctions: Array[AggregateFunction2] =
+ nonAlgebraicAggregateFunctionPositions.map(aggregateFunctions)
+
+ // The underlying buffer shared by all aggregate functions.
+ protected val buffer: MutableRow = {
+ // The number of elements of the underlying buffer of this operator.
+ // All aggregate functions are sharing this underlying buffer and they find their
+ // buffer values through bufferOffset.
+ var size = initialBufferOffset
+ var i = 0
+ while (i < aggregateFunctions.length) {
+ size += aggregateFunctions(i).bufferSchema.length
+ i += 1
+ }
+ new GenericMutableRow(size)
+ }
+
+ /** Initializes buffer values for all aggregate functions. */
+ protected def initializeBuffer(): Unit = {
+ algebraicInitialProjection(EmptyRow)
+ var i = 0
+ while (i < nonAlgebraicAggregateFunctions.length) {
+ nonAlgebraicAggregateFunctions(i).initialize(buffer)
+ i += 1
+ }
+ }
+
+ protected val joinedRow = new JoinedRow
+
+ // This is used to project expressions for the grouping expressions.
+ protected val groupGenerator =
+ newMutableProjection(groupingExpressions, inputAttributes)()
+
+ protected def initialBufferOffset: Int
+
+ protected val placeholderExpressions = Seq.fill(initialBufferOffset)(NoOp)
+
+ // This projection is used to initialize buffer values for all AlgebraicAggregates.
+ protected val algebraicInitialProjection = {
+ val initExpressions = placeholderExpressions ++ aggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.initialValues
+ case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ }
+ newMutableProjection(initExpressions, Nil)().target(buffer)
+ }
+
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HybridAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HybridAggregationIterator.scala
new file mode 100644
index 0000000000000..92734162f8c65
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HybridAggregationIterator.scala
@@ -0,0 +1,273 @@
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.{Logging, SparkEnv, TaskContext}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering
+import org.apache.spark.sql.execution.UnsafeExternalAggregation
+import org.apache.spark.sql.types.{StructField, StructType, NullType}
+
+/**
+ * An iterator used to do partial aggregations for hybrid aggregate
+ */
+class HybridPartialAggregationIterator(
+ groupingExpressions: Seq[NamedExpression],
+ aggregateExpressions: Seq[AggregateExpression2],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+ inputAttributes: Seq[Attribute],
+ inputIter: Iterator[InternalRow])
+ extends AggregationIterator(
+ groupingExpressions,
+ aggregateExpressions,
+ newMutableProjection,
+ inputAttributes,
+ inputIter)
+ with Logging {
+
+ override protected def initialBufferOffset: Int = 0
+
+ // This projection is used to update buffer values for all AlgebraicAggregates.
+ private val algebraicUpdateProjection = {
+ val bufferSchema = aggregateFunctions.flatMap(_.bufferAttributes)
+ val updateExpressions = aggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.updateExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ }
+ newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)()
+ }
+
+ // This projection is used to merge buffer values for all AlgebraicAggregates.
+ private val algebraicMergeProjection = {
+ val bufferSchemata = aggregateFunctions.flatMap(_.bufferAttributes) ++
+ aggregateFunctions.flatMap(_.cloneBufferAttributes)
+ val mergeExpressions = aggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.mergeExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ }
+
+ newMutableProjection(mergeExpressions, bufferSchemata)()
+ }
+
+ def iterator: Iterator[InternalRow] = {
+
+ val groupingKeyOrdering: Ordering[InternalRow] = GenerateOrdering.generate(
+ groupingExpressions.map(_.dataType).zipWithIndex.map { case(dt, index) =>
+ new SortOrder(BoundReference(index, dt, nullable = true), Ascending)
+ })
+
+ val aggregationBufferSchema: StructType = StructType.fromAttributes(
+ aggregateFunctions.flatMap(_.bufferAttributes)
+ )
+
+ val groupKeySchema: StructType = {
+ val fields = groupingExpressions.zipWithIndex.map { case (expr, idx) =>
+ StructField(idx.toString, expr.dataType, expr.nullable)
+ }
+ StructType(fields)
+ }
+
+ this.initializeBuffer()
+
+ log.info(s"GroupKey Schema: ${groupKeySchema.mkString(",")}")
+ log.info(s"AggregationBuffer Schema: ${aggregationBufferSchema.mkString(",")}")
+
+ val sparkEnv: SparkEnv = SparkEnv.get
+ val taskContext: TaskContext = TaskContext.get
+ val aggregationMap = new UnsafeExternalAggregation(
+ taskContext.taskMemoryManager,
+ sparkEnv.shuffleMemoryManager,
+ sparkEnv.blockManager,
+ taskContext,
+ algebraicUpdateProjection,
+ algebraicMergeProjection,
+ nonAlgebraicAggregateFunctions,
+ buffer,
+ aggregationBufferSchema,
+ groupKeySchema,
+ groupingKeyOrdering,
+ 1024 * 16,
+ sparkEnv.conf,
+ false
+ )
+
+ while (inputIter.hasNext) {
+ val currentRow: InternalRow = inputIter.next()
+ val groupKey: InternalRow = groupGenerator(currentRow).copy()
+ aggregationMap.insertRow(groupKey, currentRow)
+ }
+
+ new Iterator[InternalRow] {
+ private[this] val mapIterator = aggregationMap.iterator()
+
+ def hasNext: Boolean = mapIterator.hasNext
+
+ def next(): InternalRow = {
+ val entry = mapIterator.next()
+ if (hasNext) {
+ joinedRow(entry.key, entry.value)
+ } else {
+ // This is the last element in the iterator, so let's free the buffer. Before we do,
+ // though, we need to make a defensive copy of the result so that we don't return an
+ // object that might contain dangling pointers to the freed memory
+ val resultCopy = joinedRow(entry.key.copy(), entry.value.copy())
+ aggregationMap.freeMemory()
+ resultCopy
+ }
+ }
+ }
+ }
+}
+
+/**
+ * An iterator used to do final aggregations for hybrid aggregate
+ */
+class HybridFinalAggregationIterator(
+ groupingExpressions: Seq[NamedExpression],
+ aggregateExpressions: Seq[AggregateExpression2],
+ aggregateAttributes: Seq[Attribute],
+ resultExpressions: Seq[NamedExpression],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+ inputAttributes: Seq[Attribute],
+ inputIter: Iterator[InternalRow])
+ extends AggregationIterator(
+ groupingExpressions,
+ aggregateExpressions,
+ newMutableProjection,
+ inputAttributes,
+ inputIter)
+ with Logging {
+
+ // The result of aggregate functions.
+ private val aggregateResult: MutableRow = new GenericMutableRow(aggregateAttributes.length)
+
+ // The projection used to generate the output rows of this operator.
+ // This is only used when we are generating final results of aggregate functions.
+ private val resultProjection =
+ newMutableProjection(
+ resultExpressions, groupingExpressions.map(_.toAttribute) ++ aggregateAttributes)()
+
+ protected def initialBufferOffset: Int = 0
+
+ private val offsetAttributes =
+ Seq.fill(groupingExpressions.length)(AttributeReference("placeholder", NullType)())
+
+ // This projection is used to partial merge buffer values for all AlgebraicAggregates.
+ private val algebraicPartialMergeProjection = {
+ val bufferSchema = aggregateFunctions.flatMap(_.bufferAttributes) ++
+ offsetAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes)
+ val mergeExpressions = aggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.mergeExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ }
+
+ newMutableProjection(mergeExpressions, bufferSchema)()
+ }
+
+ // This projection is used to merge buffer values for all AlgebraicAggregates.
+ private val algebraicMergeProjection = {
+ val bufferSchemata = aggregateFunctions.flatMap(_.bufferAttributes) ++
+ aggregateFunctions.flatMap(_.cloneBufferAttributes)
+ val mergeExpressions = aggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.mergeExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ }
+
+ newMutableProjection(mergeExpressions, bufferSchemata)()
+ }
+
+ // This projection is used to evaluate all AlgebraicAggregates.
+ private val algebraicEvalProjection = {
+ val bufferSchemata = aggregateFunctions.flatMap(_.bufferAttributes) ++
+ aggregateFunctions.flatMap(_.cloneBufferAttributes)
+ val evalExpressions = aggregateFunctions.map {
+ case ae: AlgebraicAggregate => ae.evaluateExpression
+ case agg: AggregateFunction2 => NoOp
+ }
+
+ newMutableProjection(evalExpressions, bufferSchemata)()
+ }
+
+ def iterator: Iterator[InternalRow] = {
+
+ val groupingKeyOrdering: Ordering[InternalRow] = GenerateOrdering.generate(
+ groupingExpressions.map(_.dataType).zipWithIndex.map { case(dt, index) =>
+ new SortOrder(BoundReference(index, dt, nullable = true), Ascending)
+ })
+
+ val aggregationBufferSchema: StructType = StructType.fromAttributes(
+ aggregateFunctions.flatMap(_.bufferAttributes)
+ )
+
+ val groupKeySchema: StructType = {
+ val fields = groupingExpressions.zipWithIndex.map { case (expr, idx) =>
+ // This is a dummy field name
+ StructField(idx.toString, expr.dataType, expr.nullable)
+ }
+ StructType(fields)
+ }
+
+ this.initializeBuffer()
+
+ log.info(s"GroupKey Schema: ${groupKeySchema.mkString(",")}")
+ log.info(s"AggregationBuffer Schema: ${aggregationBufferSchema.mkString(",")}")
+
+ val sparkEnv: SparkEnv = SparkEnv.get
+ val taskContext: TaskContext = TaskContext.get
+ val aggregationMap = new UnsafeExternalAggregation(
+ taskContext.taskMemoryManager,
+ sparkEnv.shuffleMemoryManager,
+ sparkEnv.blockManager,
+ taskContext,
+ algebraicPartialMergeProjection,
+ algebraicMergeProjection,
+ nonAlgebraicAggregateFunctions,
+ buffer,
+ aggregationBufferSchema,
+ groupKeySchema,
+ groupingKeyOrdering,
+ 1024 * 16,
+ sparkEnv.conf,
+ false
+ )
+
+ while (inputIter.hasNext) {
+ val currentRow: InternalRow = inputIter.next()
+ val groupKey: InternalRow = groupGenerator(currentRow).copy()
+ aggregationMap.insertRow(groupKey, currentRow)
+ }
+
+ new Iterator[InternalRow] {
+ private[this] val mapIterator = aggregationMap.iterator()
+ private[this] var nextKey: InternalRow = _
+ private[this] var nextValue: InternalRow = _
+
+ def hasNext: Boolean = mapIterator.hasNext
+
+ def next(): InternalRow = {
+ val entry = mapIterator.next()
+ if (hasNext) {
+ nextKey = entry.key
+ nextValue = entry.value
+ } else {
+ nextKey = entry.key.copy()
+ nextValue = entry.value.copy()
+ aggregationMap.freeMemory()
+ }
+ // Generate results for all algebraic aggregate functions.
+ algebraicEvalProjection.target(aggregateResult)(nextValue)
+ // Generate results for all non-algebraic aggregate functions.
+ var i = 0
+ while (i < nonAlgebraicAggregateFunctions.length) {
+ aggregateResult.update(
+ nonAlgebraicAggregateFunctionPositions(i),
+ nonAlgebraicAggregateFunctions(i).eval(nextValue))
+ i += 1
+ }
+ resultProjection(joinedRow(nextKey, aggregateResult))
+ }
+ }
+ }
+
+}
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala
index 0c9082897f390..0e59bb8f0c0cb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala
@@ -117,6 +117,70 @@ case class Aggregate2Sort(
}
}
+case class Aggregate2Hybrid(
+ requiredChildDistributionExpressions: Option[Seq[Expression]],
+ groupingExpressions: Seq[NamedExpression],
+ aggregateExpressions: Seq[AggregateExpression2],
+ aggregateAttributes: Seq[Attribute],
+ resultExpressions: Seq[NamedExpression],
+ child: SparkPlan)
+ extends UnaryNode {
+
+ override def canProcessUnsafeRows: Boolean = true
+
+ override def references: AttributeSet = {
+ val referencesInResults =
+ AttributeSet(resultExpressions.flatMap(_.references)) -- AttributeSet(aggregateAttributes)
+
+ AttributeSet(
+ groupingExpressions.flatMap(_.references) ++
+ aggregateExpressions.flatMap(_.references) ++
+ referencesInResults)
+ }
+
+ override def requiredChildDistribution: List[Distribution] = {
+ requiredChildDistributionExpressions match {
+ case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
+ case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil
+ case None => UnspecifiedDistribution :: Nil
+ }
+ }
+
+ override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
+
+ protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
+ child.execute().mapPartitions { iter =>
+ assert(aggregateExpressions.length > 0, "Use Aggregate2Sort when no aggregate function")
+ val aggregationIterator: Iterator[InternalRow] = {
+ aggregateExpressions.map(_.mode).distinct.toList match {
+ case Partial :: Nil =>
+ new HybridPartialAggregationIterator(
+ groupingExpressions,
+ aggregateExpressions,
+ newMutableProjection,
+ child.output,
+ iter).iterator
+ case Final :: Nil =>
+ new HybridFinalAggregationIterator(
+ groupingExpressions,
+ aggregateExpressions,
+ aggregateAttributes,
+ resultExpressions,
+ newMutableProjection,
+ child.output,
+ iter).iterator
+ case other =>
+ sys.error(
+ s"Could not evaluate ${aggregateExpressions} because we do not support evaluate " +
+ s"modes $other in this operator.")
+ }
+ }
+
+ aggregationIterator
+ }
+ }
+}
+
case class FinalAndCompleteAggregate2Sort(
previousGroupingExpressions: Seq[NamedExpression],
groupingExpressions: Seq[NamedExpression],
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
index 1b89edafa8dad..666f8429068dc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
@@ -34,90 +34,13 @@ private[sql] abstract class SortAggregationIterator(
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
inputAttributes: Seq[Attribute],
inputIter: Iterator[InternalRow])
- extends Iterator[InternalRow] {
-
- ///////////////////////////////////////////////////////////////////////////
- // Static fields for this iterator
- ///////////////////////////////////////////////////////////////////////////
-
- protected val aggregateFunctions: Array[AggregateFunction2] = {
- var bufferOffset = initialBufferOffset
- val functions = new Array[AggregateFunction2](aggregateExpressions.length)
- var i = 0
- while (i < aggregateExpressions.length) {
- val func = aggregateExpressions(i).aggregateFunction
- val funcWithBoundReferences = aggregateExpressions(i).mode match {
- case Partial | Complete if !func.isInstanceOf[AlgebraicAggregate] =>
- // We need to create BoundReferences if the function is not an
- // AlgebraicAggregate (it does not support code-gen) and the mode of
- // this function is Partial or Complete because we will call eval of this
- // function's children in the update method of this aggregate function.
- // Those eval calls require BoundReferences to work.
- BindReferences.bindReference(func, inputAttributes)
- case _ => func
- }
- // Set bufferOffset for this function. It is important that setting bufferOffset
- // happens after all potential bindReference operations because bindReference
- // will create a new instance of the function.
- funcWithBoundReferences.bufferOffset = bufferOffset
- bufferOffset += funcWithBoundReferences.bufferSchema.length
- functions(i) = funcWithBoundReferences
- i += 1
- }
- functions
- }
-
- // Positions of those non-algebraic aggregate functions in aggregateFunctions.
- // For example, we have func1, func2, func3, func4 in aggregateFunctions, and
- // func2 and func3 are non-algebraic aggregate functions.
- // nonAlgebraicAggregateFunctionPositions will be [1, 2].
- protected val nonAlgebraicAggregateFunctionPositions: Array[Int] = {
- val positions = new ArrayBuffer[Int]()
- var i = 0
- while (i < aggregateFunctions.length) {
- aggregateFunctions(i) match {
- case agg: AlgebraicAggregate =>
- case _ => positions += i
- }
- i += 1
- }
- positions.toArray
- }
-
- // All non-algebraic aggregate functions.
- protected val nonAlgebraicAggregateFunctions: Array[AggregateFunction2] =
- nonAlgebraicAggregateFunctionPositions.map(aggregateFunctions)
-
- // This is used to project expressions for the grouping expressions.
- protected val groupGenerator =
- newMutableProjection(groupingExpressions, inputAttributes)()
-
- // The underlying buffer shared by all aggregate functions.
- protected val buffer: MutableRow = {
- // The number of elements of the underlying buffer of this operator.
- // All aggregate functions are sharing this underlying buffer and they find their
- // buffer values through bufferOffset.
- var size = initialBufferOffset
- var i = 0
- while (i < aggregateFunctions.length) {
- size += aggregateFunctions(i).bufferSchema.length
- i += 1
- }
- new GenericMutableRow(size)
- }
-
- protected val joinedRow = new JoinedRow
-
- protected val placeholderExpressions = Seq.fill(initialBufferOffset)(NoOp)
-
- // This projection is used to initialize buffer values for all AlgebraicAggregates.
- protected val algebraicInitialProjection = {
- val initExpressions = placeholderExpressions ++ aggregateFunctions.flatMap {
- case ae: AlgebraicAggregate => ae.initialValues
- case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
- }
- newMutableProjection(initExpressions, Nil)().target(buffer)
- }
+ extends AggregationIterator(
+ groupingExpressions,
+ aggregateExpressions,
+ newMutableProjection,
+ inputAttributes,
+ inputIter)
+ with Iterator[InternalRow] {
///////////////////////////////////////////////////////////////////////////
// Mutable states
@@ -136,16 +59,6 @@ private[sql] abstract class SortAggregationIterator(
// Private methods
///////////////////////////////////////////////////////////////////////////
- /** Initializes buffer values for all aggregate functions. */
- protected def initializeBuffer(): Unit = {
- algebraicInitialProjection(EmptyRow)
- var i = 0
- while (i < nonAlgebraicAggregateFunctions.length) {
- nonAlgebraicAggregateFunctions(i).initialize(buffer)
- i += 1
- }
- }
-
protected def initialize(): Unit = {
if (inputIter.hasNext) {
initializeBuffer()
@@ -218,8 +131,6 @@ private[sql] abstract class SortAggregationIterator(
// Methods that need to be implemented
///////////////////////////////////////////////////////////////////////////
- protected def initialBufferOffset: Int
-
protected def processRow(row: InternalRow): Unit
protected def generateOutput(): InternalRow
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index 5bbe6c162ff4b..3a881f2fa5858 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -114,7 +114,7 @@ object Utils {
expr.collect {
case agg: AggregateExpression2 => agg
}
- }.toSet.toSeq
+ }
val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct)
val hasMultipleDistinctColumnSets =
if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
@@ -178,6 +178,7 @@ object Utils {
aggregateExpressions: Seq[AggregateExpression2],
aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute],
resultExpressions: Seq[NamedExpression],
+ useHybridAggregate: Boolean,
child: SparkPlan): Seq[SparkPlan] = {
// 1. Create an Aggregate Operator for partial aggregations.
val namedGroupingExpressions = groupingExpressions.map {
@@ -195,7 +196,20 @@ object Utils {
val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
agg.aggregateFunction.bufferAttributes
}
- val partialAggregate =
+
+ var useHybridAggregateNew = useHybridAggregate
+ if (aggregateExpressions.length == 0) {
+ useHybridAggregateNew = false
+ }
+ val partialAggregate = if (useHybridAggregateNew) {
+ Aggregate2Hybrid(
+ None: Option[Seq[Expression]],
+ namedGroupingExpressions.map(_._2),
+ partialAggregateExpressions,
+ partialAggregateAttributes,
+ namedGroupingAttributes ++ partialAggregateAttributes,
+ child)
+ } else {
Aggregate2Sort(
None: Option[Seq[Expression]],
namedGroupingExpressions.map(_._2),
@@ -203,6 +217,7 @@ object Utils {
partialAggregateAttributes,
namedGroupingAttributes ++ partialAggregateAttributes,
child)
+ }
// 2. Create an Aggregate Operator for final aggregations.
val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final))
@@ -222,13 +237,23 @@ object Utils {
}.getOrElse(expression)
}.asInstanceOf[NamedExpression]
}
- val finalAggregate = Aggregate2Sort(
- Some(namedGroupingAttributes),
- namedGroupingAttributes,
- finalAggregateExpressions,
- finalAggregateAttributes,
- rewrittenResultExpressions,
- partialAggregate)
+ val finalAggregate = if (useHybridAggregateNew) {
+ Aggregate2Hybrid(
+ Some(namedGroupingAttributes),
+ namedGroupingAttributes,
+ finalAggregateExpressions,
+ finalAggregateAttributes,
+ rewrittenResultExpressions,
+ partialAggregate)
+ } else {
+ Aggregate2Sort(
+ Some(namedGroupingAttributes),
+ namedGroupingAttributes,
+ finalAggregateExpressions,
+ finalAggregateAttributes,
+ rewrittenResultExpressions,
+ partialAggregate)
+ }
finalAggregate :: Nil
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index cd386b7a3ecf9..a0e9858d43210 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -21,9 +21,10 @@ import org.scalatest.BeforeAndAfterAll
import java.sql.Timestamp
+import org.apache.spark.SparkEnv
import org.apache.spark.sql.catalyst.DefaultParserDialect
import org.apache.spark.sql.catalyst.errors.DialectException
-import org.apache.spark.sql.execution.aggregate.Aggregate2Sort
+import org.apache.spark.sql.execution.aggregate.{Aggregate2Hybrid, Aggregate2Sort}
import org.apache.spark.sql.execution.GeneratedAggregate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.TestData._
@@ -305,6 +306,105 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
}
}
+ test("aggregation with hybridAggregate") {
+ val originalValue = sqlContext.conf.codegenEnabled
+ sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true)
+ val originalUseAggregate2 = sqlContext.conf.useSqlAggregate2
+ sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2, true)
+ val originalUseHybridAggregate = sqlContext.conf.useHybridAggregate
+ sqlContext.setConf(SQLConf.USE_HYBRID_AGGREGATE, true)
+ SparkEnv.get.conf.set("spark.test.aggregate.spillFrequency","10")
+ // Prepare a table that we can group some rows.
+ sqlContext.table("testData")
+ .unionAll(sqlContext.table("testData"))
+ .unionAll(sqlContext.table("testData"))
+ .registerTempTable("testData3x")
+
+ def testHybridAggregate(sqlText: String, expectedResults: Seq[Row]): Unit = {
+ val df = sql(sqlText)
+ // First, check if we have GeneratedAggregate.
+ var hasAggregate2Hybrid = false
+ df.queryExecution.executedPlan.foreach {
+ case newAggregate: Aggregate2Hybrid => hasAggregate2Hybrid = true
+ case _ =>
+ }
+ if (!hasAggregate2Hybrid) {
+ fail(
+ s"""
+ |Codegen is enabled, but query $sqlText does not have Aggregate2Hybrid in the plan.
+ |${df.queryExecution.simpleString}
+ """.stripMargin)
+ }
+ // Then, check results.
+ checkAnswer(df, expectedResults)
+ }
+
+ try {
+ // COUNT
+ testHybridAggregate(
+ "SELECT key, count(value) FROM testData3x GROUP BY key",
+ (1 to 100).map(i => Row(i, 3)))
+ testHybridAggregate(
+ "SELECT count(key) FROM testData3x",
+ Row(300) :: Nil)
+ // SUM
+ testHybridAggregate(
+ "SELECT value, sum(key) FROM testData3x GROUP BY value",
+ (1 to 100).map(i => Row(i.toString, 3 * i)))
+ testHybridAggregate(
+ "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x",
+ Row(5050 * 3, 5050 * 3.0) :: Nil)
+ // AVERAGE
+ testHybridAggregate(
+ "SELECT value, avg(key) FROM testData3x GROUP BY value",
+ (1 to 100).map(i => Row(i.toString, i)))
+ testHybridAggregate(
+ "SELECT avg(key) FROM testData3x",
+ Row(50.5) :: Nil)
+ // MAX
+ testHybridAggregate(
+ "SELECT value, max(key) FROM testData3x GROUP BY value",
+ (1 to 100).map(i => Row(i.toString, i)))
+ testHybridAggregate(
+ "SELECT max(key) FROM testData3x",
+ Row(100) :: Nil)
+ // MIN
+ testHybridAggregate(
+ "SELECT value, min(key) FROM testData3x GROUP BY value",
+ (1 to 100).map(i => Row(i.toString, i)))
+ testHybridAggregate(
+ "SELECT min(key) FROM testData3x",
+ Row(1) :: Nil)
+ // Some combinations.
+ testHybridAggregate(
+ """
+ |SELECT
+ | value,
+ | sum(key),
+ | max(key),
+ | min(key),
+ | avg(key),
+ | count(key)
+ |FROM testData3x
+ |GROUP BY value
+ """.stripMargin,
+ (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3)))
+ testHybridAggregate(
+ "SELECT max(key), min(key), avg(key), count(key) FROM testData3x",
+ Row(100, 1, 50.5, 300) :: Nil)
+ // Aggregate with Code generation handling all null values
+ testHybridAggregate(
+ "SELECT sum('a'), avg('a'), count(null) FROM testData",
+ Row(null, null, 0) :: Nil)
+ } finally {
+ sqlContext.dropTempTable("testData3x")
+ sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue)
+ sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2, originalUseAggregate2)
+ sqlContext.setConf(SQLConf.USE_HYBRID_AGGREGATE, originalUseHybridAggregate)
+ SparkEnv.get.conf.set("spark.test.aggregate.spillFrequency","0")
+ }
+ }
+
test("Add Parser of SQL COALESCE()") {
checkAnswer(
sql("""SELECT COALESCE(1, 2)"""),