Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed bugs in async signing. #4567

Merged
merged 3 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -102,5 +102,5 @@ static Checksummer forNoOp() {
/**
* Given a payload, asynchronously calculate a checksum and promise to add it to the request.
*/
CompletableFuture<Void> checksum(Publisher<ByteBuffer> payload, SdkHttpRequest.Builder request);
CompletableFuture<Publisher<ByteBuffer>> checksum(Publisher<ByteBuffer> payload, SdkHttpRequest.Builder request);
}
Original file line number Diff line number Diff line change
Expand Up @@ -257,16 +257,14 @@ private static CompletableFuture<AsyncSignedRequest> doSign(AsyncSignRequest<? e

SdkHttpRequest.Builder requestBuilder = request.request().toBuilder();

CompletableFuture<V4RequestSigningResult> resultSigningResultFuture =
checksummer.checksum(request.payload().orElse(null), requestBuilder)
.thenApply(__ -> requestSigner.sign(requestBuilder));

return resultSigningResultFuture.thenApply(
resultSigningResult -> AsyncSignedRequest.builder()
.request(resultSigningResult.getSignedRequest().build())
.payload(payloadSigner.signAsync(request.payload().orElse(null), resultSigningResult))
.build()
);
return checksummer.checksum(request.payload().orElse(null), requestBuilder)
.thenApply(payload -> {
V4RequestSigningResult requestSigningResultFuture = requestSigner.sign(requestBuilder);
return AsyncSignedRequest.builder()
.request(requestSigningResultFuture.getSignedRequest().build())
.payload(payloadSigner.signAsync(payload, requestSigningResultFuture))
.build();
});
}

private static Duration validateExpirationDuration(Duration expirationDuration) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,18 @@ public void checksum(ContentStreamProvider payload, SdkHttpRequest.Builder reque
}

@Override
public CompletableFuture<Void> checksum(Publisher<ByteBuffer> payload, SdkHttpRequest.Builder request) {
public CompletableFuture<Publisher<ByteBuffer>> checksum(Publisher<ByteBuffer> payload, SdkHttpRequest.Builder request) {
ChecksumSubscriber checksumSubscriber = new ChecksumSubscriber(optionToSdkChecksum.values());

if (payload != null) {
payload.subscribe(checksumSubscriber);
if (payload == null) {
addChecksums(request);
return CompletableFuture.completedFuture(null);
}

return checksumSubscriber.checksum().thenRun(() -> addChecksums(request));
payload.subscribe(checksumSubscriber);
CompletableFuture<Publisher<ByteBuffer>> result = checksumSubscriber.completeFuture();
result.thenRun(() -> addChecksums(request));
return result;
}

private void addChecksums(SdkHttpRequest.Builder request) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,11 @@ public void checksum(ContentStreamProvider payload, SdkHttpRequest.Builder reque
}

@Override
public CompletableFuture<Void> checksum(Publisher<ByteBuffer> payload, SdkHttpRequest.Builder request) {
public CompletableFuture<Publisher<ByteBuffer>> checksum(Publisher<ByteBuffer> payload, SdkHttpRequest.Builder request) {
try {
String checksum = computation.call();
request.putHeader(X_AMZ_CONTENT_SHA256, checksum);

return CompletableFuture.completedFuture(null);
return CompletableFuture.completedFuture(payload);
} catch (Exception e) {
throw new RuntimeException("Could not retrieve checksum: ", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
import java.util.zip.Checksum;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import software.amazon.awssdk.annotations.SdkInternalApi;
Expand All @@ -30,11 +32,13 @@
*/
@SdkInternalApi
public final class ChecksumSubscriber implements Subscriber<ByteBuffer> {
private final CompletableFuture<Void> checksumming = new CompletableFuture<>();
private final CompletableFuture<Publisher<ByteBuffer>> checksumming = new CompletableFuture<>();
private final Collection<Checksum> checksums = new ArrayList<>();
private volatile boolean canceled = false;
private volatile Subscription subscription;

private final List<ByteBuffer> bufferedPayload = new ArrayList<>();

public ChecksumSubscriber(Collection<? extends Checksum> consumers) {
this.checksums.addAll(consumers);

Expand Down Expand Up @@ -65,31 +69,34 @@ public void onSubscribe(Subscription subscription) {
@Override
public void onNext(ByteBuffer byteBuffer) {
if (!canceled) {
byte[] buf;

if (byteBuffer.hasArray()) {
buf = byteBuffer.array();
} else {
buf = new byte[byteBuffer.remaining()];
byteBuffer.get(buf);
}
updateChecksumsAndBuffer(byteBuffer);
}
}

// We have to use a byte[], since update(<ByteBuffer>) is java 9+
checksums.forEach(checksum -> checksum.update(buf, 0, buf.length));
private void updateChecksumsAndBuffer(ByteBuffer buffer) {
int remaining = buffer.remaining();
if (remaining <= 0) {
return;
}

byte[] copyBuffer = new byte[remaining];
buffer.get(copyBuffer);
checksums.forEach(c -> c.update(copyBuffer, 0, remaining));
bufferedPayload.add(ByteBuffer.wrap(copyBuffer));
}


@Override
public void onError(Throwable throwable) {
checksumming.completeExceptionally(throwable);
}

@Override
public void onComplete() {
checksumming.complete(null);
checksumming.complete(new InMemoryPublisher(bufferedPayload));
}

public CompletableFuture<Void> checksum() {
public CompletableFuture<Publisher<ByteBuffer>> completeFuture() {
return checksumming;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.awssdk.http.auth.aws.internal.signer.io;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.utils.Validate;

/**
* Temporarily used for buffering all data into memory. TODO(sra-identity-auth): Remove this by supporting chunked encoding. We
* should not buffer everything into memory.
*/
@SdkInternalApi
public class InMemoryPublisher implements Publisher<ByteBuffer> {
private final AtomicBoolean subscribed = new AtomicBoolean(false);
private final List<ByteBuffer> data;

public InMemoryPublisher(List<ByteBuffer> data) {
this.data = new ArrayList<>(Validate.noNullElements(data, "Data must not contain null elements."));
}

@Override
public void subscribe(Subscriber<? super ByteBuffer> s) {
if (!subscribed.compareAndSet(false, true)) {
s.onSubscribe(new NoOpSubscription());
s.onError(new IllegalStateException("InMemoryPublisher cannot be subscribed to twice."));
return;
}

s.onSubscribe(new Subscription() {
private final AtomicBoolean sending = new AtomicBoolean(false);

private final Object doneLock = new Object();
private final AtomicBoolean done = new AtomicBoolean(false);
private final AtomicLong demand = new AtomicLong(0);
private int position = 0;

@Override
public void request(long n) {
if (done.get()) {
return;
}

try {
demand.addAndGet(n);
fulfillDemand();
} catch (Throwable t) {
finish(() -> s.onError(t));
}
}

private void fulfillDemand() {
do {
if (sending.compareAndSet(false, true)) {
try {
send();
} finally {
sending.set(false);
}
}
} while (!done.get() && demand.get() > 0);
}

private void send() {
while (true) {
assert position >= 0;
assert position <= data.size();

if (done.get()) {
break;
}

if (position == data.size()) {
finish(s::onComplete);
break;
}

if (demand.get() == 0) {
break;
}

demand.decrementAndGet();
int dataIndex = position;
s.onNext(data.get(dataIndex));
data.set(dataIndex, null); // We're done with this data here, so allow it to be garbage collected
position++;
}
}

@Override
public void cancel() {
finish(() -> {
});
}

private void finish(Runnable thingToDo) {
synchronized (doneLock) {
if (done.compareAndSet(false, true)) {
thingToDo.run();
}
}
}
});
}

private static class NoOpSubscription implements Subscription {
@Override
public void request(long n) {
}

@Override
public void cancel() {
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public void checksum_computesCorrectSha256() {
Flowable<ByteBuffer> publisher = Flowable.just(ByteBuffer.wrap(testString.getBytes(StandardCharsets.UTF_8)));
publisher.subscribe(subscriber);

joinLikeSync(subscriber.checksum());
joinLikeSync(subscriber.completeFuture());
String computedDigest = BinaryUtils.toHex(checksum.getChecksumBytes());

assertThat(computedDigest).isEqualTo(expectedDigest);
Expand All @@ -65,7 +65,7 @@ public void checksum_withMultipleChecksums_shouldComputeCorrectChecksums() {
Flowable<ByteBuffer> publisher = Flowable.just(ByteBuffer.wrap(testString.getBytes(StandardCharsets.UTF_8)));
publisher.subscribe(subscriber);

joinLikeSync(subscriber.checksum());
joinLikeSync(subscriber.completeFuture());
String computedSha256Digest = BinaryUtils.toHex(sha256Checksum.getChecksumBytes());
String computedCrc32Digest = BinaryUtils.toHex(crc32Checksum.getChecksumBytes());

Expand All @@ -79,7 +79,7 @@ public void checksum_futureCancelledBeforeSubscribe_cancelsSubscription() {

ChecksumSubscriber subscriber = new ChecksumSubscriber(Collections.emptyList());

subscriber.checksum().cancel(true);
subscriber.completeFuture().cancel(true);

subscriber.onSubscribe(mockSubscription);

Expand All @@ -97,6 +97,6 @@ public void checksum_publisherCallsOnError_errorPropagatedToFuture() {
RuntimeException error = new RuntimeException("error");
subscriber.onError(error);

assertThatThrownBy(subscriber.checksum()::join).hasCause(error);
assertThatThrownBy(subscriber.completeFuture()::join).hasCause(error);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ public void onStream(Publisher<ByteBuffer> publisher) {

@Override
public void onError(Throwable err) {
if (streamFuture == null) {
prepare();
}
Comment on lines +76 to +78
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding:

  • so onError can possibly get called without prepare()?
  • calling prepare here and then immediately calling completeExceptionally on it. is the idea to just make sure streamFuture is non-null or is also setting up the thenComplete important? Wondering about just doing streamFuture = new CompletableFuture<>(); here instead of prepare(), and further why not initialize streamFuture = new CompletableFuture<>() in the field declaration itself instead of in prepare().

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIU, prepare() would be called in the normal flow to get the transformed response (see MakeAsyncHttpRequestStage), so if there is an error after prepare() is called there, then streamFuture wouldn't be null. On the other hand, if there is an error before the normal flow is executed (e.g. prepare() wasn't called before an error occured), then prepare is called here so that the response-handler is called before exceptionally-completing the stream-future.

Copy link
Contributor Author

@millems millems Oct 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so onError can possibly get called without prepare()?

Yes. We aren't clear in the docs, but it can happen as it happened in this case. I think in general our use of prepare() should be pushed upstream more, since we have other cases where prepare() isn't getting called soon enough. I wanted to have a light touch here, since I know this is a fragile part of the codebase.

calling prepare here and then immediately calling completeExceptionally on it. is the idea to just make sure streamFuture is non-null or is also setting up the thenComplete important? Wondering about just doing streamFuture = new CompletableFuture<>(); here instead of prepare(), and further why not initialize streamFuture = new CompletableFuture<>() in the field declaration itself instead of in prepare().

I was going for a light touch, but it's possible we could have done something else.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we did streamFuture = new CompletableFuture<>() instead of prepare(), the behavior would be the same, assuming prepare() is defined as it is today. Which we should use depends on what we think is safer in the future.

Initializing streamFuture in the field declaration is riskier. We don't really document the lifecycle of a handler, so if the same handler is used for retries it wouldn't work. I don't know exactly which we do.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Matt for helping me understand better. At least my questions seem ballpark relevant. Agree with the light touch approach - this change seems safest option to me right now.

streamFuture.completeExceptionally(err);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public void onError(Throwable error) {
if (headersFuture != null) { // Failure in marshalling calls this before prepare() so value is null
headersFuture.completeExceptionally(error);
}

successResponseHandler.onError(error);
errorResponseHandler.onError(error);
}
Expand Down
Loading