Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/javaTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
"org.apache.sysds.test.applications.**",
"**.test.usertest.**",
"**.component.c**.** -Dtest-threadCount=1 -Dtest-forkCount=1",
"**.component.e**.**,**.component.f**.**,**.component.m**.**",
"**.component.e**.**,**.component.f**.**,**.component.m**.**,**.component.o**.**",
"**.component.p**.**,**.component.r**.**,**.component.s**.**,**.component.t**.**,**.component.u**.**",
"**.functions.a**.**,**.functions.binary.matrix.**,**.functions.binary.scalar.**,**.functions.binary.tensor.**",
"**.functions.blocks.**,**.functions.data.rand.**,",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -672,8 +672,7 @@ else if(err instanceof Exception)
List<OOCStream.QueueCallback<IndexedMatrixValue>> outList = new ArrayList<>(r.size());
for(int j = 0; j < r.size(); j++) {
if(explicitCaching[j]) {
// Early forget item from cache
outList.add(new OOCStream.SimpleQueueCallback<>(r.get(j).get(), null));
outList.add(r.get(j).keepOpen());
}
else {
outList.add(r.get(j).keepOpen());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.ooc.memory.InMemoryQueueCallback;
import org.apache.sysds.runtime.ooc.stats.OOCEventLog;
import org.apache.sysds.utils.Statistics;

Expand Down Expand Up @@ -130,6 +131,10 @@ public static void forget(long streamId, int blockId) {
getCache().forget(key);
}

public static void forget(BlockKey key) {
getCache().forget(key);
}

/**
* Store a block in the OOC cache (serialize once)
*/
Expand Down Expand Up @@ -195,6 +200,15 @@ public static CompletableFuture<OOCStream.QueueCallback<IndexedMatrixValue>> req
return getCache().request(key).thenApply(e -> toCallback(e, key, null));
}

public static OOCStream.QueueCallback<IndexedMatrixValue> tryRequestBlock(long streamId, long blockId) {
return tryRequestBlock(new BlockKey(streamId, (int) blockId));
}

public static OOCStream.QueueCallback<IndexedMatrixValue> tryRequestBlock(BlockKey key) {
BlockEntry entry = getCache().tryRequest(key);
return entry == null ? null : toCallback(entry, key, null);
}

public static CompletableFuture<List<OOCStream.QueueCallback<IndexedMatrixValue>>> requestManyBlocks(List<BlockKey> keys) {
return getCache().request(keys).thenApply(
l -> {
Expand Down Expand Up @@ -245,6 +259,10 @@ public static boolean canClaimMemory() {
return getCache().isWithinLimits() && OOCInstruction.getComputeInFlight() <= OOCInstruction.getComputeBackpressureThreshold();
}

public static OOCCacheScheduler.HandoverHandle handover(BlockKey key, InMemoryQueueCallback callback) {
return getCache().handover(key, callback);
}

private static void pin(BlockEntry entry) {
getCache().pin(entry);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@

package org.apache.sysds.runtime.ooc.cache;

import org.apache.sysds.runtime.instructions.ooc.OOCStream;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.ooc.memory.InMemoryQueueCallback;

import java.util.Collection;
import java.util.List;
import java.util.concurrent.CompletableFuture;
Expand All @@ -32,6 +36,17 @@ public interface OOCCacheScheduler {
*/
CompletableFuture<BlockEntry> request(BlockKey key);

/**
* Tries to request a single block from the cache.
* Immediately returns the entry if present, otherwise null without scheduling reads.
* @param key the requested key associated to the block
* @return the available BlockEntry or null
*/
default BlockEntry tryRequest(BlockKey key) {
List<BlockEntry> out = tryRequest(List.of(key));
return out == null || out.isEmpty() ? null : out.get(0);
}

/**
* Requests a list of blocks from the cache that must be available at the same time.
* @param keys the requested keys associated to the block
Expand Down Expand Up @@ -81,6 +96,15 @@ public interface OOCCacheScheduler {
*/
BlockEntry putAndPin(BlockKey key, Object data, long size);

interface HandoverHandle {
BlockKey getKey();
boolean isCommitted();
CompletableFuture<Boolean> getCompletionFuture();
OOCStream.QueueCallback<IndexedMatrixValue> reclaim();
}

HandoverHandle handover(BlockKey key, InMemoryQueueCallback callback);

/**
* Places a new source-backed block in the cache and registers the location with the IO handler. The entry is
* treated as backed by disk, so eviction does not schedule spill writes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.instructions.ooc.OOCStream;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.ooc.memory.InMemoryQueueCallback;
import org.apache.sysds.runtime.ooc.stats.OOCEventLog;
import org.apache.sysds.utils.Statistics;
import scala.Tuple2;
Expand All @@ -49,6 +52,7 @@ public class OOCLRUCacheScheduler implements OOCCacheScheduler {
private final HashMap<BlockKey, BlockEntry> _evictionCache;
private final DeferredReadQueue _deferredReadRequests;
private final Deque<DeferredReadRequest> _processingReadRequests;
private final Deque<PendingHandover> _pendingHandovers;
private final HashMap<BlockKey, BlockReadState> _blockReads;
private volatile long _hardLimit;
private long _evictionLimit;
Expand All @@ -74,6 +78,7 @@ public OOCLRUCacheScheduler(OOCIOHandler ioHandler, long evictionLimit, long har
this._evictionCache = new HashMap<>();
this._deferredReadRequests = new DeferredReadQueue();
this._processingReadRequests = new ArrayDeque<>();
this._pendingHandovers = new ArrayDeque<>();
this._blockReads = new HashMap<>();
this._hardLimit = hardLimit;
this._evictionLimit = evictionLimit;
Expand Down Expand Up @@ -282,6 +287,25 @@ public BlockEntry putAndPin(BlockKey key, Object data, long size) {
return put(key, data, size, true, null);
}

@Override
public HandoverHandle handover(BlockKey key, InMemoryQueueCallback callback) {
if(!this._running)
throw new IllegalStateException("Cache scheduler has been shut down.");
PendingHandover handover = new PendingHandover(key, callback);
boolean immediateCommit = false;
synchronized(this) {
if(canAcceptHandoverLocked(callback.getManagedBytes()))
immediateCommit = true;
else
_pendingHandovers.addLast(handover);
}
if(immediateCommit) {
if(commitHandover(handover))
onCacheSizeChanged(true);
}
return handover;
}

@Override
public void putSourceBacked(BlockKey key, Object data, long size, OOCIOHandler.SourceBlockDescriptor descriptor) {
put(key, data, size, false, descriptor);
Expand Down Expand Up @@ -487,6 +511,14 @@ public synchronized void shutdown() {
_cache.clear();
_evictionCache.clear();
_processingReadRequests.clear();
while(!_pendingHandovers.isEmpty()) {
PendingHandover pending = _pendingHandovers.pollFirst();
if(pending == null)
continue;
OOCStream.QueueCallback<IndexedMatrixValue> callback = pending.reclaim();
if(callback != null)
callback.close();
}
_deferredReadRequests.clear();
_deferredReadCountHint = 0;
_blockReads.clear();
Expand Down Expand Up @@ -555,6 +587,9 @@ private void onCacheSizeChangedInternal(boolean incr) {
onCacheSizeIncremented();
else
while(onCacheSizeDecremented()) {}
while(processPendingHandovers()) {
onCacheSizeIncremented();
}
if(DMLScript.OOC_LOG_EVENTS)
OOCEventLog.onCacheSizeChangedEvent(_callerId, System.nanoTime(), _cacheSize, _bytesUpForEviction,
_pinnedBytes, _readingReservedBytes);
Expand Down Expand Up @@ -721,6 +756,32 @@ private long getEvictionPressure() {
return _cacheSize + _readBuffer - _bytesUpForEviction;
}

private boolean processPendingHandovers() {
List<PendingHandover> committed = new ArrayList<>();
synchronized(this) {
while(!_pendingHandovers.isEmpty()) {
PendingHandover pending = _pendingHandovers.peekFirst();
if(pending == null)
break;
if(pending.isCancelled()) {
_pendingHandovers.pollFirst();
continue;
}
long bytes = pending.getManagedBytes();
if(!canAcceptHandoverLocked(bytes))
break;
_pendingHandovers.pollFirst();
committed.add(pending);
}
}
boolean progress = false;
for(PendingHandover pending : committed) {
if(commitHandover(pending))
progress = true;
}
return progress;
}

private boolean onCacheSizeDecremented() {
if(_cacheSize + 10000000 >= _hardLimit || _deferredReadCountHint == 0)
return false;
Expand Down Expand Up @@ -1018,6 +1079,34 @@ private void registerWaiter(BlockKey key, DeferredReadRequest request, int index
state.waiters.add(new DeferredReadWaiter(request, index));
}

private boolean commitHandover(PendingHandover pending) {
InMemoryQueueCallback callback = pending.takeForCommit();
if(callback == null)
return false;
try {
IndexedMatrixValue value = callback.takeManagedResultForHandover();
long size = callback.getManagedBytes();
synchronized(this) {
BlockEntry entry = new BlockEntry(pending.getKey(), size, value);
_cache.put(pending.getKey(), entry);
_cacheSize += size;
}
callback.releaseManagedMemory();
callback.close();
pending.markCommitted();
return true;
}
catch(Throwable t) {
pending.markCancelled();
callback.close();
throw t;
}
}

private boolean canAcceptHandoverLocked(long bytes) {
return bytes >= 0 && _cacheSize + bytes <= _hardLimit;
}

private static class BlockReadState {
private double priority;
private final List<DeferredReadWaiter> waiters;
Expand All @@ -1037,4 +1126,76 @@ private DeferredReadWaiter(DeferredReadRequest request, int index) {
this.index = index;
}
}

private static class PendingHandover implements HandoverHandle {
private final BlockKey _key;
private final CompletableFuture<Boolean> _completionFuture;
private InMemoryQueueCallback _callback;
private boolean _committed;
private boolean _cancelled;
private boolean _committing;

private PendingHandover(BlockKey key, InMemoryQueueCallback callback) {
_key = key;
_completionFuture = new CompletableFuture<>();
_callback = callback;
}

@Override
public synchronized BlockKey getKey() {
return _key;
}

@Override
public synchronized boolean isCommitted() {
return _committed;
}

@Override
public synchronized CompletableFuture<Boolean> getCompletionFuture() {
return _completionFuture;
}

@Override
public synchronized OOCStream.QueueCallback<IndexedMatrixValue> reclaim() {
if(_committed || _committing)
return null;
_cancelled = true;
_completionFuture.complete(false);
OOCStream.QueueCallback<IndexedMatrixValue> callback = _callback;
_callback = null;
return callback;
}

private synchronized long getManagedBytes() {
return _callback == null ? 0 : _callback.getManagedBytes();
}

private synchronized boolean isCancelled() {
return _cancelled;
}

private synchronized InMemoryQueueCallback takeForCommit() {
if(_committed || _cancelled || _committing)
return null;
_committing = true;
InMemoryQueueCallback callback = _callback;
_callback = null;
return callback;
}

private synchronized void markCommitted() {
_committing = false;
_committed = true;
_completionFuture.complete(true);
}

private synchronized void markCancelled() {
if(_committed || _cancelled)
return;
_committing = false;
_cancelled = true;
_completionFuture.complete(false);
}
}
}
Loading
Loading