Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-12950] [SQL] Improve lookup of BytesToBytesMap in aggregate #11010

Closed
wants to merge 12 commits into from
108 changes: 61 additions & 47 deletions core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.hash.Murmur3_x86_32;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.unsafe.memory.MemoryLocation;
import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillReader;
import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter;

Expand All @@ -65,8 +64,6 @@ public final class BytesToBytesMap extends MemoryConsumer {

private final Logger logger = LoggerFactory.getLogger(BytesToBytesMap.class);

private static final Murmur3_x86_32 HASHER = new Murmur3_x86_32(0);

private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING;

private final TaskMemoryManager taskMemoryManager;
Expand Down Expand Up @@ -417,7 +414,19 @@ public MapIterator destructiveIterator() {
* This function always return the same {@link Location} instance to avoid object allocation.
*/
public Location lookup(Object keyBase, long keyOffset, int keyLength) {
safeLookup(keyBase, keyOffset, keyLength, loc);
safeLookup(keyBase, keyOffset, keyLength, loc,
Murmur3_x86_32.hashUnsafeWords(keyBase, keyOffset, keyLength, 42));
return loc;
}

/**
* Looks up a key, and return a {@link Location} handle that can be used to test existence
* and read/write values.
*
* This function always return the same {@link Location} instance to avoid object allocation.
*/
public Location lookup(Object keyBase, long keyOffset, int keyLength, int hash) {
safeLookup(keyBase, keyOffset, keyLength, loc, hash);
return loc;
}

Expand All @@ -426,37 +435,33 @@ public Location lookup(Object keyBase, long keyOffset, int keyLength) {
*
* This is a thread-safe version of `lookup`, could be used by multiple threads.
*/
public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location loc) {
public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location loc, int hash) {
assert(longArray != null);

if (enablePerfMetrics) {
numKeyLookups++;
}
final int hashcode = HASHER.hashUnsafeWords(keyBase, keyOffset, keyLength);
int pos = hashcode & mask;
int pos = hash & mask;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't work if hash is negative right? Not clear to me the new hash doesn't return negatives. assert this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works.

int step = 1;
while (true) {
if (enablePerfMetrics) {
numProbes++;
}
if (longArray.get(pos * 2) == 0) {
// This is a new key.
loc.with(pos, hashcode, false);
loc.with(pos, hash, false);
return;
} else {
long stored = longArray.get(pos * 2 + 1);
if ((int) (stored) == hashcode) {
if ((int) (stored) == hash) {
// Full hash code matches. Let's compare the keys for equality.
loc.with(pos, hashcode, true);
loc.with(pos, hash, true);
if (loc.getKeyLength() == keyLength) {
final MemoryLocation keyAddress = loc.getKeyAddress();
final Object storedkeyBase = keyAddress.getBaseObject();
final long storedkeyOffset = keyAddress.getBaseOffset();
final boolean areEqual = ByteArrayMethods.arrayEquals(
keyBase,
keyOffset,
storedkeyBase,
storedkeyOffset,
loc.getKeyBase(),
loc.getKeyOffset(),
keyLength
);
if (areEqual) {
Expand Down Expand Up @@ -484,13 +489,14 @@ public final class Location {
private boolean isDefined;
/**
* The hashcode of the most recent key passed to
* {@link BytesToBytesMap#lookup(Object, long, int)}. Caching this hashcode here allows us to
* avoid re-hashing the key when storing a value for that key.
* {@link BytesToBytesMap#lookup(Object, long, int, int)}. Caching this hashcode here allows us
* to avoid re-hashing the key when storing a value for that key.
*/
private int keyHashcode;
private final MemoryLocation keyMemoryLocation = new MemoryLocation();
private final MemoryLocation valueMemoryLocation = new MemoryLocation();
private Object baseObject; // the base object for key and value
private long keyOffset;
private int keyLength;
private long valueOffset;
private int valueLength;

/**
Expand All @@ -504,18 +510,15 @@ private void updateAddressesAndSizes(long fullKeyAddress) {
taskMemoryManager.getOffsetInPage(fullKeyAddress));
}

private void updateAddressesAndSizes(final Object base, final long offset) {
long position = offset;
final int totalLength = Platform.getInt(base, position);
position += 4;
keyLength = Platform.getInt(base, position);
position += 4;
private void updateAddressesAndSizes(final Object base, long offset) {
baseObject = base;
final int totalLength = Platform.getInt(base, offset);
offset += 4;
keyLength = Platform.getInt(base, offset);
offset += 4;
keyOffset = offset;
valueOffset = offset + keyLength;
valueLength = totalLength - keyLength - 4;

keyMemoryLocation.setObjAndOffset(base, position);

position += keyLength;
valueMemoryLocation.setObjAndOffset(base, position);
}

private Location with(int pos, int keyHashcode, boolean isDefined) {
Expand Down Expand Up @@ -543,10 +546,11 @@ private Location with(MemoryBlock page, long offsetInPage) {
private Location with(Object base, long offset, int length) {
this.isDefined = true;
this.memoryPage = null;
baseObject = base;
keyOffset = offset + 4;
keyLength = Platform.getInt(base, offset);
valueOffset = offset + 4 + keyLength;
valueLength = length - 4 - keyLength;
keyMemoryLocation.setObjAndOffset(base, offset + 4);
valueMemoryLocation.setObjAndOffset(base, offset + 4 + keyLength);
return this;
}

Expand All @@ -566,34 +570,44 @@ public boolean isDefined() {
}

/**
* Returns the address of the key defined at this position.
* This points to the first byte of the key data.
* Unspecified behavior if the key is not defined.
* For efficiency reasons, calls to this method always returns the same MemoryLocation object.
* Returns the base object for key.
*/
public MemoryLocation getKeyAddress() {
public Object getKeyBase() {
assert (isDefined);
return keyMemoryLocation;
return baseObject;
}

/**
* Returns the length of the key defined at this position.
* Unspecified behavior if the key is not defined.
* Returns the offset for key.
*/
public int getKeyLength() {
public long getKeyOffset() {
assert (isDefined);
return keyLength;
return keyOffset;
}

/**
* Returns the base object for value.
*/
public Object getValueBase() {
assert (isDefined);
return baseObject;
}

/**
* Returns the address of the value defined at this position.
* This points to the first byte of the value data.
* Returns the offset for value.
*/
public long getValueOffset() {
assert (isDefined);
return valueOffset;
}

/**
* Returns the length of the key defined at this position.
* Unspecified behavior if the key is not defined.
* For efficiency reasons, calls to this method always returns the same MemoryLocation object.
*/
public MemoryLocation getValueAddress() {
public int getKeyLength() {
assert (isDefined);
return valueMemoryLocation;
return keyLength;
}

/**
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/util/Benchmark.scala
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ private[spark] object Benchmark {
}
val best = runTimes.min
val avg = runTimes.sum / iters
Result(avg / 1000000, num / (best / 1000), best / 1000000)
Result(avg / 1000000.0, num / (best / 1000.0), best / 1000000.0)
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,13 @@

import org.apache.spark.SparkConf;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.storage.*;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.memory.MemoryLocation;
import org.apache.spark.util.Utils;

import static org.hamcrest.Matchers.greaterThan;
Expand Down Expand Up @@ -142,10 +141,9 @@ public void tearDown() {

protected abstract boolean useOffHeapMemoryAllocator();

private static byte[] getByteArray(MemoryLocation loc, int size) {
private static byte[] getByteArray(Object base, long offset, int size) {
final byte[] arr = new byte[size];
Platform.copyMemory(
loc.getBaseObject(), loc.getBaseOffset(), arr, Platform.BYTE_ARRAY_OFFSET, size);
Platform.copyMemory(base, offset, arr, Platform.BYTE_ARRAY_OFFSET, size);
return arr;
}

Expand All @@ -163,13 +161,14 @@ private byte[] getRandomByteArray(int numWords) {
*/
private static boolean arrayEquals(
byte[] expected,
MemoryLocation actualAddr,
Object base,
long offset,
long actualLengthBytes) {
return (actualLengthBytes == expected.length) && ByteArrayMethods.arrayEquals(
expected,
Platform.BYTE_ARRAY_OFFSET,
actualAddr.getBaseObject(),
actualAddr.getBaseOffset(),
base,
offset,
expected.length
);
}
Expand Down Expand Up @@ -212,16 +211,20 @@ public void setAndRetrieveAKey() {
// reflect the result of this store without us having to call lookup() again on the same key.
Assert.assertEquals(recordLengthBytes, loc.getKeyLength());
Assert.assertEquals(recordLengthBytes, loc.getValueLength());
Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes));
Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes));
Assert.assertArrayEquals(keyData,
getByteArray(loc.getKeyBase(), loc.getKeyOffset(), recordLengthBytes));
Assert.assertArrayEquals(valueData,
getByteArray(loc.getValueBase(), loc.getValueOffset(), recordLengthBytes));

// After calling lookup() the location should still point to the correct data.
Assert.assertTrue(
map.lookup(keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined());
Assert.assertEquals(recordLengthBytes, loc.getKeyLength());
Assert.assertEquals(recordLengthBytes, loc.getValueLength());
Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes));
Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes));
Assert.assertArrayEquals(keyData,
getByteArray(loc.getKeyBase(), loc.getKeyOffset(), recordLengthBytes));
Assert.assertArrayEquals(valueData,
getByteArray(loc.getValueBase(), loc.getValueOffset(), recordLengthBytes));

try {
Assert.assertTrue(loc.putNewKey(
Expand Down Expand Up @@ -283,15 +286,12 @@ private void iteratorTestBase(boolean destructive) throws Exception {
while (iter.hasNext()) {
final BytesToBytesMap.Location loc = iter.next();
Assert.assertTrue(loc.isDefined());
final MemoryLocation keyAddress = loc.getKeyAddress();
final MemoryLocation valueAddress = loc.getValueAddress();
final long value = Platform.getLong(
valueAddress.getBaseObject(), valueAddress.getBaseOffset());
final long value = Platform.getLong(loc.getValueBase(), loc.getValueOffset());
final long keyLength = loc.getKeyLength();
if (keyLength == 0) {
Assert.assertTrue("value " + value + " was not divisible by 5", value % 5 == 0);
} else {
final long key = Platform.getLong(keyAddress.getBaseObject(), keyAddress.getBaseOffset());
final long key = Platform.getLong(loc.getKeyBase(), loc.getKeyOffset());
Assert.assertEquals(value, key);
}
valuesSeen.set((int) value);
Expand Down Expand Up @@ -365,15 +365,15 @@ public void iteratingOverDataPagesWithWastedSpace() throws Exception {
Assert.assertEquals(KEY_LENGTH, loc.getKeyLength());
Assert.assertEquals(VALUE_LENGTH, loc.getValueLength());
Platform.copyMemory(
loc.getKeyAddress().getBaseObject(),
loc.getKeyAddress().getBaseOffset(),
loc.getKeyBase(),
loc.getKeyOffset(),
key,
Platform.LONG_ARRAY_OFFSET,
KEY_LENGTH
);
Platform.copyMemory(
loc.getValueAddress().getBaseObject(),
loc.getValueAddress().getBaseOffset(),
loc.getValueBase(),
loc.getValueOffset(),
value,
Platform.LONG_ARRAY_OFFSET,
VALUE_LENGTH
Expand Down Expand Up @@ -425,8 +425,9 @@ public void randomizedStressTest() {
Assert.assertTrue(loc.isDefined());
Assert.assertEquals(key.length, loc.getKeyLength());
Assert.assertEquals(value.length, loc.getValueLength());
Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length));
Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length));
Assert.assertTrue(arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), key.length));
Assert.assertTrue(
arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), value.length));
}
}

Expand All @@ -436,8 +437,10 @@ public void randomizedStressTest() {
final BytesToBytesMap.Location loc =
map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length);
Assert.assertTrue(loc.isDefined());
Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength()));
Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength()));
Assert.assertTrue(
arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), loc.getKeyLength()));
Assert.assertTrue(
arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), loc.getValueLength()));
}
} finally {
map.free();
Expand Down Expand Up @@ -476,8 +479,9 @@ public void randomizedTestWithRecordsLargerThanPageSize() {
Assert.assertTrue(loc.isDefined());
Assert.assertEquals(key.length, loc.getKeyLength());
Assert.assertEquals(value.length, loc.getValueLength());
Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length));
Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length));
Assert.assertTrue(arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), key.length));
Assert.assertTrue(
arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), value.length));
}
}
for (Map.Entry<ByteBuffer, byte[]> entry : expected.entrySet()) {
Expand All @@ -486,8 +490,10 @@ public void randomizedTestWithRecordsLargerThanPageSize() {
final BytesToBytesMap.Location loc =
map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length);
Assert.assertTrue(loc.isDefined());
Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength()));
Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength()));
Assert.assertTrue(
arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), loc.getKeyLength()));
Assert.assertTrue(
arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), loc.getValueLength()));
}
} finally {
map.free();
Expand Down
1 change: 1 addition & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ object MimaExcludes {
excludePackage("org.apache.spark.rpc"),
excludePackage("org.spark-project.jetty"),
excludePackage("org.apache.spark.unused"),
excludePackage("org.apache.spark.unsafe"),
excludePackage("org.apache.spark.util.collection.unsafe"),
excludePackage("org.apache.spark.sql.catalyst"),
excludePackage("org.apache.spark.sql.execution"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,6 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
}
}


override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
ev.isNull = "false"
val childrenHash = children.map { child =>
Expand Down
Loading