From d45ff6d381a9ccf1a075d9384aafa8889ea8f993 Mon Sep 17 00:00:00 2001 From: Jannik Lindemann Date: Sat, 11 Apr 2026 15:52:28 +0200 Subject: [PATCH] [SYSTEMDS-3891] Add OOC Memory Tracking Closes #2458. --- .github/workflows/javaTests.yml | 2 +- .../instructions/ooc/OOCInstruction.java | 3 +- .../runtime/ooc/cache/OOCCacheManager.java | 18 + .../runtime/ooc/cache/OOCCacheScheduler.java | 24 + .../ooc/cache/OOCLRUCacheScheduler.java | 161 ++++++ .../runtime/ooc/memory/CachedAllowance.java | 498 ++++++++++++++++++ .../ooc/memory/GlobalMemoryBroker.java | 231 ++++++++ .../ooc/memory/InMemoryQueueCallback.java | 194 +++++++ .../runtime/ooc/memory/MemoryAllowance.java | 44 ++ .../runtime/ooc/memory/MemoryBroker.java | 28 + .../ooc/memory/SyncMemoryAllowance.java | 191 +++++++ .../ooc/memory/OOCMemoryAllowanceTest.java | 459 ++++++++++++++++ 12 files changed, 1850 insertions(+), 3 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/ooc/memory/CachedAllowance.java create mode 100644 src/main/java/org/apache/sysds/runtime/ooc/memory/GlobalMemoryBroker.java create mode 100644 src/main/java/org/apache/sysds/runtime/ooc/memory/InMemoryQueueCallback.java create mode 100644 src/main/java/org/apache/sysds/runtime/ooc/memory/MemoryAllowance.java create mode 100644 src/main/java/org/apache/sysds/runtime/ooc/memory/MemoryBroker.java create mode 100644 src/main/java/org/apache/sysds/runtime/ooc/memory/SyncMemoryAllowance.java create mode 100644 src/test/java/org/apache/sysds/test/component/ooc/memory/OOCMemoryAllowanceTest.java diff --git a/.github/workflows/javaTests.yml b/.github/workflows/javaTests.yml index 4a377443ba3..6c73738b558 100644 --- a/.github/workflows/javaTests.yml +++ b/.github/workflows/javaTests.yml @@ -54,7 +54,7 @@ jobs: "org.apache.sysds.test.applications.**", "**.test.usertest.**", "**.component.c**.** -Dtest-threadCount=1 -Dtest-forkCount=1", - "**.component.e**.**,**.component.f**.**,**.component.m**.**", + "**.component.e**.**,**.component.f**.**,**.component.m**.**,**.component.o**.**", "**.component.p**.**,**.component.r**.**,**.component.s**.**,**.component.t**.**,**.component.u**.**", "**.functions.a**.**,**.functions.binary.matrix.**,**.functions.binary.scalar.**,**.functions.binary.tensor.**", "**.functions.blocks.**,**.functions.data.rand.**,", diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index aaa24ac2f42..be9728d87b9 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -672,8 +672,7 @@ else if(err instanceof Exception) List> outList = new ArrayList<>(r.size()); for(int j = 0; j < r.size(); j++) { if(explicitCaching[j]) { - // Early forget item from cache - outList.add(new OOCStream.SimpleQueueCallback<>(r.get(j).get(), null)); + outList.add(r.get(j).keepOpen()); } else { outList.add(r.get(j).keepOpen()); diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java index a0c8bb075a8..9f0f8c15b49 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java @@ -26,6 +26,7 @@ import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.ooc.memory.InMemoryQueueCallback; import org.apache.sysds.runtime.ooc.stats.OOCEventLog; import org.apache.sysds.utils.Statistics; @@ -130,6 +131,10 @@ public static void forget(long streamId, int blockId) { getCache().forget(key); } + public static void forget(BlockKey key) { + getCache().forget(key); + } + /** * Store a block in the OOC cache (serialize once) */ @@ -195,6 +200,15 @@ public static CompletableFuture> req return getCache().request(key).thenApply(e -> toCallback(e, key, null)); } + public static OOCStream.QueueCallback tryRequestBlock(long streamId, long blockId) { + return tryRequestBlock(new BlockKey(streamId, (int) blockId)); + } + + public static OOCStream.QueueCallback tryRequestBlock(BlockKey key) { + BlockEntry entry = getCache().tryRequest(key); + return entry == null ? null : toCallback(entry, key, null); + } + public static CompletableFuture>> requestManyBlocks(List keys) { return getCache().request(keys).thenApply( l -> { @@ -245,6 +259,10 @@ public static boolean canClaimMemory() { return getCache().isWithinLimits() && OOCInstruction.getComputeInFlight() <= OOCInstruction.getComputeBackpressureThreshold(); } + public static OOCCacheScheduler.HandoverHandle handover(BlockKey key, InMemoryQueueCallback callback) { + return getCache().handover(key, callback); + } + private static void pin(BlockEntry entry) { getCache().pin(entry); } diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java index dbbd73d53a4..f78327160fa 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java @@ -19,6 +19,10 @@ package org.apache.sysds.runtime.ooc.cache; +import org.apache.sysds.runtime.instructions.ooc.OOCStream; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.ooc.memory.InMemoryQueueCallback; + import java.util.Collection; import java.util.List; import java.util.concurrent.CompletableFuture; @@ -32,6 +36,17 @@ public interface OOCCacheScheduler { */ CompletableFuture request(BlockKey key); + /** + * Tries to request a single block from the cache. + * Immediately returns the entry if present, otherwise null without scheduling reads. + * @param key the requested key associated to the block + * @return the available BlockEntry or null + */ + default BlockEntry tryRequest(BlockKey key) { + List out = tryRequest(List.of(key)); + return out == null || out.isEmpty() ? null : out.get(0); + } + /** * Requests a list of blocks from the cache that must be available at the same time. * @param keys the requested keys associated to the block @@ -81,6 +96,15 @@ public interface OOCCacheScheduler { */ BlockEntry putAndPin(BlockKey key, Object data, long size); + interface HandoverHandle { + BlockKey getKey(); + boolean isCommitted(); + CompletableFuture getCompletionFuture(); + OOCStream.QueueCallback reclaim(); + } + + HandoverHandle handover(BlockKey key, InMemoryQueueCallback callback); + /** * Places a new source-backed block in the cache and registers the location with the IO handler. The entry is * treated as backed by disk, so eviction does not schedule spill writes. diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java index cc7aa7bcd1f..e8af837d670 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java @@ -23,6 +23,9 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.api.DMLScript; +import org.apache.sysds.runtime.instructions.ooc.OOCStream; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.ooc.memory.InMemoryQueueCallback; import org.apache.sysds.runtime.ooc.stats.OOCEventLog; import org.apache.sysds.utils.Statistics; import scala.Tuple2; @@ -49,6 +52,7 @@ public class OOCLRUCacheScheduler implements OOCCacheScheduler { private final HashMap _evictionCache; private final DeferredReadQueue _deferredReadRequests; private final Deque _processingReadRequests; + private final Deque _pendingHandovers; private final HashMap _blockReads; private volatile long _hardLimit; private long _evictionLimit; @@ -74,6 +78,7 @@ public OOCLRUCacheScheduler(OOCIOHandler ioHandler, long evictionLimit, long har this._evictionCache = new HashMap<>(); this._deferredReadRequests = new DeferredReadQueue(); this._processingReadRequests = new ArrayDeque<>(); + this._pendingHandovers = new ArrayDeque<>(); this._blockReads = new HashMap<>(); this._hardLimit = hardLimit; this._evictionLimit = evictionLimit; @@ -282,6 +287,25 @@ public BlockEntry putAndPin(BlockKey key, Object data, long size) { return put(key, data, size, true, null); } + @Override + public HandoverHandle handover(BlockKey key, InMemoryQueueCallback callback) { + if(!this._running) + throw new IllegalStateException("Cache scheduler has been shut down."); + PendingHandover handover = new PendingHandover(key, callback); + boolean immediateCommit = false; + synchronized(this) { + if(canAcceptHandoverLocked(callback.getManagedBytes())) + immediateCommit = true; + else + _pendingHandovers.addLast(handover); + } + if(immediateCommit) { + if(commitHandover(handover)) + onCacheSizeChanged(true); + } + return handover; + } + @Override public void putSourceBacked(BlockKey key, Object data, long size, OOCIOHandler.SourceBlockDescriptor descriptor) { put(key, data, size, false, descriptor); @@ -487,6 +511,14 @@ public synchronized void shutdown() { _cache.clear(); _evictionCache.clear(); _processingReadRequests.clear(); + while(!_pendingHandovers.isEmpty()) { + PendingHandover pending = _pendingHandovers.pollFirst(); + if(pending == null) + continue; + OOCStream.QueueCallback callback = pending.reclaim(); + if(callback != null) + callback.close(); + } _deferredReadRequests.clear(); _deferredReadCountHint = 0; _blockReads.clear(); @@ -555,6 +587,9 @@ private void onCacheSizeChangedInternal(boolean incr) { onCacheSizeIncremented(); else while(onCacheSizeDecremented()) {} + while(processPendingHandovers()) { + onCacheSizeIncremented(); + } if(DMLScript.OOC_LOG_EVENTS) OOCEventLog.onCacheSizeChangedEvent(_callerId, System.nanoTime(), _cacheSize, _bytesUpForEviction, _pinnedBytes, _readingReservedBytes); @@ -721,6 +756,32 @@ private long getEvictionPressure() { return _cacheSize + _readBuffer - _bytesUpForEviction; } + private boolean processPendingHandovers() { + List committed = new ArrayList<>(); + synchronized(this) { + while(!_pendingHandovers.isEmpty()) { + PendingHandover pending = _pendingHandovers.peekFirst(); + if(pending == null) + break; + if(pending.isCancelled()) { + _pendingHandovers.pollFirst(); + continue; + } + long bytes = pending.getManagedBytes(); + if(!canAcceptHandoverLocked(bytes)) + break; + _pendingHandovers.pollFirst(); + committed.add(pending); + } + } + boolean progress = false; + for(PendingHandover pending : committed) { + if(commitHandover(pending)) + progress = true; + } + return progress; + } + private boolean onCacheSizeDecremented() { if(_cacheSize + 10000000 >= _hardLimit || _deferredReadCountHint == 0) return false; @@ -1018,6 +1079,34 @@ private void registerWaiter(BlockKey key, DeferredReadRequest request, int index state.waiters.add(new DeferredReadWaiter(request, index)); } + private boolean commitHandover(PendingHandover pending) { + InMemoryQueueCallback callback = pending.takeForCommit(); + if(callback == null) + return false; + try { + IndexedMatrixValue value = callback.takeManagedResultForHandover(); + long size = callback.getManagedBytes(); + synchronized(this) { + BlockEntry entry = new BlockEntry(pending.getKey(), size, value); + _cache.put(pending.getKey(), entry); + _cacheSize += size; + } + callback.releaseManagedMemory(); + callback.close(); + pending.markCommitted(); + return true; + } + catch(Throwable t) { + pending.markCancelled(); + callback.close(); + throw t; + } + } + + private boolean canAcceptHandoverLocked(long bytes) { + return bytes >= 0 && _cacheSize + bytes <= _hardLimit; + } + private static class BlockReadState { private double priority; private final List waiters; @@ -1037,4 +1126,76 @@ private DeferredReadWaiter(DeferredReadRequest request, int index) { this.index = index; } } + + private static class PendingHandover implements HandoverHandle { + private final BlockKey _key; + private final CompletableFuture _completionFuture; + private InMemoryQueueCallback _callback; + private boolean _committed; + private boolean _cancelled; + private boolean _committing; + + private PendingHandover(BlockKey key, InMemoryQueueCallback callback) { + _key = key; + _completionFuture = new CompletableFuture<>(); + _callback = callback; + } + + @Override + public synchronized BlockKey getKey() { + return _key; + } + + @Override + public synchronized boolean isCommitted() { + return _committed; + } + + @Override + public synchronized CompletableFuture getCompletionFuture() { + return _completionFuture; + } + + @Override + public synchronized OOCStream.QueueCallback reclaim() { + if(_committed || _committing) + return null; + _cancelled = true; + _completionFuture.complete(false); + OOCStream.QueueCallback callback = _callback; + _callback = null; + return callback; + } + + private synchronized long getManagedBytes() { + return _callback == null ? 0 : _callback.getManagedBytes(); + } + + private synchronized boolean isCancelled() { + return _cancelled; + } + + private synchronized InMemoryQueueCallback takeForCommit() { + if(_committed || _cancelled || _committing) + return null; + _committing = true; + InMemoryQueueCallback callback = _callback; + _callback = null; + return callback; + } + + private synchronized void markCommitted() { + _committing = false; + _committed = true; + _completionFuture.complete(true); + } + + private synchronized void markCancelled() { + if(_committed || _cancelled) + return; + _committing = false; + _cancelled = true; + _completionFuture.complete(false); + } + } } diff --git a/src/main/java/org/apache/sysds/runtime/ooc/memory/CachedAllowance.java b/src/main/java/org/apache/sysds/runtime/ooc/memory/CachedAllowance.java new file mode 100644 index 00000000000..4649e47f81e --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/ooc/memory/CachedAllowance.java @@ -0,0 +1,498 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.memory; + +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.instructions.ooc.CachingStream; +import org.apache.sysds.runtime.instructions.ooc.OOCStream; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.ooc.cache.BlockKey; +import org.apache.sysds.runtime.ooc.cache.OOCCacheManager; +import org.apache.sysds.runtime.ooc.cache.OOCCacheScheduler; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReferenceArray; + +public class CachedAllowance extends SyncMemoryAllowance { + private static final int INITIAL_SLOTS = 64; + private static final long MIN_HANDOVER_SLACK = 1_000_000L; + private static final long MAX_HANDOVER_SLACK = 128_000_000L; + + private final long _streamId; + private final AtomicLong _nextBlockId; + private volatile AtomicReferenceArray _slots; + private long _pendingHandoverBytes; + private int _highestPopulatedIndex; + private boolean _handoverScheduling; + private boolean _handoverSchedulingRequested; + + public CachedAllowance(MemoryBroker broker) { + super(broker); + _streamId = CachingStream._streamSeq.getNextID(); + _slots = new AtomicReferenceArray<>(INITIAL_SLOTS); + _nextBlockId = new AtomicLong(0); + _pendingHandoverBytes = 0; + _highestPopulatedIndex = -1; + _handoverScheduling = false; + _handoverSchedulingRequested = false; + } + + public void handover(InMemoryQueueCallback callback, int index) { + if(callback == null) + throw new IllegalArgumentException("Cannot hand over null callback."); + callback.transferOwnershipBlocking(this); + + InMemoryQueueCallback root = (InMemoryQueueCallback) callback.keepOpen(); + callback.close(); + root.getHandle().attachCachedAllowance(this, index); + + SlotEntry entry = new SlotEntry(root); + synchronized(this) { + ensureCapacity(index); + AtomicReferenceArray slots = _slots; + if(slots.get(index) != null) { + root.getHandle().detachCachedAllowance(); + root.close(); + throw new IllegalStateException("Cached allowance slot " + index + " already occupied."); + } + slots.set(index, entry); + if(index > _highestPopulatedIndex) + _highestPopulatedIndex = index; + } + } + + public OOCStream.QueueCallback tryGet(int index) { + SlotEntry entry = getSlot(index); + if(entry == null) + return null; + + while(true) { + BlockKey cacheKey = null; + OOCCacheScheduler.HandoverHandle handover = null; + InMemoryQueueCallback local = null; + + synchronized(entry) { + if(entry._local != null && entry._handover == null) + local = entry._local; + else if(entry._handover != null) { + handover = entry._handover; + cacheKey = entry._cacheKey; + } + else if(entry._cacheKey != null) + cacheKey = entry._cacheKey; + else + return null; + } + + if(local != null) + return local.keepOpen(); + + if(handover != null) { + OOCStream.QueueCallback reclaimed = handover.reclaim(); + if(reclaimed != null) { + reclaimed.close(); + synchronized(entry) { + if(entry._handover == handover) { + finishPendingHandover(entry); + entry._handover = null; + entry._cacheKey = null; + } + } + continue; + } + + CompletableFuture future = handover.getCompletionFuture(); + if(!future.isDone()) + return null; + boolean committed = future.join(); + InMemoryQueueCallback localToClose = null; + synchronized(entry) { + if(entry._handover != handover) + continue; + finishPendingHandover(entry); + entry._handover = null; + if(committed) { + localToClose = entry._local; + entry._local = null; + } + else { + entry._cacheKey = null; + } + } + if(localToClose != null) + closeRoot(localToClose); + continue; + } + + return OOCCacheManager.tryRequestBlock(cacheKey); + } + } + + public CompletableFuture> get(int index) { + OOCStream.QueueCallback immediate = tryGet(index); + if(immediate != null) + return CompletableFuture.completedFuture(immediate); + + SlotEntry entry = getSlot(index); + if(entry == null) + return CompletableFuture.completedFuture(null); + + OOCCacheScheduler.HandoverHandle handover; + BlockKey cacheKey; + synchronized(entry) { + if(entry._local != null && entry._handover == null) + return CompletableFuture.completedFuture(entry._local.keepOpen()); + handover = entry._handover; + cacheKey = entry._cacheKey; + } + + if(handover != null) { + return handover.getCompletionFuture().handle((committed, ex) -> { + if(ex != null) + throw DMLRuntimeException.of(ex.getCause() == null ? ex : ex.getCause()); + return committed == true; + }).thenCompose(committed -> { + InMemoryQueueCallback localToClose = null; + InMemoryQueueCallback local = null; + BlockKey key; + + synchronized(entry) { + if(entry._handover != handover) + return get(index); + + finishPendingHandover(entry); + entry._handover = null; + if(committed) { + key = entry._cacheKey; + localToClose = entry._local; + entry._local = null; + } + else { + entry._cacheKey = null; + local = entry._local; + key = null; + } + } + + if(localToClose != null) + closeRoot(localToClose); + + if(committed) + return OOCCacheManager.requestBlock(key); + return CompletableFuture.completedFuture(local == null ? null : local.keepOpen()); + }); + } + + if(cacheKey != null) + return OOCCacheManager.requestBlock(cacheKey); + return CompletableFuture.completedFuture(null); + } + + public void clear(int index) { + SlotEntry entry = removeSlot(index); + if(entry == null) + return; + + while(true) { + OOCCacheScheduler.HandoverHandle handover = null; + BlockKey forgetKey = null; + InMemoryQueueCallback localToClose = null; + + synchronized(entry) { + if(entry._local != null && entry._handover == null) { + localToClose = entry._local; + entry._local = null; + } + else if(entry._handover != null) + handover = entry._handover; + else if(entry._cacheKey != null) { + forgetKey = entry._cacheKey; + entry._cacheKey = null; + } + else + return; + } + + if(localToClose != null) { + closeRoot(localToClose); + return; + } + + if(forgetKey != null) { + OOCCacheManager.forget(forgetKey); + return; + } + + OOCStream.QueueCallback reclaimed = handover.reclaim(); + if(reclaimed != null) { + reclaimed.close(); + synchronized(entry) { + if(entry._handover == handover) { + finishPendingHandover(entry); + localToClose = entry._local; + entry._local = null; + entry._handover = null; + entry._cacheKey = null; + } + } + if(localToClose != null) + closeRoot(localToClose); + return; + } + + boolean committed; + try { + committed = handover.getCompletionFuture().join(); + } + catch(CompletionException ex) { + throw DMLRuntimeException.of(ex.getCause() == null ? ex : ex.getCause()); + } + + synchronized(entry) { + if(entry._handover != handover) + continue; + finishPendingHandover(entry); + localToClose = entry._local; + entry._local = null; + entry._handover = null; + if(committed) + forgetKey = entry._cacheKey; + entry._cacheKey = null; + } + + if(localToClose != null) + closeRoot(localToClose); + if(forgetKey != null) + OOCCacheManager.forget(forgetKey); + return; + } + } + + @Override + public boolean tryReserve(long bytes) { + throw new UnsupportedOperationException("CachedAllowance does not support direct reservations. Use handover(...)."); + } + + @Override + public void reserveBlocking(long bytes) { + throw new UnsupportedOperationException("CachedAllowance does not support direct reservations. Use handover(...)."); + } + + @Override + public void setTargetMemory(long targetMemory) { + super.setTargetMemory(targetMemory); + maybeScheduleHandovers(0); + } + + void onFinishedHandover(long bytes) { + synchronized(this) { + _pendingHandoverBytes -= bytes; + if(_pendingHandoverBytes < 0) + throw new IllegalStateException(); + notifyAll(); + } + maybeScheduleHandovers(0); + } + + void admitBlocking(long bytes) { + while(true) { + if(super.tryReserve(bytes)) + return; + maybeScheduleHandovers(bytes); + if(super.tryReserve(bytes)) + return; + if(_shutdown || _destroyed) + throw new IllegalStateException("Cannot reserve memory on closed allowance."); + synchronized(this) { + if(_shutdown || _destroyed) + throw new IllegalStateException("Cannot reserve memory on closed allowance."); + try { + wait(); + } + catch(InterruptedException e) { + throw new DMLRuntimeException(e); + } + } + } + } + + private void maybeScheduleHandovers(long requestedBytes) { + synchronized(this) { + _handoverSchedulingRequested = true; + if(_handoverScheduling) + return; + _handoverScheduling = true; + } + + boolean restart; + try { + while(true) { + long reclaimGoal; + int startIndex; + synchronized(this) { + _handoverSchedulingRequested = false; + if(_shutdown || _destroyed) + return; + long excess = _usedBytes + requestedBytes - _targetBytes - _pendingHandoverBytes; + if(excess <= 0) { + if(!_handoverSchedulingRequested) + return; + continue; + } + + long slack = Math.max(MIN_HANDOVER_SLACK, Math.min(MAX_HANDOVER_SLACK, _targetBytes / 16)); + reclaimGoal = excess + slack; + startIndex = _highestPopulatedIndex; + } + + AtomicReferenceArray slots = _slots; + int newHighest = startIndex; + // Find highes non-null entry + for(int i = Math.min(startIndex, slots.length() - 1); i >= 0; i--) { + if(slots.get(i) != null) { + newHighest = i; + break; + } + } + + for(int i = newHighest; i >= 0 && reclaimGoal > 0; i--) { + long bytes = tryStartCacheHandover(slots.get(i)); + if(bytes <= 0) + continue; + reclaimGoal -= bytes; + } + + synchronized(this) { + if(newHighest < _highestPopulatedIndex) + _highestPopulatedIndex = newHighest; + if(!_handoverSchedulingRequested) + return; + } + } + } + finally { + synchronized(this) { + _handoverScheduling = false; + restart = _handoverSchedulingRequested; + } + if(restart) + maybeScheduleHandovers(requestedBytes); + } + } + + private long tryStartCacheHandover(SlotEntry entry) { + if(entry == null) + return 0; + synchronized(entry) { + if(entry._local == null || entry._handover != null || !entry._local.getHandle().isExclusiveToRoot()) + return 0; + + long bytes = entry._local.getManagedBytes(); + if(bytes <= 0) + return 0; + + InMemoryQueueCallback retained = (InMemoryQueueCallback) entry._local.keepOpen(); + try { + entry._cacheKey = new BlockKey(_streamId, _nextBlockId.getAndIncrement()); + entry._handover = OOCCacheManager.handover(entry._cacheKey, retained); + entry._pendingBytes = bytes; + synchronized(this) { + _pendingHandoverBytes += bytes; + } + entry._handover.getCompletionFuture().whenComplete((committed, ex) -> onHandoverCompleted(entry)); + return bytes; + } + catch(RuntimeException ex) { + entry._cacheKey = null; + entry._handover = null; + entry._pendingBytes = 0; + retained.close(); + throw ex; + } + } + } + + private void onHandoverCompleted(SlotEntry entry) { + synchronized(entry) { + if(entry._pendingBytes <= 0) + return; + finishPendingHandover(entry); + } + maybeScheduleHandovers(0); + } + + private void finishPendingHandover(SlotEntry entry) { + if(entry._pendingBytes <= 0) + return; + long bytes = entry._pendingBytes; + entry._pendingBytes = 0; + onFinishedHandover(bytes); + } + + private void closeRoot(InMemoryQueueCallback local) { + local.getHandle().detachCachedAllowance(); + local.close(); + } + + private SlotEntry getSlot(int index) { + AtomicReferenceArray slots = _slots; + if(index < 0 || index >= slots.length()) + return null; + return slots.get(index); + } + + private SlotEntry removeSlot(int index) { + synchronized(this) { + AtomicReferenceArray slots = _slots; + if(index < 0 || index >= slots.length()) + return null; + SlotEntry entry = slots.get(index); + if(entry != null) + slots.set(index, null); + return entry; + } + } + + private void ensureCapacity(int index) { + AtomicReferenceArray slots = _slots; + if(index < slots.length()) + return; + int newLen = slots.length(); + while(index >= newLen) + newLen *= 2; + AtomicReferenceArray grown = new AtomicReferenceArray<>(newLen); + for(int i = 0; i < slots.length(); i++) + grown.set(i, slots.get(i)); + _slots = grown; + } + + private static final class SlotEntry { + private InMemoryQueueCallback _local; + private BlockKey _cacheKey; + private OOCCacheScheduler.HandoverHandle _handover; + private long _pendingBytes; + + private SlotEntry(InMemoryQueueCallback local) { + _local = local; + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/ooc/memory/GlobalMemoryBroker.java b/src/main/java/org/apache/sysds/runtime/ooc/memory/GlobalMemoryBroker.java new file mode 100644 index 00000000000..6009182f156 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/ooc/memory/GlobalMemoryBroker.java @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.memory; + +import java.util.ArrayList; +import java.util.LinkedList; +import java.util.List; + +public class GlobalMemoryBroker implements MemoryBroker { + private enum BrokerMode { + RELAXED, STRICT + } + + private static final GlobalMemoryBroker BROKER = new GlobalMemoryBroker(Runtime.getRuntime().maxMemory() / 3); + + public static GlobalMemoryBroker get() { + return BROKER; + } + + private final long _allowedBytes; + private final List _allowances; + private final LinkedList _overconsumers; + private long _usedBytes; + private BrokerMode _brokerMode; + + private record TargetUpdate(MemoryAllowance _allowance, long _target) {} + + public GlobalMemoryBroker(long allowedBytes) { + _allowedBytes = allowedBytes; + _usedBytes = 0; + _allowances = new ArrayList<>(); + _overconsumers = new LinkedList<>(); + } + + @Override + public long requestMemory(MemoryAllowance allowance, long minSize, long maxSize) { + List updates = null; + long allow = 0; + synchronized(this) { + if(minSize < 0 || maxSize < minSize) + throw new IllegalArgumentException(); + long share = getEqualShare(); + long free = _allowedBytes - _usedBytes; + if(free < minSize) { + if(allowance.getGrantedMemory() > share && allowance.getTargetMemory() > allowance.getGrantedMemory()) + updates = List.of(new TargetUpdate(allowance, allowance.getUsedMemory())); + else { + MemoryAllowance largestConsumer = findAndRemoveLargestConsumer(); + if(largestConsumer != null) { + long newTarget = (long) (largestConsumer.getGrantedMemory() * 0.8); + if(newTarget <= share) + newTarget = share; + else + addOverconsumer(largestConsumer); + updates = List.of(new TargetUpdate(largestConsumer, newTarget)); + } + } + } + else { + allow = Math.min(free, maxSize); + _usedBytes += allow; + updates = rebalance(false); + if(allowance.getGrantedMemory() <= share && allowance.getGrantedMemory() + allow > share) + addOverconsumer(allowance); + } + } + if(updates != null) + applyTargetUpdates(updates); + return allow; + } + + private MemoryAllowance findAndRemoveLargestConsumer() { + long largest = Long.MIN_VALUE; + MemoryAllowance allowance = null; + for(MemoryAllowance largestConsumer : _overconsumers) { + if(largestConsumer.getGrantedMemory() > largest) { + largest = largestConsumer.getGrantedMemory(); + allowance = largestConsumer; + } + } + _overconsumers.remove(allowance); + return allowance; + } + + @Override + public void freeMemory(MemoryAllowance allowance, long freedMemory) { + List updates = null; + synchronized(this) { + if(freedMemory < 0) + throw new IllegalArgumentException(); + _usedBytes -= freedMemory; + if(allowance.isShutdown()) + updates = rebalance(false); + long share = getEqualShare(); + if(allowance.getGrantedMemory() <= share && allowance.getGrantedMemory() + freedMemory > share) + _overconsumers.remove(allowance); + else if(allowance.getGrantedMemory() <= allowance.getTargetMemory() && allowance.getGrantedMemory() > share) + addOverconsumer(allowance); + } + if(updates != null) + applyTargetUpdates(updates); + } + + @Override + public void shutdownAllowance(MemoryAllowance allowance) { + List updates; + synchronized(this) { + _overconsumers.remove(allowance); + updates = rebalance(true); + } + applyTargetUpdates(updates); + } + + @Override + public void destroyAllowance(MemoryAllowance allowance, long freedMemory) { + List updates; + synchronized(this) { + if(freedMemory < 0) + throw new IllegalArgumentException(); + _allowances.remove(allowance); + _overconsumers.remove(allowance); + _usedBytes -= freedMemory; + updates = rebalance(true); + } + applyTargetUpdates(updates); + } + + @Override + public synchronized void attachAllowance(MemoryAllowance allowance) { + _allowances.add(allowance); + allowance.setTargetMemory(_allowedBytes); + } + + private List rebalance(boolean force) { + long free = _allowedBytes - _usedBytes; + if(force) + _brokerMode = null; + if(free > _allowedBytes / 5) + return switchBrokerMode(BrokerMode.RELAXED); + else + return switchBrokerMode(BrokerMode.STRICT); + } + + private List switchBrokerMode(BrokerMode newMode) { + if(newMode == _brokerMode) + return null; + List updates = switch(newMode) { + case STRICT -> rebalanceToStrict(); + case RELAXED -> rebalanceToRelaxed(); + default -> throw new IllegalStateException("Unsupported broker mode " + newMode); + }; + _brokerMode = newMode; + return updates; + } + + private List rebalanceToStrict() { + List updates = new ArrayList<>(); + long share = getEqualShare(); + for(MemoryAllowance allowance : _allowances) { + if(allowance.isShutdown()) + continue; + if(allowance.getUsedMemory() > share) { + updates.add(new TargetUpdate(allowance, + Math.min(allowance.getTargetMemory(), share + (long) ((allowance.getUsedMemory() - share) * 0.9)))); + } + } + refreshOverconsumers(updates); + return updates; + } + + private List rebalanceToRelaxed() { + List updates = new ArrayList<>(); + long free = _allowedBytes - _usedBytes; + for(MemoryAllowance allowance : _allowances) { + if(allowance.isShutdown()) + continue; + updates.add(new TargetUpdate(allowance, allowance.getGrantedMemory() + free)); + } + refreshOverconsumers(updates); + return updates; + } + + private long getEqualShare() { + return _allowances.isEmpty() ? _allowedBytes : _allowedBytes / _allowances.size(); + } + + private void addOverconsumer(MemoryAllowance allowance) { + if(!_overconsumers.contains(allowance)) + _overconsumers.add(allowance); + } + + private void refreshOverconsumers(List updates) { + _overconsumers.clear(); + long share = getEqualShare(); + for(MemoryAllowance allowance : _allowances) { + if(allowance.isShutdown()) + continue; + long target = allowance.getTargetMemory(); + for(TargetUpdate update : updates) { + if(update._allowance == allowance) { + target = update._target; + break; + } + } + if(allowance.getGrantedMemory() > share && allowance.getGrantedMemory() <= target) + _overconsumers.add(allowance); + } + } + + private static void applyTargetUpdates(List updates) { + for(TargetUpdate update : updates) + update._allowance.setTargetMemory(update._target); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/ooc/memory/InMemoryQueueCallback.java b/src/main/java/org/apache/sysds/runtime/ooc/memory/InMemoryQueueCallback.java new file mode 100644 index 00000000000..7496279d3e3 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/ooc/memory/InMemoryQueueCallback.java @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.memory; + +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.instructions.ooc.OOCStream; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; + +import java.util.concurrent.atomic.AtomicInteger; + +public class InMemoryQueueCallback implements OOCStream.QueueCallback { + private CallbackHandle _handle; + private boolean _closed; + + public InMemoryQueueCallback(IndexedMatrixValue result, DMLRuntimeException failure, MemoryAllowance allow, + long reservedBytes) { + _handle = new CallbackHandle(result, failure, allow, reservedBytes); + _closed = false; + } + + private InMemoryQueueCallback(CallbackHandle handle) { + _handle = handle; + _closed = false; + } + + @Override + public IndexedMatrixValue get() { + return _handle.get(); + } + + @Override + public synchronized OOCStream.QueueCallback keepOpen() { + if(_closed) + throw new IllegalStateException("Cannot keep open a closed callback"); + _handle._refCtr.incrementAndGet(); + return new InMemoryQueueCallback(_handle); + } + + @Override + public void fail(DMLRuntimeException failure) { + _handle._failure = failure; + } + + public long getManagedBytes() { + synchronized(_handle) { + return _handle._reservedBytes; + } + } + + public boolean tryTransferOwnership(MemoryAllowance allowance) { + synchronized(_handle) { + long bytes = _handle._reservedBytes; + if(bytes <= 0 || _handle._allow == allowance) + return true; + if(_handle._cacheIdx >= 0) + return false; + if(!allowance.tryReserve(bytes)) + return false; + _handle._allow.release(bytes); + _handle._allow = allowance; + return true; + } + } + + public void transferOwnershipBlocking(MemoryAllowance allowance) { + synchronized(_handle) { + long bytes = _handle._reservedBytes; + if(bytes <= 0 || _handle._allow == allowance) + return; + if(_handle._cacheIdx >= 0) + throw new IllegalStateException("Cannot transfer ownership of a cached allowance callback."); + if(allowance instanceof CachedAllowance cached) + cached.admitBlocking(bytes); + else + allowance.reserveBlocking(bytes); + _handle._allow.release(bytes); + _handle._allow = allowance; + } + } + + public long releaseManagedMemory() { + synchronized(_handle) { + long bytes = _handle._reservedBytes; + if(bytes <= 0) + return 0; + _handle._reservedBytes = 0; + _handle._allow.release(bytes); + return bytes; + } + } + + @Override + public synchronized void close() { + if(_closed) + return; + _closed = true; + if(_handle._refCtr.decrementAndGet() == 0) + _handle.closeFinal(); + _handle = null; + } + + @Override + public boolean isEos() { + return _handle.isEos(); + } + + @Override + public boolean isFailure() { + return _handle._failure != null; + } + + CallbackHandle getHandle() { + return _handle; + } + + static final class CallbackHandle { + private volatile IndexedMatrixValue _result; + private final AtomicInteger _refCtr; + private MemoryAllowance _allow; + private long _reservedBytes; + private volatile DMLRuntimeException _failure; + private int _cacheIdx; + + private CallbackHandle(IndexedMatrixValue result, DMLRuntimeException failure, MemoryAllowance allow, + long reservedBytes) { + _result = result; + _failure = failure; + _refCtr = new AtomicInteger(1); + _allow = allow; + _reservedBytes = reservedBytes; + _cacheIdx = -1; + } + + private IndexedMatrixValue get() { + if(_failure != null) + throw _failure; + return _result; + } + + private boolean isEos() { + return _failure == null && _result == null; + } + + synchronized void attachCachedAllowance(CachedAllowance allowance, int index) { + if(_allow != allowance) + throw new IllegalStateException("Callback ownership must already belong to the cached allowance."); + if(_cacheIdx >= 0 && _cacheIdx != index) + throw new IllegalStateException("Callback is already attached to a different cached slot."); + _cacheIdx = index; + } + + synchronized void detachCachedAllowance() { + _cacheIdx = -1; + } + + boolean isExclusiveToRoot() { + return _refCtr.get() == 1; + } + + private synchronized IndexedMatrixValue takeManagedResultForHandover() { + IndexedMatrixValue result = _result; + _result = null; + return result; + } + + private void closeFinal() { + _result = null; + _allow.release(_reservedBytes); + _reservedBytes = 0; + _cacheIdx = -1; + } + } + + public IndexedMatrixValue takeManagedResultForHandover() { + return _handle.takeManagedResultForHandover(); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/ooc/memory/MemoryAllowance.java b/src/main/java/org/apache/sysds/runtime/ooc/memory/MemoryAllowance.java new file mode 100644 index 00000000000..64518ded4a3 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/ooc/memory/MemoryAllowance.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.memory; + +public interface MemoryAllowance { + boolean tryReserve(long bytes); + void reserveBlocking(long bytes); + void release(long bytes); + long getUsedMemory(); + long getGrantedMemory(); + long getTargetMemory(); + void setTargetMemory(long targetMemory); + void shutdown(); + boolean isShutdown(); + + default void destroy() { + shutdown(); + } + + default long getFreeMemory() { + return Math.max(0, getGrantedMemory() - getUsedMemory()); + } + + default boolean isUnderPressure() { + return getGrantedMemory() > getTargetMemory(); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/ooc/memory/MemoryBroker.java b/src/main/java/org/apache/sysds/runtime/ooc/memory/MemoryBroker.java new file mode 100644 index 00000000000..fb4d6ae182d --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/ooc/memory/MemoryBroker.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.memory; + +public interface MemoryBroker { + long requestMemory(MemoryAllowance allowance, long minSize, long maxSize); + void freeMemory(MemoryAllowance allowance, long freedMemory); + void shutdownAllowance(MemoryAllowance allowance); + void destroyAllowance(MemoryAllowance allowance, long freedMemory); + void attachAllowance(MemoryAllowance allowance); +} diff --git a/src/main/java/org/apache/sysds/runtime/ooc/memory/SyncMemoryAllowance.java b/src/main/java/org/apache/sysds/runtime/ooc/memory/SyncMemoryAllowance.java new file mode 100644 index 00000000000..85d2cbfcd2b --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/ooc/memory/SyncMemoryAllowance.java @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.memory; + +import org.apache.sysds.runtime.DMLRuntimeException; + +public class SyncMemoryAllowance implements MemoryAllowance { + protected final MemoryBroker _broker; + protected volatile long _usedBytes; + protected volatile long _grantedBytes; + protected volatile long _targetBytes; + protected volatile boolean _shutdown; + protected volatile boolean _destroyed; + + public SyncMemoryAllowance(MemoryBroker broker) { + _broker = broker; + _usedBytes = 0; + _grantedBytes = 0; + _targetBytes = 0; + _shutdown = false; + _destroyed = false; + broker.attachAllowance(this); + } + + @Override + public boolean tryReserve(long bytes) { + long minRequest; + long maxRequest; + synchronized(this) { + if(_shutdown || _destroyed) + return false; + if(_usedBytes + bytes > _targetBytes) + return false; + if(_usedBytes + bytes <= _grantedBytes) { + _usedBytes += bytes; + return true; + } + minRequest = _usedBytes + bytes - _grantedBytes; + maxRequest = Math.max(minRequest, Math.max(_grantedBytes, bytes) * 2); + } + + long granted = _broker.requestMemory(this, minRequest, maxRequest); + long refund = 0; + boolean success = false; + synchronized(this) { + if(_shutdown || _destroyed) + refund = granted; + else { + _grantedBytes += granted; + if(_usedBytes + bytes <= _targetBytes && _usedBytes + bytes <= _grantedBytes) { + _usedBytes += bytes; + success = true; + } + notifyAll(); + } + } + if(refund > 0) + _broker.freeMemory(this, refund); + return success; + } + + @Override + public void reserveBlocking(long bytes) { + if(_shutdown || _destroyed) + throw new IllegalStateException("Cannot reserve memory on closed allowance."); + while(true) { + if(tryReserve(bytes)) { + synchronized(this) { + notifyAll(); + } + return; + } + synchronized(this) { + if(_shutdown || _destroyed) + throw new IllegalStateException("Cannot reserve memory on closed allowance."); + try { + wait(); + } + catch(InterruptedException e) { + throw new DMLRuntimeException(e); + } + } + } + } + + @Override + public void release(long bytes) { + long freedMemory = 0; + long destroyFreedMemory = 0; + boolean destroy = false; + synchronized(this) { + _usedBytes -= bytes; + if(_shutdown) { + long oldGrantedBytes = _grantedBytes; + _grantedBytes = _usedBytes; + if(_usedBytes == 0) { + _destroyed = true; + destroy = true; + destroyFreedMemory = oldGrantedBytes; + } + else { + freedMemory = oldGrantedBytes - _grantedBytes; + } + } + else if(_grantedBytes > _targetBytes) { + long oldGrantedBytes = _grantedBytes; + _grantedBytes = Math.max(_usedBytes, _targetBytes); + freedMemory = oldGrantedBytes - _grantedBytes; + } + notifyAll(); + } + if(destroy) + _broker.destroyAllowance(this, destroyFreedMemory); + else if(freedMemory > 0) + _broker.freeMemory(this, freedMemory); + } + + @Override + public long getUsedMemory() { + return _usedBytes; + } + + @Override + public long getGrantedMemory() { + return _grantedBytes; + } + + @Override + public long getTargetMemory() { + return _targetBytes; + } + + @Override + public synchronized void setTargetMemory(long targetMemory) { + if(_shutdown || _destroyed) + return; + _targetBytes = targetMemory; + notifyAll(); + } + + @Override + public void shutdown() { + long freedMemory = 0; + long destroyFreedMemory = 0; + boolean destroy = false; + synchronized(this) { + if(_shutdown || _destroyed) + return; + _shutdown = true; + long oldGrantedBytes = _grantedBytes; + _grantedBytes = _usedBytes; + _targetBytes = 0; + if(_usedBytes == 0) { + _destroyed = true; + destroy = true; + destroyFreedMemory = oldGrantedBytes; + } + else { + freedMemory = oldGrantedBytes - _grantedBytes; + } + notifyAll(); + } + _broker.shutdownAllowance(this); + if(destroy) + _broker.destroyAllowance(this, destroyFreedMemory); + else if(freedMemory > 0) + _broker.freeMemory(this, freedMemory); + } + + @Override + public boolean isShutdown() { + return _shutdown || _destroyed; + } +} diff --git a/src/test/java/org/apache/sysds/test/component/ooc/memory/OOCMemoryAllowanceTest.java b/src/test/java/org/apache/sysds/test/component/ooc/memory/OOCMemoryAllowanceTest.java new file mode 100644 index 00000000000..6ae3cd62ed8 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/ooc/memory/OOCMemoryAllowanceTest.java @@ -0,0 +1,459 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.ooc.memory; + +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.functionobjects.Plus; +import org.apache.sysds.runtime.instructions.ooc.OOCInstruction; +import org.apache.sysds.runtime.instructions.ooc.OOCStream; +import org.apache.sysds.runtime.instructions.ooc.SubscribableTaskQueue; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.matrix.operators.BinaryOperator; +import org.apache.sysds.runtime.matrix.operators.RightScalarOperator; +import org.apache.sysds.runtime.ooc.cache.OOCCacheManager; +import org.apache.sysds.runtime.ooc.memory.CachedAllowance; +import org.apache.sysds.runtime.ooc.memory.GlobalMemoryBroker; +import org.apache.sysds.runtime.ooc.memory.InMemoryQueueCallback; +import org.apache.sysds.runtime.ooc.memory.MemoryAllowance; +import org.apache.sysds.runtime.ooc.memory.MemoryBroker; +import org.apache.sysds.runtime.ooc.memory.SyncMemoryAllowance; +import org.junit.Assert; +import org.junit.Test; +import scala.Tuple3; + +import java.util.ArrayList; +import java.util.IdentityHashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiFunction; +import java.util.function.Function; + +public class OOCMemoryAllowanceTest { + private static final int TILES = 20000; + + @Test + public void testOptimal() { + test(true, 0, 1); + } + + @Test + public void testWorstCase() { + test(false, 0, 1); + } + + public void test(boolean optimal, int nWarmup, int nMeasure) { + //DMLScript.OOC_STATISTICS = true; + long millis; + for(int i = 0; i < nWarmup; i++) { + testNew(optimal); + } + millis = 0; + for(int i = 0; i < nMeasure; i++) { + millis += testNew(optimal); + } + //System.out.println("New: " + millis + "ms"); + //System.out.println(Statistics.displayOOCEvictionStats()); + OOCCacheManager.reset(); + + /*for(int i = 0; i < 10; i++) { + testOld(optimal); + } + millis = 0; + for(int i = 0; i < 10; i++) { + millis += testOld(optimal); + } + //System.out.println("Old: " + millis + "ms"); + //System.out.println(Statistics.displayOOCEvictionStats()); + OOCCacheManager.reset();*/ + } + + public long testNew(boolean optimal) { + // We emulate the expression (A + 2) + B with limited memory + MemoryBroker parentBroker = new GlobalMemoryBroker(500000000L); + CoordinatedBroker broker = new CoordinatedBroker(parentBroker); + TestInstruction test = new TestInstruction(); + + MemoryAllowance leftAllowance = new SyncMemoryAllowance(broker); + MemoryAllowance rightAllowance = new SyncMemoryAllowance(broker); + MemoryAllowance joinAllowance = new SyncMemoryAllowance(broker); + CachedAllowance cache = new CachedAllowance(broker); + + OOCStream leftStream = new SubscribableTaskQueue<>(); + OOCStream rightStream = new SubscribableTaskQueue<>(); + OOCStream outStream = new SubscribableTaskQueue<>(); + + long startMillis = System.currentTimeMillis(); + + // Left producer reservation thread + new Thread(() -> { + for(int i = 0; i < TILES; i++) { + leftAllowance.reserveBlocking(8 * 1000 + /* Working memory */ 8 * 1000); + leftStream.enqueue(i); + } + leftStream.closeInput(); + }).start(); + + // Right producer reservation thread + new Thread(() -> { + for(int i = 0; i < TILES; i++) { + rightAllowance.reserveBlocking(8 * 1000); // Needs no working memory + if(optimal) + rightStream.enqueue(i); + else + rightStream.enqueue(TILES-i-1); + } + rightStream.closeInput(); + }).start(); + + OOCStream leftStreamOut = new SubscribableTaskQueue<>(); + OOCStream leftStreamOutOut = new SubscribableTaskQueue<>(); + OOCStream rightStreamOut = new SubscribableTaskQueue<>(); + + test.map(leftStream, leftStreamOut, i -> { + var imv = new IndexedMatrixValue(new MatrixIndexes(i.longValue(), 1L), new MatrixBlock(1000, 1, 5.0)); + return new InMemoryQueueCallback(imv, null, leftAllowance, 8 * 1000); + }); + test.map(leftStreamOut, leftStreamOutOut, cb -> { + try(cb) { + var imv = new IndexedMatrixValue(cb.get().getIndexes(), cb.get().getValue() + .scalarOperations(new RightScalarOperator(Plus.getPlusFnObject(), 2.0), new MatrixBlock())); + return new InMemoryQueueCallback(imv, null, leftAllowance, 8 * 1000); + } + }); + test.map(rightStream, rightStreamOut, i -> { + var imv = new IndexedMatrixValue(new MatrixIndexes(i.longValue(), 1L), new MatrixBlock(1000, 1, 3.0)); + return new InMemoryQueueCallback(imv, null, rightAllowance, 8 * 1000); + }); + + test.join(leftStreamOutOut, rightStreamOut, outStream, () -> joinAllowance.reserveBlocking(8 * 1000), cache, + (l, r) -> { + var imv = new IndexedMatrixValue(l.getIndexes(), ((MatrixBlock)l.getValue()).binaryOperations(new BinaryOperator( + Plus.getPlusFnObject()), r.getValue())); + return new InMemoryQueueCallback(imv, null, joinAllowance, 8 * 1000); + }); + + CompletableFuture future = new CompletableFuture<>(); + AtomicInteger ctr = new AtomicInteger(); + outStream.setSubscriber(cb -> { + try { + if(cb.isEos()) { + future.complete(null); + return; + } + InMemoryQueueCallback inner = cb.get(); + try(cb; inner) { + ctr.incrementAndGet(); + double checksum =((MatrixBlock)inner.get().getValue()).sum(); + if(checksum < 10000.0 - 1e-9 || checksum > 10000.0 + 1e-9) + future.completeExceptionally(new AssertionError("Wrong checksum: " + checksum)); + //System.out.println(cb.get().get().getIndexes()); + } + } + catch(Exception e) { + future.completeExceptionally(e); + } + }); + future.join(); + + Assert.assertEquals(TILES, ctr.get()); + return System.currentTimeMillis() - startMillis; + } + + public long testOld(boolean optimal) { + // We emulate the expression (A + 2) + B with limited memory + TestInstruction test = new TestInstruction(); + + OOCStream leftStream = new SubscribableTaskQueue<>(); + OOCStream rightStream = new SubscribableTaskQueue<>(); + OOCStream outStream = new SubscribableTaskQueue<>(); + + long startMillis = System.currentTimeMillis(); + + // Left producer reservation thread + new Thread(() -> { + for(int i = 0; i < TILES; i++) { + leftStream.enqueue(i); + } + leftStream.closeInput(); + }).start(); + + // Right producer reservation thread + new Thread(() -> { + for(int i = 0; i < TILES; i++) { + if(optimal) + rightStream.enqueue(i); + else + rightStream.enqueue(TILES-i-1); + } + rightStream.closeInput(); + }).start(); + + OOCStream leftStreamOut = new SubscribableTaskQueue<>(); + OOCStream leftStreamOutOut = new SubscribableTaskQueue<>(); + OOCStream rightStreamOut = new SubscribableTaskQueue<>(); + + test.map(leftStream, leftStreamOut, i -> { + var imv = new IndexedMatrixValue(new MatrixIndexes(i.longValue(), 1L), new MatrixBlock(1000, 1, 5.0)); + return imv; + }); + test.map(leftStreamOut, leftStreamOutOut, v -> { + var imv = new IndexedMatrixValue(v.getIndexes(), v.getValue() + .scalarOperations(new RightScalarOperator(Plus.getPlusFnObject(), 2.0), new MatrixBlock())); + return imv; + }); + test.map(rightStream, rightStreamOut, i -> { + var imv = new IndexedMatrixValue(new MatrixIndexes(i.longValue(), 1L), new MatrixBlock(1000, 1, 3.0)); + return imv; + }); + + test.joinOOC(leftStreamOutOut, rightStreamOut, outStream, + (l, r) -> { + var imv = new IndexedMatrixValue(l.getIndexes(), ((MatrixBlock)l.getValue()).binaryOperations(new BinaryOperator( + Plus.getPlusFnObject()), r.getValue())); + return imv; + }); + + CompletableFuture future = new CompletableFuture<>(); + outStream.setSubscriber(cb -> { + try { + if(cb.isEos()) { + future.complete(null); + return; + } + try(cb) { + //System.out.println(cb.get().getIndexes()); + } + } + catch(Exception e) { + e.printStackTrace(); + } + }); + future.join(); + return System.currentTimeMillis() - startMillis; + } + + static class TestInstruction extends OOCInstruction { + protected TestInstruction() { + super(null, "test", "test"); + } + + @Override + public void processInstruction(ExecutionContext ec) { + } + + public CompletableFuture map(OOCStream qIn, OOCStream qOut, Function mapper) { + return mapOOC(qIn, qOut, mapper); + } + + public CompletableFuture joinOOC(OOCStream l, OOCStream r, + OOCStream out, BiFunction joinFn) { + + return super.joinOOC(l, r, out, joinFn, IndexedMatrixValue::getIndexes); + } + + public CompletableFuture join(OOCStream l, OOCStream r, + OOCStream out, Runnable memoryReserver, CachedAllowance cache, + BiFunction joinFn) { + + OOCStream, OOCStream.QueueCallback, Integer>> intermediate = createWritableStream(); + + new Thread(() -> { + InMemoryQueueCallback next; + IndexedMatrixValue nextValue; + boolean nextLeft = true; + AtomicInteger pendingRequests = new AtomicInteger(1); + + while((next = (nextLeft ? l : r).dequeue()) != null) { + try { + nextValue = next.get(); + int idx = (int) nextValue.getIndexes().getRowIndex(); + var future = cache.get(idx); + if(future.isDone()) { + var cb = future.getNow(null); + if(cb == null) { + cache.handover(next, idx); + } + else { + try(cb) { + memoryReserver.run(); // reserve memory for future pipeline + intermediate.enqueue(nextLeft ? new Tuple3<>(next.keepOpen(), cb.keepOpen(), idx) : + new Tuple3<>(cb.keepOpen(), next.keepOpen(), idx)); + } + } + } + else { + pendingRequests.incrementAndGet(); + final var pinned = next.keepOpen(); + final boolean isLeft = nextLeft; + future.thenAccept(cb -> { + try(cb; pinned) { + intermediate.enqueue( + isLeft ? new Tuple3<>(pinned.keepOpen(), cb.keepOpen(), idx) : + new Tuple3<>(cb.keepOpen(), pinned.keepOpen(), idx)); + } + if(pendingRequests.decrementAndGet() == 0) + intermediate.closeInput(); + }); + } + + nextLeft = !nextLeft; + } + finally { + next.close(); + } + } + + if(pendingRequests.decrementAndGet() == 0) + intermediate.closeInput(); + }).start(); + + return mapOOC(intermediate, out, t -> { + var qL = t._1(); + var qR = t._2(); + try(qL; qR) { + return joinFn.apply(qL.get(), qR.get()); + } + finally { + cache.clear(t._3()); + } + }); + } + } + + static class CoordinatedBroker extends SyncMemoryAllowance implements MemoryBroker { + private final List _children; + private final Map _credits; + private record TargetUpdate(MemoryAllowance allowance, long target) {} + + CoordinatedBroker(MemoryBroker parentBroker) { + super(parentBroker); + _children = new ArrayList<>(); + _credits = new IdentityHashMap<>(); + } + + @Override + public void attachAllowance(MemoryAllowance allowance) { + List updates; + synchronized(this) { + _children.add(allowance); + _credits.put(allowance, 0L); + updates = rebalanceTargetsLocked(); + } + applyTargetUpdates(updates); + } + + @Override + public long requestMemory(MemoryAllowance allowance, long minSize, long maxSize) { + if(!_credits.containsKey(allowance)) + throw new UnsupportedOperationException("Allowance is not attached to CoordinatedBroker."); + List updates; + long granted; + synchronized(this) { + granted = requestGrantLocked(allowance, minSize); + updates = rebalanceTargetsLocked(); + } + applyTargetUpdates(updates); + return granted; + } + + @Override + public void freeMemory(MemoryAllowance allowance, long freedMemory) { + if(!_credits.containsKey(allowance)) + throw new UnsupportedOperationException("Allowance is not attached to CoordinatedBroker."); + if(freedMemory <= 0) + return; + List updates; + synchronized(this) { + release(freedMemory); + updates = rebalanceTargetsLocked(); + } + applyTargetUpdates(updates); + } + + @Override + public void shutdownAllowance(MemoryAllowance allowance) { + if(!_credits.containsKey(allowance)) + throw new UnsupportedOperationException("Allowance is not attached to CoordinatedBroker."); + List updates; + synchronized(this) { + updates = rebalanceTargetsLocked(); + } + applyTargetUpdates(updates); + } + + @Override + public void destroyAllowance(MemoryAllowance allowance, long freedMemory) { + if(!_credits.containsKey(allowance)) + throw new UnsupportedOperationException("Allowance is not attached to CoordinatedBroker."); + List updates; + synchronized(this) { + _children.remove(allowance); + _credits.remove(allowance); + if(freedMemory > 0) + release(freedMemory); + updates = rebalanceTargetsLocked(); + } + applyTargetUpdates(updates); + } + + private long requestGrantLocked(MemoryAllowance requester, long minSize) { + int n = _children.size(); + if(n == 0) + return 0; + long credit = _credits.getOrDefault(requester, 0L); + if(credit >= minSize) { + _credits.put(requester, credit - minSize); + return minSize; + } + + long granted = credit; + long missing = minSize - granted; + long total = n * missing; + if(!tryReserve(total)) + return 0; + _credits.put(requester, 0L); + for(MemoryAllowance child : _children) { + if(child == requester) + continue; + _credits.put(child, _credits.getOrDefault(child, 0L) + missing); + } + return minSize; + } + + private List rebalanceTargetsLocked() { + List updates = new ArrayList<>(_children.size()); + long target = getTargetMemory(); + int n = _children.size(); + long share = n == 0 ? 0 : target / n; + for(MemoryAllowance child : _children) + updates.add(new TargetUpdate(child, share)); + return updates; + } + + private static void applyTargetUpdates(List updates) { + for(TargetUpdate update : updates) + update.allowance.setTargetMemory(update.target); + } + } +}