Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import java.util.LinkedList;
import java.util.List;

import javax.annotation.Nullable;

import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -217,6 +219,7 @@ public static final class BytesToBytesMapIterator implements Iterator<Location>
private final Iterator<MemoryBlock> dataPagesIterator;
private final Location loc;

private MemoryBlock currentPage;
private int currentRecordNumber = 0;
private Object pageBaseObject;
private long offsetInPage;
Expand All @@ -232,7 +235,7 @@ private BytesToBytesMapIterator(
}

private void advanceToNextPage() {
final MemoryBlock currentPage = dataPagesIterator.next();
currentPage = dataPagesIterator.next();
pageBaseObject = currentPage.getBaseObject();
offsetInPage = currentPage.getBaseOffset();
}
Expand All @@ -249,7 +252,7 @@ public Location next() {
advanceToNextPage();
totalLength = PlatformDependent.UNSAFE.getInt(pageBaseObject, offsetInPage);
}
loc.with(pageBaseObject, offsetInPage);
loc.with(currentPage, offsetInPage);
offsetInPage += 8 + totalLength;
currentRecordNumber++;
return loc;
Expand Down Expand Up @@ -346,14 +349,19 @@ public final class Location {
private int keyLength;
private int valueLength;

/**
* Memory page containing the record. Only set if created by {@link BytesToBytesMap#iterator()}.
*/
@Nullable private MemoryBlock memoryPage;

private void updateAddressesAndSizes(long fullKeyAddress) {
updateAddressesAndSizes(
taskMemoryManager.getPage(fullKeyAddress),
taskMemoryManager.getOffsetInPage(fullKeyAddress));
}

private void updateAddressesAndSizes(final Object page, final long keyOffsetInPage) {
long position = keyOffsetInPage;
private void updateAddressesAndSizes(final Object page, final long offsetInPage) {
long position = offsetInPage;
final int totalLength = PlatformDependent.UNSAFE.getInt(page, position);
position += 4;
keyLength = PlatformDependent.UNSAFE.getInt(page, position);
Expand All @@ -366,7 +374,7 @@ private void updateAddressesAndSizes(final Object page, final long keyOffsetInPa
valueMemoryLocation.setObjAndOffset(page, position);
}

Location with(int pos, int keyHashcode, boolean isDefined) {
private Location with(int pos, int keyHashcode, boolean isDefined) {
this.pos = pos;
this.isDefined = isDefined;
this.keyHashcode = keyHashcode;
Expand All @@ -377,12 +385,21 @@ Location with(int pos, int keyHashcode, boolean isDefined) {
return this;
}

Location with(Object page, long keyOffsetInPage) {
private Location with(MemoryBlock page, long offsetInPage) {
this.isDefined = true;
updateAddressesAndSizes(page, keyOffsetInPage);
this.memoryPage = page;
updateAddressesAndSizes(page.getBaseObject(), offsetInPage);
return this;
}

/**
* Returns the memory page that contains the current record.
* This is only valid if this is returned by {@link BytesToBytesMap#iterator()}.
*/
public MemoryBlock getMemoryPage() {
return this.memoryPage;
}

/**
* Returns true if the key is defined at this position, and false otherwise.
*/
Expand Down Expand Up @@ -538,7 +555,7 @@ public boolean putNewKey(
long insertCursor = dataPageInsertOffset;

// Compute all of our offsets up-front:
final long totalLengthOffset = insertCursor;
final long recordOffset = insertCursor;
insertCursor += 4;
final long keyLengthOffset = insertCursor;
insertCursor += 4;
Expand All @@ -547,7 +564,7 @@ public boolean putNewKey(
final long valueDataOffsetInPage = insertCursor;
insertCursor += valueLengthBytes; // word used to store the value size

PlatformDependent.UNSAFE.putInt(dataPageBaseObject, totalLengthOffset,
PlatformDependent.UNSAFE.putInt(dataPageBaseObject, recordOffset,
keyLengthBytes + valueLengthBytes);
PlatformDependent.UNSAFE.putInt(dataPageBaseObject, keyLengthOffset, keyLengthBytes);
// Copy the key
Expand All @@ -569,7 +586,7 @@ public boolean putNewKey(
numElements++;
bitset.set(pos);
final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset(
dataPage, totalLengthOffset);
dataPage, recordOffset);
longArray.set(pos * 2, storedKeyAddress);
longArray.set(pos * 2 + 1, keyHashcode);
updateAddressesAndSizes(storedKeyAddress);
Expand Down Expand Up @@ -618,6 +635,10 @@ public void free() {
assert(dataPages.isEmpty());
}

public TaskMemoryManager getTaskMemoryManager() {
return taskMemoryManager;
}

/** Returns the total amount of memory, in bytes, consumed by this map's managed structures. */
public long getTotalMemoryConsumption() {
long totalDataPagesSize = 0L;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ object UnsafeProjection {
GenerateUnsafeProjection.generate(exprs)
}

def create(expr: Expression): UnsafeProjection = create(Seq(expr))

/**
* Returns an UnsafeProjection for given sequence of Expressions, which will be bound to
* `inputSchema`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.Logging
import org.apache.spark.annotation.Private
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.StructType

/**
* Inherits some default implementation for Java from `Ordering[Row]`
Expand All @@ -43,7 +44,16 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] =
in.map(BindReferences.bindReference(_, inputSchema))

protected def create(ordering: Seq[SortOrder]): Ordering[InternalRow] = {
/**
* Creates a code gen ordering for sorting this schema, in ascending order.
*/
def create(schema: StructType): BaseOrdering = {
create(schema.zipWithIndex.map { case (field, ordinal) =>
SortOrder(BoundReference(ordinal, field.dataType, nullable = true), Ascending)
})
}

protected def create(ordering: Seq[SortOrder]): BaseOrdering = {
val ctx = newCodeGenContext()

val comparisons = ordering.map { order =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,26 @@

package org.apache.spark.sql.execution;

import java.io.IOException;

import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering;
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.KVIterator;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.map.BytesToBytesMap;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.unsafe.memory.MemoryLocation;
import org.apache.spark.unsafe.memory.TaskMemoryManager;
import org.apache.spark.util.collection.unsafe.sort.PrefixComparator;
import org.apache.spark.util.collection.unsafe.sort.RecordComparator;
import org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter;
import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator;

/**
* Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width.
Expand Down Expand Up @@ -225,4 +232,93 @@ public void printPerfMetrics() {
System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption());
}

/**
* Sorts the key, value data in this map in place, and return them as an iterator.
*
* The only memory that is allocated is the address/prefix array, 16 bytes per record.
*/
public KVIterator<UnsafeRow, UnsafeRow> sortedIterator() {
int numElements = map.numElements();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could also be final

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or just inlined at one site where it's used

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is gone now

final int numKeyFields = groupingKeySchema.size();
TaskMemoryManager memoryManager = map.getTaskMemoryManager();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you name this taskMemoryManager to disambiguate from ShuffleMemoryManager?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


UnsafeExternalRowSorter.PrefixComputer prefixComp =
SortPrefixUtils.createPrefixGenerator(groupingKeySchema);
PrefixComparator prefixComparator = SortPrefixUtils.getPrefixComparator(groupingKeySchema);

final BaseOrdering ordering = GenerateOrdering.create(groupingKeySchema);
RecordComparator recordComparator = new RecordComparator() {
private final UnsafeRow row1 = new UnsafeRow();
private final UnsafeRow row2 = new UnsafeRow();

@Override
public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) {
row1.pointTo(baseObj1, baseOff1 + 4, numKeyFields, -1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The -1 here is for sizeInBytes; if we needed to, I guess we could retrieve the size in bytes since we know where it's stored relative to the row address.

row2.pointTo(baseObj2, baseOff2 + 4, numKeyFields, -1);
return ordering.compare(row1, row2);
}
};

// Insert the records into the in-memory sorter.
final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(
memoryManager, recordComparator, prefixComparator, numElements);

BytesToBytesMap.BytesToBytesMapIterator iter = map.iterator();
UnsafeRow row = new UnsafeRow();
while (iter.hasNext()) {
final BytesToBytesMap.Location loc = iter.next();
final Object baseObject = loc.getKeyAddress().getBaseObject();
final long baseOffset = loc.getKeyAddress().getBaseOffset();

// Get encoded memory address
MemoryBlock page = loc.getMemoryPage();
long address = memoryManager.encodePageNumberAndOffset(page, baseOffset - 8);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a one-line comment here (or at the top of the loop) to explain that baseObject + baseOffset point to the beginning of the key data in the map, but that the KV-pair's length data is stored in the word immediately before that address.


// Compute prefix
row.pointTo(baseObject, baseOffset, numKeyFields, loc.getKeyLength());
final long prefix = prefixComp.computePrefix(row);

sorter.insertRecord(address, prefix);
}

// Return the sorted result as an iterator.
return new KVIterator<UnsafeRow, UnsafeRow>() {

private UnsafeSorterIterator sortedIterator = sorter.getSortedIterator();
private final UnsafeRow key = new UnsafeRow();
private final UnsafeRow value = new UnsafeRow();
private int numValueFields = aggregationBufferSchema.size();

@Override
public boolean next() throws IOException {
if (sortedIterator.hasNext()) {
sortedIterator.loadNext();
Object baseObj = sortedIterator.getBaseObject();
long recordOffset = sortedIterator.getBaseOffset();
int recordLen = sortedIterator.getRecordLength();
int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This points to the key length because recordOffset is 4 bytes past where the total record length was stored?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that this is gone now

key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen);
value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, recordLen - keyLen);
return true;
} else {
return false;
}
}

@Override
public UnsafeRow getKey() {
return key;
}

@Override
public UnsafeRow getValue() {
return value;
}

@Override
public void close() {
// Do nothing
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

package org.apache.spark.sql.execution

import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, PrefixComparator}

Expand Down Expand Up @@ -46,4 +47,19 @@ object SortPrefixUtils {
case _ => NoOpPrefixComparator
}
}

def getPrefixComparator(schema: StructType): PrefixComparator = {
val field = schema.head
getPrefixComparator(SortOrder(BoundReference(0, field.dataType, field.nullable), Ascending))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to hardcode Ascending here? I realize that we only need to cluster, not sort, for aggregation, but in terms of API design maybe it would be clearer to fix Ascending in the KV sorter than here. If our KV sorter technically performs clustering instead of sorting then maybe we should add a comment to make that very explicit.

}

def createPrefixGenerator(schema: StructType): UnsafeExternalRowSorter.PrefixComputer = {
val boundReference = BoundReference(0, schema.head.dataType, nullable = true)
val prefixProjection = UnsafeProjection.create(SortPrefix(SortOrder(boundReference, Ascending)))
new UnsafeExternalRowSorter.PrefixComputer {
override def computePrefix(row: InternalRow): Long = {
prefixProjection.apply(row).getLong(0)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,38 @@ class UnsafeFixedWidthAggregationMapSuite
map.free()
}

test("test sorting") {
val map = new UnsafeFixedWidthAggregationMap(
emptyAggregationBuffer,
aggBufferSchema,
groupKeySchema,
taskMemoryManager,
shuffleMemoryManager,
128, // initial capacity
PAGE_SIZE_BYTES,
false // disable perf metrics
)

val rand = new Random(42)
val groupKeys: Set[String] = Seq.fill(512) {
Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString
}.toSet
groupKeys.foreach { keyString =>
val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString)))
buf.setInt(0, keyString.length)
assert(buf != null)
}

val out = new scala.collection.mutable.ArrayBuffer[String]
val iter = map.sortedIterator()
while (iter.next()) {
assert(iter.getKey.getString(0).length === iter.getValue.getInt(0))
out += iter.getKey.getString(0)
}

assert(out === groupKeys.toSeq.sorted)

map.free()
}

}
4 changes: 3 additions & 1 deletion unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

package org.apache.spark.unsafe;

import java.io.IOException;

public abstract class KVIterator<K, V> {

public abstract boolean next();
public abstract boolean next() throws IOException;

public abstract K getKey();

Expand Down