-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-9520][SQL] Support in-place sort in UnsafeFixedWidthAggregationMap #7849
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,19 +17,26 @@ | |
|
|
||
| package org.apache.spark.sql.execution; | ||
|
|
||
| import java.io.IOException; | ||
|
|
||
| import org.apache.spark.shuffle.ShuffleMemoryManager; | ||
| import org.apache.spark.sql.catalyst.InternalRow; | ||
| import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; | ||
| import org.apache.spark.sql.catalyst.expressions.UnsafeRow; | ||
| import org.apache.spark.sql.types.Decimal; | ||
| import org.apache.spark.sql.types.DecimalType; | ||
| import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering; | ||
| import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering; | ||
| import org.apache.spark.sql.types.StructField; | ||
| import org.apache.spark.sql.types.StructType; | ||
| import org.apache.spark.unsafe.KVIterator; | ||
| import org.apache.spark.unsafe.PlatformDependent; | ||
| import org.apache.spark.unsafe.map.BytesToBytesMap; | ||
| import org.apache.spark.unsafe.memory.MemoryBlock; | ||
| import org.apache.spark.unsafe.memory.MemoryLocation; | ||
| import org.apache.spark.unsafe.memory.TaskMemoryManager; | ||
| import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; | ||
| import org.apache.spark.util.collection.unsafe.sort.RecordComparator; | ||
| import org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter; | ||
| import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator; | ||
|
|
||
| /** | ||
| * Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width. | ||
|
|
@@ -225,4 +232,93 @@ public void printPerfMetrics() { | |
| System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption()); | ||
| } | ||
|
|
||
| /** | ||
| * Sorts the key, value data in this map in place, and return them as an iterator. | ||
| * | ||
| * The only memory that is allocated is the address/prefix array, 16 bytes per record. | ||
| */ | ||
| public KVIterator<UnsafeRow, UnsafeRow> sortedIterator() { | ||
| int numElements = map.numElements(); | ||
| final int numKeyFields = groupingKeySchema.size(); | ||
| TaskMemoryManager memoryManager = map.getTaskMemoryManager(); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you name this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
|
||
| UnsafeExternalRowSorter.PrefixComputer prefixComp = | ||
| SortPrefixUtils.createPrefixGenerator(groupingKeySchema); | ||
| PrefixComparator prefixComparator = SortPrefixUtils.getPrefixComparator(groupingKeySchema); | ||
|
|
||
| final BaseOrdering ordering = GenerateOrdering.create(groupingKeySchema); | ||
| RecordComparator recordComparator = new RecordComparator() { | ||
| private final UnsafeRow row1 = new UnsafeRow(); | ||
| private final UnsafeRow row2 = new UnsafeRow(); | ||
|
|
||
| @Override | ||
| public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { | ||
| row1.pointTo(baseObj1, baseOff1 + 4, numKeyFields, -1); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| row2.pointTo(baseObj2, baseOff2 + 4, numKeyFields, -1); | ||
| return ordering.compare(row1, row2); | ||
| } | ||
| }; | ||
|
|
||
| // Insert the records into the in-memory sorter. | ||
| final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter( | ||
| memoryManager, recordComparator, prefixComparator, numElements); | ||
|
|
||
| BytesToBytesMap.BytesToBytesMapIterator iter = map.iterator(); | ||
| UnsafeRow row = new UnsafeRow(); | ||
| while (iter.hasNext()) { | ||
| final BytesToBytesMap.Location loc = iter.next(); | ||
| final Object baseObject = loc.getKeyAddress().getBaseObject(); | ||
| final long baseOffset = loc.getKeyAddress().getBaseOffset(); | ||
|
|
||
| // Get encoded memory address | ||
| MemoryBlock page = loc.getMemoryPage(); | ||
| long address = memoryManager.encodePageNumberAndOffset(page, baseOffset - 8); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add a one-line comment here (or at the top of the loop) to explain that |
||
|
|
||
| // Compute prefix | ||
| row.pointTo(baseObject, baseOffset, numKeyFields, loc.getKeyLength()); | ||
| final long prefix = prefixComp.computePrefix(row); | ||
|
|
||
| sorter.insertRecord(address, prefix); | ||
| } | ||
|
|
||
| // Return the sorted result as an iterator. | ||
| return new KVIterator<UnsafeRow, UnsafeRow>() { | ||
|
|
||
| private UnsafeSorterIterator sortedIterator = sorter.getSortedIterator(); | ||
| private final UnsafeRow key = new UnsafeRow(); | ||
| private final UnsafeRow value = new UnsafeRow(); | ||
| private int numValueFields = aggregationBufferSchema.size(); | ||
|
|
||
| @Override | ||
| public boolean next() throws IOException { | ||
| if (sortedIterator.hasNext()) { | ||
| sortedIterator.loadNext(); | ||
| Object baseObj = sortedIterator.getBaseObject(); | ||
| long recordOffset = sortedIterator.getBaseOffset(); | ||
| int recordLen = sortedIterator.getRecordLength(); | ||
| int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This points to the key length because There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note that this is gone now |
||
| key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen); | ||
| value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, recordLen - keyLen); | ||
| return true; | ||
| } else { | ||
| return false; | ||
| } | ||
| } | ||
|
|
||
| @Override | ||
| public UnsafeRow getKey() { | ||
| return key; | ||
| } | ||
|
|
||
| @Override | ||
| public UnsafeRow getValue() { | ||
| return value; | ||
| } | ||
|
|
||
| @Override | ||
| public void close() { | ||
| // Do nothing | ||
| } | ||
| }; | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,7 +18,8 @@ | |
|
|
||
| package org.apache.spark.sql.execution | ||
|
|
||
| import org.apache.spark.sql.catalyst.expressions.SortOrder | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, PrefixComparator} | ||
|
|
||
|
|
@@ -46,4 +47,19 @@ object SortPrefixUtils { | |
| case _ => NoOpPrefixComparator | ||
| } | ||
| } | ||
|
|
||
| def getPrefixComparator(schema: StructType): PrefixComparator = { | ||
| val field = schema.head | ||
| getPrefixComparator(SortOrder(BoundReference(0, field.dataType, field.nullable), Ascending)) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want to hardcode |
||
| } | ||
|
|
||
| def createPrefixGenerator(schema: StructType): UnsafeExternalRowSorter.PrefixComputer = { | ||
| val boundReference = BoundReference(0, schema.head.dataType, nullable = true) | ||
| val prefixProjection = UnsafeProjection.create(SortPrefix(SortOrder(boundReference, Ascending))) | ||
| new UnsafeExternalRowSorter.PrefixComputer { | ||
| override def computePrefix(row: InternalRow): Long = { | ||
| prefixProjection.apply(row).getLong(0) | ||
| } | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could also be final
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or just inlined at one site where it's used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is gone now