Skip to content

Commit

Permalink
Option-2: Dump LongArray to PMem (apache#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
yma11 committed Nov 27, 2020
1 parent 915d7c4 commit 777dc39
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.util.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.HashMap;
import java.util.LinkedList;

public final class PMemWriter {
private static final Logger logger = LoggerFactory.getLogger(PMemWriter.class);
private final ShuffleWriteMetrics writeMetrics;
private final TaskMemoryManager taskMemoryManager;
private final LinkedList<MemoryBlock> allocatedPMemPages = new LinkedList<>();
Expand All @@ -30,12 +34,15 @@ public PMemWriter(
}

public boolean dumpPageToPMem(MemoryBlock page) {
long dumpStartTime = System.nanoTime();
MemoryBlock pMemBlock = taskMemoryManager.allocatePMemPage(page.size());
if (pMemBlock != null) {
Platform.copyMemory(page.getBaseObject(), page.getBaseOffset(), null, pMemBlock.getBaseOffset(), page.size());
writeMetrics.incBytesWritten(page.size());
allocatedPMemPages.add(pMemBlock);
pageMap.put(page, pMemBlock);
System.out.println("page size: " + Utils.bytesToString(page.size())
+ " time: " + (System.nanoTime()-dumpStartTime)/1000000);
return true;
}
return false;
Expand Down Expand Up @@ -63,7 +70,18 @@ public void updateLongArray(LongArray sortedArray, int numRecords, int position)
sortedArray.set(position, pMemOffset);
position += 2;
}
this.sortedArray = sortedArray;
// copy the LongArray to PMem
MemoryBlock arrayBlock = sortedArray.memoryBlock();
MemoryBlock pMemBlock = taskMemoryManager.allocatePMemPage(arrayBlock.size());

if (pMemBlock != null) {
writeMetrics.incBytesWritten(pMemBlock.size());
allocatedPMemPages.add(pMemBlock);
Platform.copyMemory(arrayBlock.getBaseObject(), arrayBlock.getBaseOffset(), null, pMemBlock.getBaseOffset(), arrayBlock.size());
this.sortedArray = new LongArray(pMemBlock);
} else {
logger.error("fails to allocate PMem for LongArray");
}
}

public LongArray getSortedArray() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
* Force this sorter to spill when there are this many elements in memory.
*/
private final int numElementsForSpillThreshold;
// private final boolean spillToPMemEnabled = true;
private final boolean spillToPMemEnabled = SparkEnv.get() != null && (boolean) SparkEnv.get().conf().get(
package$.MODULE$.MEMORY_SPILL_PMEM_ENABLED());
/**
Expand Down Expand Up @@ -219,54 +220,66 @@ public long spill(long size, MemoryConsumer trigger) throws IOException {
long spillSize = 0;
ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics();
// firstly try to spill to PMem if spark.memory.spill.pmem.enabled set to true
long arraySize = inMemSorter.getArray().memoryBlock().size();
long required = getMemoryUsage();
// assure there is enough PMem space for this spill first
long startTime = System.nanoTime();
long duration = 0;
long sortDuration = 0;
long writeDuration = 0;
if (spillToPMemEnabled && taskMemoryManager.acquireExtendedMemory(required) == required) {
// records are not sorted before spill to PMem, may affected performance?
final PMemWriter pMemSpillWriter = new PMemWriter(writeMetrics, taskContext.taskMetrics(),
taskMemoryManager, inMemSorter.numRecords());
long sortTime = System.nanoTime();
UnsafeSorterIterator sortedIterator = inMemSorter.getSortedIterator();

sortDuration = System.nanoTime() - sortTime;
System.out.println("inMemSorter records: " + sortedIterator.getNumRecords());
long dumpTime = System.nanoTime();
for (MemoryBlock page : allocatedPages) {
if (pMemSpillWriter.dumpPageToPMem(page)) {
spillSize += page.size();
} else {
if (!pMemSpillWriter.dumpPageToPMem(page)) {
logger.error("UnsafeExternalSorter fails to spill fully to PMem.");
}
}
long dumpDuration = System.nanoTime() - dumpTime;
System.out.println("dump time : " + dumpDuration/1000000);
long sortTime = System.nanoTime();
pMemSpillWriter.updateLongArray(inMemSorter.getSortedArray(), inMemSorter.numRecords(), 0);
long sortDuration = System.nanoTime() - sortTime;
System.out.println("sort time : " + sortDuration/1000000);
writeDuration = System.nanoTime() - dumpTime;

pMemSpillWriter.updateLongArray(inMemSorter.getArray(), inMemSorter.numRecords(), 0);
// verify all records in inMemSoter are spilled in PMem
assert(pMemSpillWriter.getNumRecordsWritten() == inMemSorter.numRecords());
pMemSpillWriters.add(pMemSpillWriter);
duration = System.nanoTime() - startTime;
spillSize += freeMemory();
inMemSorter.resetWithoutLongArrray();
// spill size includes long array size
spillSize += getMemoryUsage();
freeMemory();

} else {
// fallback to disk spill if PMem spill is not enabled or space not enough
final UnsafeSorterSpillWriter spillWriter =
new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics,
inMemSorter.numRecords());
spillWriters.add(spillWriter);
spillIterator(inMemSorter.getSortedIterator(), spillWriter);
duration = System.nanoTime() - startTime;
long sortStartTime = System.nanoTime();
UnsafeSorterIterator sortedIterator = inMemSorter.getSortedIterator();
sortDuration = System.nanoTime() - sortStartTime;
System.out.println("inMemSorter records: " + sortedIterator.getNumRecords());
long writeStartTime = System.nanoTime();
spillIterator(sortedIterator, spillWriter);
writeDuration = System.nanoTime() - writeStartTime;
spillSize += freeMemory();
inMemSorter.reset();
}
inMemSorter.reset();
// Note that this is more-or-less going to be a multiple of the page size, so wasted space in
// pages will currently be counted as memory spilled even though that space isn't actually
// written to disk. This also counts the space needed to store the sorter's pointer array.
// Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the
// records. Otherwise, if the task is over allocated memory, then without freeing the memory
// pages, we might not be able to get memory for the pointer array.
System.out.println("long array size: " + Utils.bytesToString(arraySize)
+ " released size: " + Utils.bytesToString(spillSize) + " write size: "
+ Utils.bytesToString(writeMetrics.bytesWritten()) +
" write time: " + writeDuration/1000000);
taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
taskContext.taskMetrics().incDiskBytesSpilled(writeMetrics.bytesWritten());
taskContext.taskMetrics().incShuffleSpillWriteTime(duration);
taskContext.taskMetrics().incShuffleSpillWriteTime(writeDuration);
taskContext.taskMetrics().incSpillSortTime(sortDuration);
totalSpillBytes += spillSize;
return spillSize;
}
Expand Down Expand Up @@ -355,9 +368,9 @@ private void deleteSpillFiles() {

private void deletePMemSpillPages() {
for (PMemWriter pMemWriter: pMemSpillWriters) {
if (pMemWriter.getSortedArray().memoryBlock().pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) {
freeArray(pMemWriter.getSortedArray());
}
/* if (pMemWriter.getSortedArray().memoryBlock().pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) {
freePMemPage(pMemWriter.getSortedArray().memoryBlock());
}*/
for (MemoryBlock block : pMemWriter.getAllocatedPMemPages()) {
freePMemPage(block);
}
Expand Down Expand Up @@ -588,6 +601,7 @@ public long spill() throws IOException {
((UnsafeInMemorySorter.SortedIterator) upstream).clone();

ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics();
long arraySize = inMemSorter.getArray().memoryBlock().size();
long required = getMemoryUsage();
long startTime = System.nanoTime();
long released = 0L;
Expand All @@ -611,7 +625,7 @@ public long spill() throws IOException {
assert(inMemSorter != null);
released += inMemSorter.getMemoryUsage();
totalSortTimeNanos += inMemSorter.getSortTimeNanos();
inMemSorter.freeWithoutLongArray();
inMemSorter.free();
} else {
// Iterate over the records that have not been returned and spill them.
final UnsafeSorterSpillWriter spillWriter =
Expand Down Expand Up @@ -641,6 +655,10 @@ public long spill() throws IOException {
}
allocatedPages.clear();
}
System.out.println("long array size: " + Utils.bytesToString(arraySize)
+ "released size: " + Utils.bytesToString(released) + "write size: "
+ Utils.bytesToString(writeMetrics.bytesWritten()) +
"write time: " + duration/1000000);
inMemSorter = null;
taskContext.taskMetrics().incMemoryBytesSpilled(released);
taskContext.taskMetrics().incDiskBytesSpilled(writeMetrics.bytesWritten());
Expand Down

0 comments on commit 777dc39

Please sign in to comment.