From a51b64118c198c636815d0abda4625edf5309248 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 31 Jul 2015 17:41:25 -0700 Subject: [PATCH 1/3] Added a KV sorter interface. --- .../spark/unsafe/map/BytesToBytesMap.java | 8 +- .../sql/execution/UnsafeKeyValueSorter.java | 30 ++++++++ .../UnsafeFixedWidthAggregationMap.java | 73 +++++++++---------- .../sql/execution/GeneratedAggregate.scala | 27 ++++--- .../UnsafeFixedWidthAggregationMapSuite.scala | 27 ++++--- .../org/apache/spark/unsafe/KVIterator.java | 29 ++++++++ 6 files changed, 130 insertions(+), 64 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeKeyValueSorter.java create mode 100644 unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 0f42950e6ed8b..be4b5cc62b263 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -17,7 +17,6 @@ package org.apache.spark.unsafe.map; -import java.io.IOException; import java.lang.Override; import java.lang.UnsupportedOperationException; import java.util.Iterator; @@ -212,7 +211,7 @@ public BytesToBytesMap( */ public int numElements() { return numElements; } - private static final class BytesToBytesMapIterator implements Iterator { + public static final class BytesToBytesMapIterator implements Iterator { private final int numRecords; private final Iterator dataPagesIterator; @@ -222,7 +221,8 @@ private static final class BytesToBytesMapIterator implements Iterator private Object pageBaseObject; private long offsetInPage; - BytesToBytesMapIterator(int numRecords, Iterator dataPagesIterator, Location loc) { + private BytesToBytesMapIterator( + int numRecords, Iterator dataPagesIterator, Location loc) { this.numRecords = numRecords; this.dataPagesIterator = dataPagesIterator; this.loc = loc; @@ -269,7 +269,7 @@ public void remove() { * If any other lookups or operations are performed on this map while iterating over it, including * `lookup()`, the behavior of the returned iterator is undefined. */ - public Iterator iterator() { + public BytesToBytesMapIterator iterator() { return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeKeyValueSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeKeyValueSorter.java new file mode 100644 index 0000000000000..59c774da74acf --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeKeyValueSorter.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution; + +import java.io.IOException; + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.unsafe.KVIterator; + +public abstract class UnsafeKeyValueSorter { + + public abstract void insert(UnsafeRow key, UnsafeRow value); + + public abstract KVIterator sort() throws IOException; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 08a98cdd94a4c..c18b6dea6b2e1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -17,9 +17,6 @@ package org.apache.spark.sql.execution; -import java.io.IOException; -import java.util.Iterator; - import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; @@ -28,6 +25,7 @@ import org.apache.spark.sql.types.DecimalType; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.KVIterator; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryLocation; @@ -156,54 +154,55 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { return currentAggregationBuffer; } - /** - * Mutable pair object returned by {@link UnsafeFixedWidthAggregationMap#iterator()}. - */ - public static class MapEntry { - private MapEntry() { }; - public final UnsafeRow key = new UnsafeRow(); - public final UnsafeRow value = new UnsafeRow(); - } - /** * Returns an iterator over the keys and values in this map. * * For efficiency, each call returns the same object. */ - public Iterator iterator() { - return new Iterator() { + public KVIterator iterator() { + return new KVIterator() { + + private final BytesToBytesMap.BytesToBytesMapIterator mapLocationIterator = map.iterator(); + private final UnsafeRow key = new UnsafeRow(); + private final UnsafeRow value = new UnsafeRow(); - private final MapEntry entry = new MapEntry(); - private final Iterator mapLocationIterator = map.iterator(); + @Override + public boolean next() { + if (mapLocationIterator.hasNext()) { + final BytesToBytesMap.Location loc = mapLocationIterator.next(); + final MemoryLocation keyAddress = loc.getKeyAddress(); + final MemoryLocation valueAddress = loc.getValueAddress(); + key.pointTo( + keyAddress.getBaseObject(), + keyAddress.getBaseOffset(), + groupingKeySchema.length(), + loc.getKeyLength() + ); + value.pointTo( + valueAddress.getBaseObject(), + valueAddress.getBaseOffset(), + aggregationBufferSchema.length(), + loc.getValueLength() + ); + return true; + } else { + return false; + } + } @Override - public boolean hasNext() { - return mapLocationIterator.hasNext(); + public UnsafeRow getKey() { + return key; } @Override - public MapEntry next() { - final BytesToBytesMap.Location loc = mapLocationIterator.next(); - final MemoryLocation keyAddress = loc.getKeyAddress(); - final MemoryLocation valueAddress = loc.getValueAddress(); - entry.key.pointTo( - keyAddress.getBaseObject(), - keyAddress.getBaseOffset(), - groupingKeySchema.length(), - loc.getKeyLength() - ); - entry.value.pointTo( - valueAddress.getBaseObject(), - valueAddress.getBaseOffset(), - aggregationBufferSchema.length(), - loc.getValueLength() - ); - return entry; + public UnsafeRow getValue() { + return value; } @Override - public void remove() { - throw new UnsupportedOperationException(); + public void close() { + // Do nothing. } }; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 469de6ca8e101..1c88c0a98f906 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -287,21 +287,26 @@ case class GeneratedAggregate( new Iterator[InternalRow] { private[this] val mapIterator = aggregationMap.iterator() private[this] val resultProjection = resultProjectionBuilder() + private[this] var _hasNext = mapIterator.next() - def hasNext: Boolean = mapIterator.hasNext + def hasNext: Boolean = _hasNext def next(): InternalRow = { - val entry = mapIterator.next() - val result = resultProjection(joinedRow(entry.key, entry.value)) - if (hasNext) { - result + if (_hasNext) { + val result = resultProjection(joinedRow(mapIterator.getKey, mapIterator.getKey)) + _hasNext = mapIterator.next() + if (_hasNext) { + result + } else { + // This is the last element in the iterator, so let's free the buffer. Before we do, + // though, we need to make a defensive copy of the result so that we don't return an + // object that might contain dangling pointers to the freed memory + val resultCopy = result.copy() + aggregationMap.free() + resultCopy + } } else { - // This is the last element in the iterator, so let's free the buffer. Before we do, - // though, we need to make a defensive copy of the result so that we don't return an - // object that might contain dangling pointers to the freed memory - val resultCopy = result.copy() - aggregationMap.free() - resultCopy + throw new java.util.NoSuchElementException } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 79fd52dacda52..6a2c51ca88ac3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.scalatest.{BeforeAndAfterEach, Matchers} import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.util.Random import org.apache.spark.SparkFunSuite @@ -52,7 +53,7 @@ class UnsafeFixedWidthAggregationMapSuite override def afterEach(): Unit = { if (taskMemoryManager != null) { - val leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask + val leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask() assert(taskMemoryManager.cleanUpAllAllocatedMemory() === 0) assert(leakedShuffleMemory === 0) taskMemoryManager = null @@ -80,7 +81,7 @@ class UnsafeFixedWidthAggregationMapSuite PAGE_SIZE_BYTES, false // disable perf metrics ) - assert(!map.iterator().hasNext) + assert(!map.iterator().next()) map.free() } @@ -100,13 +101,13 @@ class UnsafeFixedWidthAggregationMapSuite // Looking up a key stores a zero-entry in the map (like Python Counters or DefaultDicts) assert(map.getAggregationBuffer(groupKey) != null) val iter = map.iterator() - val entry = iter.next() - assert(!iter.hasNext) - entry.key.getString(0) should be ("cats") - entry.value.getInt(0) should be (0) + assert(iter.next()) + iter.getKey.getString(0) should be ("cats") + iter.getValue.getInt(0) should be (0) + assert(!iter.next()) // Modifications to rows retrieved from the map should update the values in the map - entry.value.setInt(0, 42) + iter.getValue.setInt(0, 42) map.getAggregationBuffer(groupKey).getInt(0) should be (42) map.free() @@ -128,12 +129,14 @@ class UnsafeFixedWidthAggregationMapSuite groupKeys.foreach { keyString => assert(map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString))) != null) } - val seenKeys: Set[String] = map.iterator().asScala.map { entry => - entry.key.getString(0) - }.toSet - seenKeys.size should be (groupKeys.size) - seenKeys should be (groupKeys) + val seenKeys = new mutable.HashSet[String] + val iter = map.iterator() + while (iter.next()) { + seenKeys += iter.getKey.getString(0) + } + assert(seenKeys.size === groupKeys.size) + assert(seenKeys === groupKeys) map.free() } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java b/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java new file mode 100644 index 0000000000000..fb163401c0d27 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe; + +public abstract class KVIterator { + + public abstract boolean next(); + + public abstract K getKey(); + + public abstract V getValue(); + + public abstract void close(); +} From 2e62ccbc83803bf32060fd21ab66aeafad2e5e92 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 31 Jul 2015 19:15:13 -0700 Subject: [PATCH 2/3] Updated BytesToBytesMap's data encoding to put the key first. --- .../spark/unsafe/map/BytesToBytesMap.java | 50 ++++++++++--------- .../unsafe/sort/UnsafeExternalSorter.java | 15 ++++++ .../map/AbstractBytesToBytesMapSuite.java | 6 +-- 3 files changed, 45 insertions(+), 26 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index be4b5cc62b263..481375f493a50 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -244,13 +244,13 @@ public boolean hasNext() { @Override public Location next() { - int keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage); - if (keyLength == END_OF_PAGE_MARKER) { + int totalLength = PlatformDependent.UNSAFE.getInt(pageBaseObject, offsetInPage); + if (totalLength == END_OF_PAGE_MARKER) { advanceToNextPage(); - keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage); + totalLength = PlatformDependent.UNSAFE.getInt(pageBaseObject, offsetInPage); } loc.with(pageBaseObject, offsetInPage); - offsetInPage += 8 + 8 + keyLength + loc.getValueLength(); + offsetInPage += 8 + totalLength; currentRecordNumber++; return loc; } @@ -352,15 +352,18 @@ private void updateAddressesAndSizes(long fullKeyAddress) { taskMemoryManager.getOffsetInPage(fullKeyAddress)); } - private void updateAddressesAndSizes(Object page, long keyOffsetInPage) { - long position = keyOffsetInPage; - keyLength = (int) PlatformDependent.UNSAFE.getLong(page, position); - position += 8; // word used to store the key size - keyMemoryLocation.setObjAndOffset(page, position); - position += keyLength; - valueLength = (int) PlatformDependent.UNSAFE.getLong(page, position); - position += 8; // word used to store the key size - valueMemoryLocation.setObjAndOffset(page, position); + private void updateAddressesAndSizes(final Object page, final long keyOffsetInPage) { + long position = keyOffsetInPage; + final int totalLength = PlatformDependent.UNSAFE.getInt(page, position); + position += 4; + keyLength = PlatformDependent.UNSAFE.getInt(page, position); + position += 4; + valueLength = totalLength - keyLength; + + keyMemoryLocation.setObjAndOffset(page, position); + + position += keyLength; + valueMemoryLocation.setObjAndOffset(page, position); } Location with(int pos, int keyHashcode, boolean isDefined) { @@ -478,7 +481,7 @@ public boolean putNewKey( // the key address instead of storing the absolute address of the value, the key and value // must be stored in the same memory page. // (8 byte key length) (key) (8 byte value length) (value) - final long requiredSize = 8 + keyLengthBytes + 8 + valueLengthBytes; + final long requiredSize = 8 + keyLengthBytes + valueLengthBytes; // --- Figure out where to insert the new record --------------------------------------------- @@ -508,7 +511,7 @@ public boolean putNewKey( // There wasn't enough space in the current page, so write an end-of-page marker: final Object pageBaseObject = currentDataPage.getBaseObject(); final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor; - PlatformDependent.UNSAFE.putLong(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER); + PlatformDependent.UNSAFE.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER); } final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes); if (memoryGranted != pageSizeBytes) { @@ -535,21 +538,22 @@ public boolean putNewKey( long insertCursor = dataPageInsertOffset; // Compute all of our offsets up-front: - final long keySizeOffsetInPage = insertCursor; - insertCursor += 8; // word used to store the key size + final long totalLengthOffset = insertCursor; + insertCursor += 4; + final long keyLengthOffset = insertCursor; + insertCursor += 4; final long keyDataOffsetInPage = insertCursor; insertCursor += keyLengthBytes; - final long valueSizeOffsetInPage = insertCursor; - insertCursor += 8; // word used to store the value size final long valueDataOffsetInPage = insertCursor; insertCursor += valueLengthBytes; // word used to store the value size + PlatformDependent.UNSAFE.putInt(dataPageBaseObject, totalLengthOffset, + keyLengthBytes + valueLengthBytes); + PlatformDependent.UNSAFE.putInt(dataPageBaseObject, keyLengthOffset, keyLengthBytes); // Copy the key - PlatformDependent.UNSAFE.putLong(dataPageBaseObject, keySizeOffsetInPage, keyLengthBytes); PlatformDependent.copyMemory( keyBaseObject, keyBaseOffset, dataPageBaseObject, keyDataOffsetInPage, keyLengthBytes); // Copy the value - PlatformDependent.UNSAFE.putLong(dataPageBaseObject, valueSizeOffsetInPage, valueLengthBytes); PlatformDependent.copyMemory(valueBaseObject, valueBaseOffset, dataPageBaseObject, valueDataOffsetInPage, valueLengthBytes); @@ -557,7 +561,7 @@ public boolean putNewKey( if (useOverflowPage) { // Store the end-of-page marker at the end of the data page - PlatformDependent.UNSAFE.putLong(dataPageBaseObject, insertCursor, END_OF_PAGE_MARKER); + PlatformDependent.UNSAFE.putInt(dataPageBaseObject, insertCursor, END_OF_PAGE_MARKER); } else { pageCursor += requiredSize; } @@ -565,7 +569,7 @@ public boolean putNewKey( numElements++; bitset.set(pos); final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset( - dataPage, keySizeOffsetInPage); + dataPage, totalLengthOffset); longArray.set(pos * 2, storedKeyAddress); longArray.set(pos * 2 + 1, keyHashcode); updateAddressesAndSizes(storedKeyAddress); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 866e0b4151577..c05f2c332eee3 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -282,6 +282,21 @@ public void insertRecord( sorter.insertRecord(recordAddress, prefix); } + /** + * Write a record to the sorter. The record is broken down into two different parts, and + * + */ + public void insertRecord( + Object recordBaseObject1, + long recordBaseOffset1, + int lengthInBytes1, + Object recordBaseObject2, + long recordBaseOffset2, + int lengthInBytes2, + long prefix) throws IOException { + + } + public UnsafeSorterIterator getSortedIterator() throws IOException { final UnsafeSorterIterator inMemoryIterator = sorter.getSortedIterator(); int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0); diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 60f483acbcb80..70f8ca4d21345 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -243,17 +243,17 @@ public void iteratorTest() throws Exception { @Test public void iteratingOverDataPagesWithWastedSpace() throws Exception { final int NUM_ENTRIES = 1000 * 1000; - final int KEY_LENGTH = 16; + final int KEY_LENGTH = 24; final int VALUE_LENGTH = 40; final BytesToBytesMap map = new BytesToBytesMap( taskMemoryManager, shuffleMemoryManager, NUM_ENTRIES, PAGE_SIZE_BYTES); - // Each record will take 8 + 8 + 16 + 40 = 72 bytes of space in the data page. Our 64-megabyte + // Each record will take 8 + 24 + 40 = 72 bytes of space in the data page. Our 64-megabyte // pages won't be evenly-divisible by records of this size, which will cause us to waste some // space at the end of the page. This is necessary in order for us to take the end-of-record // handling branch in iterator(). try { for (int i = 0; i < NUM_ENTRIES; i++) { - final long[] key = new long[] { i, i }; // 2 * 8 = 16 bytes + final long[] key = new long[] { i, i, i }; // 3 * 8 = 24 bytes final long[] value = new long[] { i, i, i, i, i }; // 5 * 8 = 40 bytes final BytesToBytesMap.Location loc = map.lookup( key, From 5716b59a310e903219f28f9d54cc4340c99271b6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 31 Jul 2015 21:26:19 -0700 Subject: [PATCH 3/3] Fixed test. --- .../org/apache/spark/sql/execution/GeneratedAggregate.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 1c88c0a98f906..cd87b8deba0c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -293,7 +293,7 @@ case class GeneratedAggregate( def next(): InternalRow = { if (_hasNext) { - val result = resultProjection(joinedRow(mapIterator.getKey, mapIterator.getKey)) + val result = resultProjection(joinedRow(mapIterator.getKey, mapIterator.getValue)) _hasNext = mapIterator.next() if (_hasNext) { result