From f440f8ec3f63e5d663e1f9d9614f05a39422102a Mon Sep 17 00:00:00 2001 From: Chi Wang Date: Thu, 28 Jul 2022 17:34:03 +0200 Subject: [PATCH] Remote: Fix performance regression in "upload missing inputs". (#15998) The regression was introduced in 702df847cf32789ffe6c0a7c7679e92a7ccccb8d where we essentially create a subscriber for each digest to subscribe the result of `findMissingBlobs`. This change update the code to not create so many subscribers but maintain the same functionalities. Fixes #15872. Closes #15890. PiperOrigin-RevId: 463826260 Change-Id: Id0b1c7c309fc9653a47c5df95c609b34e6510cde --- .../lib/remote/RemoteExecutionCache.java | 280 ++++++++++-------- .../build/lib/remote/util/AsyncTaskCache.java | 34 ++- .../build/lib/remote/InMemoryRemoteCache.java | 6 + .../build/lib/remote/RemoteCacheTest.java | 213 ++++++++++++- .../remote/RemoteExecutionServiceTest.java | 36 ++- .../lib/remote/util/InMemoryCacheClient.java | 2 +- 6 files changed, 422 insertions(+), 149 deletions(-) diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionCache.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionCache.java index 5474f884233832..24c93850c2c5f2 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionCache.java +++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionCache.java @@ -13,8 +13,9 @@ // limitations under the License. package com.google.devtools.build.lib.remote; -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.util.concurrent.Futures.immediateFailedFuture; +import static com.google.common.util.concurrent.Futures.immediateFuture; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static com.google.devtools.build.lib.remote.util.RxFutures.toCompletable; import static com.google.devtools.build.lib.remote.util.RxFutures.toSingle; @@ -25,9 +26,11 @@ import build.bazel.remote.execution.v2.Digest; import build.bazel.remote.execution.v2.Directory; import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; +import com.google.devtools.build.lib.profiler.Profiler; +import com.google.devtools.build.lib.profiler.SilentCloseable; import com.google.devtools.build.lib.remote.common.RemoteActionExecutionContext; import com.google.devtools.build.lib.remote.common.RemoteCacheClient; import com.google.devtools.build.lib.remote.merkletree.MerkleTree; @@ -36,16 +39,20 @@ import com.google.devtools.build.lib.remote.util.DigestUtil; import com.google.devtools.build.lib.remote.util.RxUtils.TransferResult; import com.google.protobuf.Message; +import io.reactivex.rxjava3.annotations.NonNull; import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.CompletableObserver; import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Observable; import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.core.SingleEmitter; +import io.reactivex.rxjava3.disposables.Disposable; import io.reactivex.rxjava3.subjects.AsyncSubject; import java.io.IOException; -import java.util.HashSet; +import java.util.List; import java.util.Map; -import java.util.Set; -import java.util.concurrent.atomic.AtomicBoolean; -import javax.annotation.concurrent.GuardedBy; +import java.util.concurrent.atomic.AtomicReference; /** A {@link RemoteCache} with additional functionality needed for remote execution. */ public class RemoteExecutionCache extends RemoteCache { @@ -85,13 +92,10 @@ public void ensureInputsPresent( return; } - MissingDigestFinder missingDigestFinder = new MissingDigestFinder(context, allDigests.size()); Flowable uploads = - Flowable.fromIterable(allDigests) - .flatMapSingle( - digest -> - uploadBlobIfMissing( - context, merkleTree, additionalInputs, force, missingDigestFinder, digest)); + createUploadTasks(context, merkleTree, additionalInputs, allDigests, force) + .flatMap(uploadTasks -> findMissingBlobs(context, uploadTasks)) + .flatMapPublisher(this::waitForUploadTasks); try { mergeBulkTransfer(uploads).blockingAwait(); @@ -105,36 +109,6 @@ public void ensureInputsPresent( } } - private Single uploadBlobIfMissing( - RemoteActionExecutionContext context, - MerkleTree merkleTree, - Map additionalInputs, - boolean force, - MissingDigestFinder missingDigestFinder, - Digest digest) { - Completable upload = - casUploadCache.execute( - digest, - Completable.defer( - () -> - // Only reach here if the digest is missing and is not being uploaded. - missingDigestFinder - .registerAndCount(digest) - .flatMapCompletable( - missingDigests -> { - if (missingDigests.contains(digest)) { - return toCompletable( - () -> uploadBlob(context, digest, merkleTree, additionalInputs), - directExecutor()); - } else { - return Completable.complete(); - } - })), - /* onIgnored= */ missingDigestFinder::count, - force); - return toTransferResult(upload); - } - private ListenableFuture uploadBlob( RemoteActionExecutionContext context, Digest digest, @@ -158,99 +132,159 @@ private ListenableFuture uploadBlob( return cacheProtocol.uploadBlob(context, digest, message.toByteString()); } - return Futures.immediateFailedFuture( + return immediateFailedFuture( new IOException( format( "findMissingDigests returned a missing digest that has not been requested: %s", digest))); } - /** - * A missing digest finder that initiates the request when the internal counter reaches an - * expected count. - */ - class MissingDigestFinder { - private final int expectedCount; - - private final AsyncSubject> digestsSubject; - private final Single> resultSingle; + static class UploadTask { + Digest digest; + AtomicReference disposable; + SingleEmitter continuation; + Completable completion; + } - @GuardedBy("this") - private final Set digests; + private Single> createUploadTasks( + RemoteActionExecutionContext context, + MerkleTree merkleTree, + Map additionalInputs, + Iterable allDigests, + boolean force) { + return Single.using( + () -> Profiler.instance().profile("collect digests"), + ignored -> + Flowable.fromIterable(allDigests) + .flatMapMaybe( + digest -> + maybeCreateUploadTask(context, merkleTree, additionalInputs, digest, force)) + .collect(toImmutableList()), + SilentCloseable::close); + } - @GuardedBy("this") - private int currentCount = 0; + private Maybe maybeCreateUploadTask( + RemoteActionExecutionContext context, + MerkleTree merkleTree, + Map additionalInputs, + Digest digest, + boolean force) { + return Maybe.create( + emitter -> { + AsyncSubject completion = AsyncSubject.create(); + UploadTask uploadTask = new UploadTask(); + uploadTask.digest = digest; + uploadTask.disposable = new AtomicReference<>(); + uploadTask.completion = + Completable.fromObservable( + completion.doOnDispose( + () -> { + Disposable d = uploadTask.disposable.getAndSet(null); + if (d != null) { + d.dispose(); + } + })); + Completable upload = + casUploadCache.execute( + digest, + Single.create( + continuation -> { + uploadTask.continuation = continuation; + emitter.onSuccess(uploadTask); + }) + .flatMapCompletable( + shouldUpload -> { + if (!shouldUpload) { + return Completable.complete(); + } - MissingDigestFinder(RemoteActionExecutionContext context, int expectedCount) { - checkArgument(expectedCount > 0, "expectedCount should be greater than 0"); - this.expectedCount = expectedCount; - this.digestsSubject = AsyncSubject.create(); - this.digests = new HashSet<>(); + return toCompletable( + () -> + uploadBlob( + context, uploadTask.digest, merkleTree, additionalInputs), + directExecutor()); + }), + /* onAlreadyRunning= */ () -> emitter.onSuccess(uploadTask), + /* onAlreadyFinished= */ emitter::onComplete, + force); + upload.subscribe( + new CompletableObserver() { + @Override + public void onSubscribe(@NonNull Disposable d) { + uploadTask.disposable.set(d); + } - AtomicBoolean findMissingDigestsCalled = new AtomicBoolean(false); - this.resultSingle = - Single.fromObservable( - digestsSubject - .flatMapSingle( - digests -> { - boolean wasCalled = findMissingDigestsCalled.getAndSet(true); - // Make sure we don't have re-subscription caused by refCount() below. - checkState(!wasCalled, "FindMissingDigests is called more than once"); - return toSingle( - () -> findMissingDigests(context, digests), directExecutor()); - }) - // Use replay here because we could have a race condition that downstream hasn't - // been added to the subscription list (to receive the upstream result) while - // upstream is completed. - .replay(1) - .refCount()); - } + @Override + public void onComplete() { + completion.onComplete(); + } - /** - * Register the {@code digest} and increase the counter. - * - *

Returned Single cannot be subscribed more than once. - * - * @return Single that emits the result of the {@code FindMissingDigest} request. - */ - Single> registerAndCount(Digest digest) { - AtomicBoolean subscribed = new AtomicBoolean(false); - // count() will potentially trigger the findMissingDigests call. Adding and counting before - // returning the Single could introduce a race that the result of findMissingDigests is - // available but the consumer doesn't get it because it hasn't subscribed the returned - // Single. In this case, it subscribes after upstream is completed resulting a re-run of - // findMissingDigests (due to refCount()). - // - // Calling count() inside doOnSubscribe to ensure the consumer already subscribed to the - // returned Single to avoid a re-execution of findMissingDigests. - return resultSingle.doOnSubscribe( - d -> { - boolean wasSubscribed = subscribed.getAndSet(true); - checkState(!wasSubscribed, "Single is subscribed more than once"); - synchronized (this) { - digests.add(digest); - } - count(); - }); - } + @Override + public void onError(@NonNull Throwable e) { + Disposable d = uploadTask.disposable.get(); + if (d != null && d.isDisposed()) { + return; + } - /** Increase the counter. */ - void count() { - ImmutableSet digestsResult = null; + completion.onError(e); + } + }); + }); + } - synchronized (this) { - if (currentCount < expectedCount) { - currentCount++; - if (currentCount == expectedCount) { - digestsResult = ImmutableSet.copyOf(digests); - } - } - } + private Single> findMissingBlobs( + RemoteActionExecutionContext context, List uploadTasks) { + return Single.using( + () -> Profiler.instance().profile("findMissingDigests"), + ignored -> + Single.fromObservable( + Observable.fromSingle( + toSingle( + () -> { + ImmutableList digestsToQuery = + uploadTasks.stream() + .filter(uploadTask -> uploadTask.continuation != null) + .map(uploadTask -> uploadTask.digest) + .collect(toImmutableList()); + if (digestsToQuery.isEmpty()) { + return immediateFuture(ImmutableSet.of()); + } + return findMissingDigests(context, digestsToQuery); + }, + directExecutor()) + .map( + missingDigests -> { + for (UploadTask uploadTask : uploadTasks) { + if (uploadTask.continuation != null) { + uploadTask.continuation.onSuccess( + missingDigests.contains(uploadTask.digest)); + } + } + return uploadTasks; + })) + // Use AsyncSubject so that if downstream is disposed, the + // findMissingDigests call is not cancelled (because it may be needed by + // other + // threads). + .subscribeWith(AsyncSubject.create())) + .doOnDispose( + () -> { + for (UploadTask uploadTask : uploadTasks) { + Disposable d = uploadTask.disposable.getAndSet(null); + if (d != null) { + d.dispose(); + } + } + }), + SilentCloseable::close); + } - if (digestsResult != null) { - digestsSubject.onNext(digestsResult); - digestsSubject.onComplete(); - } - } + private Flowable waitForUploadTasks(List uploadTasks) { + return Flowable.using( + () -> Profiler.instance().profile("upload"), + ignored -> + Flowable.fromIterable(uploadTasks) + .flatMapSingle(uploadTask -> toTransferResult(uploadTask.completion)), + SilentCloseable::close); } } diff --git a/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java b/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java index 31369ef4ee1eab..07bed15a53b0d0 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java +++ b/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java @@ -257,10 +257,10 @@ public boolean isDisposed() { /** * Executes a task. * - * @see #execute(Object, Single, Action, boolean). + * @see #execute(Object, Single, Action, Action, boolean). */ public Single execute(KeyT key, Single task, boolean force) { - return execute(key, task, () -> {}, force); + return execute(key, task, () -> {}, () -> {}, force); } /** @@ -270,12 +270,18 @@ public Single execute(KeyT key, Single task, boolean force) { *

If the cache is already shutdown, a {@link CancellationException} will be emitted. * * @param key identifies the task. - * @param onIgnored callback called when provided task is ignored. + * @param onAlreadyRunning callback called when provided task is already running. + * @param onAlreadyFinished callback called when provided task is already finished. * @param force re-execute a finished task if set to {@code true}. * @return a {@link Single} which turns to completed once the task is finished or propagates the * error if any. */ - public Single execute(KeyT key, Single task, Action onIgnored, boolean force) { + public Single execute( + KeyT key, + Single task, + Action onAlreadyRunning, + Action onAlreadyFinished, + boolean force) { return Single.create( emitter -> { synchronized (lock) { @@ -285,7 +291,7 @@ public Single execute(KeyT key, Single task, Action onIgnored, b } if (!force && finished.containsKey(key)) { - onIgnored.run(); + onAlreadyFinished.run(); emitter.onSuccess(finished.get(key)); return; } @@ -294,7 +300,7 @@ public Single execute(KeyT key, Single task, Action onIgnored, b Execution execution = inProgress.get(key); if (execution != null) { - onIgnored.run(); + onAlreadyRunning.run(); } else { execution = new Execution(key, task); inProgress.put(key, execution); @@ -445,13 +451,23 @@ public Completable executeIfNot(KeyT key, Completable task) { /** Same as {@link AsyncTaskCache#execute} but operates on {@link Completable}. */ public Completable execute(KeyT key, Completable task, boolean force) { - return execute(key, task, () -> {}, force); + return execute(key, task, () -> {}, () -> {}, force); } /** Same as {@link AsyncTaskCache#execute} but operates on {@link Completable}. */ - public Completable execute(KeyT key, Completable task, Action onIgnored, boolean force) { + public Completable execute( + KeyT key, + Completable task, + Action onAlreadyRunning, + Action onAlreadyFinished, + boolean force) { return Completable.fromSingle( - cache.execute(key, task.toSingleDefault(Optional.empty()), onIgnored, force)); + cache.execute( + key, + task.toSingleDefault(Optional.empty()), + onAlreadyRunning, + onAlreadyFinished, + force)); } /** Returns a set of keys for tasks which is finished. */ diff --git a/src/test/java/com/google/devtools/build/lib/remote/InMemoryRemoteCache.java b/src/test/java/com/google/devtools/build/lib/remote/InMemoryRemoteCache.java index 2c332180a9593a..3a79cac47e6010 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/InMemoryRemoteCache.java +++ b/src/test/java/com/google/devtools/build/lib/remote/InMemoryRemoteCache.java @@ -17,6 +17,7 @@ import build.bazel.remote.execution.v2.Digest; import com.google.devtools.build.lib.remote.common.RemoteActionExecutionContext; +import com.google.devtools.build.lib.remote.common.RemoteCacheClient; import com.google.devtools.build.lib.remote.options.RemoteOptions; import com.google.devtools.build.lib.remote.util.DigestUtil; import com.google.devtools.build.lib.remote.util.InMemoryCacheClient; @@ -39,6 +40,11 @@ class InMemoryRemoteCache extends RemoteExecutionCache { super(new InMemoryCacheClient(), options, digestUtil); } + InMemoryRemoteCache( + RemoteCacheClient cacheProtocol, RemoteOptions options, DigestUtil digestUtil) { + super(cacheProtocol, options, digestUtil); + } + Digest addContents(RemoteActionExecutionContext context, String txt) throws IOException, InterruptedException { return addContents(context, txt.getBytes(UTF_8)); diff --git a/src/test/java/com/google/devtools/build/lib/remote/RemoteCacheTest.java b/src/test/java/com/google/devtools/build/lib/remote/RemoteCacheTest.java index 3b3772cc7004e2..2406c0ef2de55b 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/RemoteCacheTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/RemoteCacheTest.java @@ -39,9 +39,12 @@ import com.google.devtools.build.lib.collect.nestedset.Order; import com.google.devtools.build.lib.exec.util.FakeOwner; import com.google.devtools.build.lib.remote.common.RemoteActionExecutionContext; +import com.google.devtools.build.lib.remote.common.RemoteCacheClient; import com.google.devtools.build.lib.remote.merkletree.MerkleTree; import com.google.devtools.build.lib.remote.options.RemoteOptions; import com.google.devtools.build.lib.remote.util.DigestUtil; +import com.google.devtools.build.lib.remote.util.InMemoryCacheClient; +import com.google.devtools.build.lib.remote.util.RxNoGlobalErrorsRule; import com.google.devtools.build.lib.remote.util.TracingMetadataUtils; import com.google.devtools.build.lib.testutil.TestUtils; import com.google.devtools.build.lib.util.io.FileOutErr; @@ -56,15 +59,20 @@ import java.io.IOException; import java.io.OutputStream; import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Deque; +import java.util.List; import java.util.SortedMap; import java.util.TreeMap; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedDeque; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import org.junit.After; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -73,6 +81,7 @@ /** Tests for {@link RemoteCache}. */ @RunWith(JUnit4.class) public class RemoteCacheTest { + @Rule public final RxNoGlobalErrorsRule rxNoGlobalErrorsRule = new RxNoGlobalErrorsRule(); private RemoteActionExecutionContext context; private FileSystem fs; @@ -199,18 +208,32 @@ public void upload_emptyBlobAndFile_doNotPerformUpload() throws Exception { } @Test - public void ensureInputsPresent_interrupted_cancelInProgressUploadTasks() throws Exception { + public void ensureInputsPresent_interruptedDuringUploadBlobs_cancelInProgressUploadTasks() + throws Exception { // arrange - InMemoryRemoteCache remoteCache = spy(newRemoteCache()); + RemoteCacheClient cacheProtocol = spy(new InMemoryCacheClient()); + RemoteExecutionCache remoteCache = spy(newRemoteExecutionCache(cacheProtocol)); - CountDownLatch findMissingDigestsCalled = new CountDownLatch(1); + List> futures = new ArrayList<>(); + CountDownLatch uploadBlobCalls = new CountDownLatch(2); doAnswer( invocationOnMock -> { - findMissingDigestsCalled.countDown(); - return SettableFuture.create(); + SettableFuture future = SettableFuture.create(); + futures.add(future); + uploadBlobCalls.countDown(); + return future; }) - .when(remoteCache) - .findMissingDigests(any(), any()); + .when(cacheProtocol) + .uploadBlob(any(), any(), any()); + doAnswer( + invocationOnMock -> { + SettableFuture future = SettableFuture.create(); + futures.add(future); + uploadBlobCalls.countDown(); + return future; + }) + .when(cacheProtocol) + .uploadFile(any(), any(), any()); Path path = fs.getPath("/execroot/foo"); FileSystemUtils.writeContentAsLatin1(path, "bar"); @@ -233,7 +256,8 @@ public void ensureInputsPresent_interrupted_cancelInProgressUploadTasks() throws // act thread.start(); - findMissingDigestsCalled.await(); + uploadBlobCalls.await(); + assertThat(futures).hasSize(2); assertThat(remoteCache.casUploadCache.getInProgressTasks()).isNotEmpty(); thread.interrupt(); @@ -241,10 +265,183 @@ public void ensureInputsPresent_interrupted_cancelInProgressUploadTasks() throws // assert assertThat(remoteCache.casUploadCache.getInProgressTasks()).isEmpty(); + assertThat(remoteCache.casUploadCache.getFinishedTasks()).isEmpty(); + for (SettableFuture future : futures) { + assertThat(future.isCancelled()).isTrue(); + } + } + + @Test + public void + ensureInputsPresent_multipleConsumers_interruptedOneDuringFindMissingBlobs_keepAndFinishInProgressUploadTasks() + throws Exception { + // arrange + RemoteCacheClient cacheProtocol = spy(new InMemoryCacheClient()); + RemoteExecutionCache remoteCache = spy(newRemoteExecutionCache(cacheProtocol)); + + SettableFuture> findMissingDigestsFuture = SettableFuture.create(); + CountDownLatch findMissingDigestsCalled = new CountDownLatch(1); + doAnswer( + invocationOnMock -> { + findMissingDigestsCalled.countDown(); + return findMissingDigestsFuture; + }) + .when(remoteCache) + .findMissingDigests(any(), any()); + Deque> futures = new ConcurrentLinkedDeque<>(); + CountDownLatch uploadBlobCalls = new CountDownLatch(2); + doAnswer( + invocationOnMock -> { + SettableFuture future = SettableFuture.create(); + futures.add(future); + uploadBlobCalls.countDown(); + return future; + }) + .when(cacheProtocol) + .uploadBlob(any(), any(), any()); + doAnswer( + invocationOnMock -> { + SettableFuture future = SettableFuture.create(); + futures.add(future); + uploadBlobCalls.countDown(); + return future; + }) + .when(cacheProtocol) + .uploadFile(any(), any(), any()); + + Path path = fs.getPath("/execroot/foo"); + FileSystemUtils.writeContentAsLatin1(path, "bar"); + SortedMap inputs = new TreeMap<>(); + inputs.put(PathFragment.create("foo"), path); + MerkleTree merkleTree = MerkleTree.build(inputs, digestUtil); + + CountDownLatch ensureInputsPresentReturned = new CountDownLatch(2); + CountDownLatch ensureInterrupted = new CountDownLatch(1); + Runnable work = + () -> { + try { + remoteCache.ensureInputsPresent(context, merkleTree, ImmutableMap.of(), false); + } catch (IOException ignored) { + // ignored + } catch (InterruptedException e) { + ensureInterrupted.countDown(); + } finally { + ensureInputsPresentReturned.countDown(); + } + }; + Thread thread1 = new Thread(work); + Thread thread2 = new Thread(work); + thread1.start(); + thread2.start(); + findMissingDigestsCalled.await(); + + // act + thread1.interrupt(); + ensureInterrupted.await(); + findMissingDigestsFuture.set(ImmutableSet.copyOf(merkleTree.getAllDigests())); + + uploadBlobCalls.await(); + assertThat(futures).hasSize(2); + + // assert + assertThat(remoteCache.casUploadCache.getInProgressTasks()).hasSize(2); + assertThat(remoteCache.casUploadCache.getFinishedTasks()).isEmpty(); + for (SettableFuture future : futures) { + assertThat(future.isCancelled()).isFalse(); + } + + for (SettableFuture future : futures) { + future.set(null); + } + ensureInputsPresentReturned.await(); + assertThat(remoteCache.casUploadCache.getInProgressTasks()).isEmpty(); + assertThat(remoteCache.casUploadCache.getFinishedTasks()).hasSize(2); + } + + @Test + public void + ensureInputsPresent_multipleConsumers_interruptedOneDuringUploadBlobs_keepInProgressUploadTasks() + throws Exception { + // arrange + RemoteCacheClient cacheProtocol = spy(new InMemoryCacheClient()); + RemoteExecutionCache remoteCache = spy(newRemoteExecutionCache(cacheProtocol)); + + List> futures = new ArrayList<>(); + CountDownLatch uploadBlobCalls = new CountDownLatch(2); + doAnswer( + invocationOnMock -> { + SettableFuture future = SettableFuture.create(); + futures.add(future); + uploadBlobCalls.countDown(); + return future; + }) + .when(cacheProtocol) + .uploadBlob(any(), any(), any()); + doAnswer( + invocationOnMock -> { + SettableFuture future = SettableFuture.create(); + futures.add(future); + uploadBlobCalls.countDown(); + return future; + }) + .when(cacheProtocol) + .uploadFile(any(), any(), any()); + + Path path = fs.getPath("/execroot/foo"); + FileSystemUtils.writeContentAsLatin1(path, "bar"); + SortedMap inputs = new TreeMap<>(); + inputs.put(PathFragment.create("foo"), path); + MerkleTree merkleTree = MerkleTree.build(inputs, digestUtil); + + CountDownLatch ensureInputsPresentReturned = new CountDownLatch(2); + CountDownLatch ensureInterrupted = new CountDownLatch(1); + Runnable work = + () -> { + try { + remoteCache.ensureInputsPresent(context, merkleTree, ImmutableMap.of(), false); + } catch (IOException ignored) { + // ignored + } catch (InterruptedException e) { + ensureInterrupted.countDown(); + } finally { + ensureInputsPresentReturned.countDown(); + } + }; + Thread thread1 = new Thread(work); + Thread thread2 = new Thread(work); + + // act + thread1.start(); + thread2.start(); + uploadBlobCalls.await(); + assertThat(futures).hasSize(2); + assertThat(remoteCache.casUploadCache.getInProgressTasks()).hasSize(2); + + thread1.interrupt(); + ensureInterrupted.await(); + + // assert + assertThat(remoteCache.casUploadCache.getInProgressTasks()).hasSize(2); + assertThat(remoteCache.casUploadCache.getFinishedTasks()).isEmpty(); + for (SettableFuture future : futures) { + assertThat(future.isCancelled()).isFalse(); + } + + for (SettableFuture future : futures) { + future.set(null); + } + ensureInputsPresentReturned.await(); + assertThat(remoteCache.casUploadCache.getInProgressTasks()).isEmpty(); + assertThat(remoteCache.casUploadCache.getFinishedTasks()).hasSize(2); } private InMemoryRemoteCache newRemoteCache() { RemoteOptions options = Options.getDefaults(RemoteOptions.class); return new InMemoryRemoteCache(options, digestUtil); } + + private RemoteExecutionCache newRemoteExecutionCache(RemoteCacheClient remoteCacheClient) { + return new RemoteExecutionCache( + remoteCacheClient, Options.getDefaults(RemoteOptions.class), digestUtil); + } } diff --git a/src/test/java/com/google/devtools/build/lib/remote/RemoteExecutionServiceTest.java b/src/test/java/com/google/devtools/build/lib/remote/RemoteExecutionServiceTest.java index 9d39facdc6a0cc..c7e0839512549b 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/RemoteExecutionServiceTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/RemoteExecutionServiceTest.java @@ -24,6 +24,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -92,6 +93,7 @@ import com.google.devtools.build.lib.remote.options.RemoteOutputsMode; import com.google.devtools.build.lib.remote.util.DigestUtil; import com.google.devtools.build.lib.remote.util.FakeSpawnExecutionContext; +import com.google.devtools.build.lib.remote.util.InMemoryCacheClient; import com.google.devtools.build.lib.remote.util.RxNoGlobalErrorsRule; import com.google.devtools.build.lib.remote.util.TracingMetadataUtils; import com.google.devtools.build.lib.remote.util.Utils.InMemoryOutput; @@ -110,6 +112,7 @@ import java.util.Arrays; import java.util.Collection; import java.util.Random; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Semaphore; @@ -160,7 +163,7 @@ public final void setUp() throws Exception { checkNotNull(stderr.getParentDirectory()).createDirectoryAndParents(); outErr = new FileOutErr(stdout, stderr); - cache = spy(new InMemoryRemoteCache(remoteOptions, digestUtil)); + cache = spy(new InMemoryRemoteCache(spy(new InMemoryCacheClient()), remoteOptions, digestUtil)); executor = mock(RemoteExecutionClient.class); RequestMetadata metadata = @@ -1544,8 +1547,16 @@ public void uploadInputsIfNotPresent_sameInputs_interruptOne_keepOthers() throws @Test public void uploadInputsIfNotPresent_interrupted_requestCancelled() throws Exception { + CountDownLatch uploadBlobCalled = new CountDownLatch(1); + CountDownLatch interrupted = new CountDownLatch(1); SettableFuture> future = SettableFuture.create(); - doReturn(future).when(cache).findMissingDigests(any(), any()); + doAnswer( + invocationOnMock -> { + uploadBlobCalled.countDown(); + return future; + }) + .when(cache.cacheProtocol) + .uploadBlob(any(), any(), any()); ActionInput input = ActionInputHelper.fromPath("inputs/foo"); fakeFileCache.createScratchInput(input, "input-foo"); RemoteExecutionService service = newRemoteExecutionService(); @@ -1556,13 +1567,22 @@ public void uploadInputsIfNotPresent_interrupted_requestCancelled() throws Excep NestedSetBuilder.create(Order.STABLE_ORDER, input)); FakeSpawnExecutionContext context = newSpawnExecutionContext(spawn); RemoteAction action = service.buildRemoteAction(spawn, context); + Thread thread = + new Thread( + () -> { + try { + service.uploadInputsIfNotPresent(action, /*force=*/ false); + } catch (InterruptedException ignored) { + interrupted.countDown(); + } catch (IOException ignored) { + // intentionally ignored + } + }); - try { - Thread.currentThread().interrupt(); - service.uploadInputsIfNotPresent(action, /*force=*/ false); - } catch (InterruptedException ignored) { - // Intentionally left empty - } + thread.start(); + uploadBlobCalled.await(); + thread.interrupt(); + interrupted.await(); assertThat(future.isCancelled()).isTrue(); } diff --git a/src/test/java/com/google/devtools/build/lib/remote/util/InMemoryCacheClient.java b/src/test/java/com/google/devtools/build/lib/remote/util/InMemoryCacheClient.java index 8925640c11ccbc..1fb0fe969ef8ee 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/util/InMemoryCacheClient.java +++ b/src/test/java/com/google/devtools/build/lib/remote/util/InMemoryCacheClient.java @@ -38,7 +38,7 @@ import java.util.stream.Collectors; /** A {@link RemoteCacheClient} that stores its contents in memory. */ -public final class InMemoryCacheClient implements RemoteCacheClient { +public class InMemoryCacheClient implements RemoteCacheClient { private final ListeningExecutorService executorService = MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(100));