Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-41415][3.2] SASL Request Retries #39645

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -19,6 +19,7 @@

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.concurrent.TimeoutException;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException;

Expand Down Expand Up @@ -65,9 +66,18 @@ public void doBootstrap(TransportClient client, Channel channel) {
SaslMessage msg = new SaslMessage(appId, payload);
ByteBuf buf = Unpooled.buffer(msg.encodedLength() + (int) msg.body().size());
msg.encode(buf);
ByteBuffer response;
buf.writeBytes(msg.body().nioByteBuffer());

ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.authRTTimeoutMs());
try {
response = client.sendRpcSync(buf.nioBuffer(), conf.authRTTimeoutMs());
} catch (RuntimeException ex) {
// We know it is a Sasl timeout here if it is a TimeoutException.
if (ex.getCause() instanceof TimeoutException) {
throw new SaslTimeoutException(ex.getCause());
} else {
throw ex;
}
}
payload = saslClient.response(JavaUtils.bufferToArray(response));
}

Expand Down
@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.network.sasl;

/**
* An exception thrown if there is a SASL timeout.
*/
public class SaslTimeoutException extends RuntimeException {
public SaslTimeoutException(Throwable cause) {
super(cause);
}

public SaslTimeoutException(String message) {
super(message);
}

public SaslTimeoutException(String message, Throwable cause) {
super(message, cause);
}
}
Expand Up @@ -374,6 +374,13 @@ public boolean useOldFetchProtocol() {
return conf.getBoolean("spark.shuffle.useOldFetchProtocol", false);
}

/** Whether to enable sasl retries or not. The number of retries is dictated by the config
* `spark.shuffle.io.maxRetries`.
*/
public boolean enableSaslRetries() {
return conf.getBoolean("spark.shuffle.sasl.enableRetries", false);
}

/**
* Class name of the implementation of MergedShuffleFileManager that merges the blocks
* pushed to it when push-based shuffle is enabled. By default, push-based shuffle is disabled at
Expand Down
Expand Up @@ -24,12 +24,14 @@
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Sets;
import com.google.common.util.concurrent.Uninterruptibles;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.sasl.SaslTimeoutException;
import org.apache.spark.network.util.NettyUtils;
import org.apache.spark.network.util.TransportConf;

Expand Down Expand Up @@ -85,6 +87,8 @@ void createAndStart(String[] blockIds, BlockTransferListener listener)
/** Number of times we've attempted to retry so far. */
private int retryCount = 0;

private boolean saslTimeoutSeen;

/**
* Set of all block ids which have not been transferred successfully or with a non-IO Exception.
* A retry involves requesting every outstanding block. Note that since this is a LinkedHashSet,
Expand All @@ -99,6 +103,9 @@ void createAndStart(String[] blockIds, BlockTransferListener listener)
*/
private RetryingBlockTransferListener currentListener;

/** Whether sasl retries are enabled. */
private final boolean enableSaslRetries;

private final ErrorHandler errorHandler;

public RetryingBlockTransferor(
Expand All @@ -115,6 +122,8 @@ public RetryingBlockTransferor(
Collections.addAll(outstandingBlocksIds, blockIds);
this.currentListener = new RetryingBlockTransferListener();
this.errorHandler = errorHandler;
this.enableSaslRetries = conf.enableSaslRetries();
this.saslTimeoutSeen = false;
}

public RetryingBlockTransferor(
Expand Down Expand Up @@ -187,13 +196,29 @@ private synchronized void initiateRetry() {

/**
* Returns true if we should retry due a block transfer failure. We will retry if and only if
* the exception was an IOException and we haven't retried 'maxRetries' times already.
* the exception was an IOException or SaslTimeoutException and we haven't retried
* 'maxRetries' times already.
*/
private synchronized boolean shouldRetry(Throwable e) {
boolean isIOException = e instanceof IOException
|| (e.getCause() != null && e.getCause() instanceof IOException);
|| e.getCause() instanceof IOException;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like the main change from 3.3/master related to this diff.
This is fine.

boolean isSaslTimeout = enableSaslRetries && e instanceof SaslTimeoutException;
if (!isSaslTimeout && saslTimeoutSeen) {
retryCount = 0;
saslTimeoutSeen = false;
}
boolean hasRemainingRetries = retryCount < maxRetries;
return isIOException && hasRemainingRetries && errorHandler.shouldRetryError(e);
boolean shouldRetry = (isSaslTimeout || isIOException) &&
hasRemainingRetries && errorHandler.shouldRetryError(e);
if (shouldRetry && isSaslTimeout) {
this.saslTimeoutSeen = true;
}
return shouldRetry;
}

@VisibleForTesting
public int getRetryCount() {
return retryCount;
}

/**
Expand All @@ -211,6 +236,10 @@ private void handleBlockTransferSuccess(String blockId, ManagedBuffer data) {
if (this == currentListener && outstandingBlocksIds.contains(blockId)) {
outstandingBlocksIds.remove(blockId);
shouldForwardSuccess = true;
if (saslTimeoutSeen) {
retryCount = 0;
saslTimeoutSeen = false;
}
}
}

Expand Down
Expand Up @@ -20,13 +20,18 @@

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeoutException;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;

import org.junit.Before;
import org.junit.Test;
import org.mockito.stubbing.Answer;
import org.mockito.stubbing.Stubber;
Expand All @@ -38,6 +43,7 @@
import org.apache.spark.network.buffer.NioManagedBuffer;
import org.apache.spark.network.util.MapConfigProvider;
import org.apache.spark.network.util.TransportConf;
import org.apache.spark.network.sasl.SaslTimeoutException;
import static org.apache.spark.network.shuffle.RetryingBlockTransferor.BlockTransferStarter;

/**
Expand All @@ -49,6 +55,16 @@ public class RetryingBlockTransferorSuite {
private final ManagedBuffer block0 = new NioManagedBuffer(ByteBuffer.wrap(new byte[13]));
private final ManagedBuffer block1 = new NioManagedBuffer(ByteBuffer.wrap(new byte[7]));
private final ManagedBuffer block2 = new NioManagedBuffer(ByteBuffer.wrap(new byte[19]));
private static Map<String, String> configMap;
private static RetryingBlockTransferor _retryingBlockTransferor;

@Before
public void initMap() {
configMap = new HashMap<String, String>() {{
put("spark.shuffle.io.maxRetries", "2");
put("spark.shuffle.io.retryWait", "0");
}};
}

@Test
public void testNoFailures() throws IOException, InterruptedException {
Expand Down Expand Up @@ -230,6 +246,101 @@ public void testRetryAndUnrecoverable() throws IOException, InterruptedException
verifyNoMoreInteractions(listener);
}

@Test
public void testSaslTimeoutFailure() throws IOException, InterruptedException {
BlockFetchingListener listener = mock(BlockFetchingListener.class);
TimeoutException timeoutException = new TimeoutException();
SaslTimeoutException saslTimeoutException =
new SaslTimeoutException(timeoutException);
List<? extends Map<String, Object>> interactions = Arrays.asList(
ImmutableMap.<String, Object>builder()
.put("b0", saslTimeoutException)
.build(),
ImmutableMap.<String, Object>builder()
.put("b0", block0)
.build()
);

performInteractions(interactions, listener);

verify(listener, timeout(5000)).onBlockTransferFailure("b0", saslTimeoutException);
verify(listener).getTransferType();
verifyNoMoreInteractions(listener);
}

@Test
public void testRetryOnSaslTimeout() throws IOException, InterruptedException {
BlockFetchingListener listener = mock(BlockFetchingListener.class);

List<? extends Map<String, Object>> interactions = Arrays.asList(
// SaslTimeout will cause a retry. Since b0 fails, we will retry both.
ImmutableMap.<String, Object>builder()
.put("b0", new SaslTimeoutException(new TimeoutException()))
.build(),
ImmutableMap.<String, Object>builder()
.put("b0", block0)
.build()
);
configMap.put("spark.shuffle.sasl.enableRetries", "true");
performInteractions(interactions, listener);

verify(listener, timeout(5000)).onBlockTransferSuccess("b0", block0);
verify(listener).getTransferType();
verifyNoMoreInteractions(listener);
assert(_retryingBlockTransferor.getRetryCount() == 0);
}

@Test
public void testRepeatedSaslRetryFailures() throws IOException, InterruptedException {
BlockFetchingListener listener = mock(BlockFetchingListener.class);
TimeoutException timeoutException = new TimeoutException();
SaslTimeoutException saslTimeoutException =
new SaslTimeoutException(timeoutException);
List<ImmutableMap<String, Object>> interactions = new ArrayList<>();
for (int i = 0; i < 3; i++) {
interactions.add(
ImmutableMap.<String, Object>builder()
.put("b0", saslTimeoutException)
.build()
);
}
configMap.put("spark.shuffle.sasl.enableRetries", "true");
performInteractions(interactions, listener);
verify(listener, timeout(5000)).onBlockTransferFailure("b0", saslTimeoutException);
verify(listener, times(3)).getTransferType();
verifyNoMoreInteractions(listener);
assert(_retryingBlockTransferor.getRetryCount() == 2);
}

@Test
public void testBlockTransferFailureAfterSasl() throws IOException, InterruptedException {
BlockFetchingListener listener = mock(BlockFetchingListener.class);

List<? extends Map<String, Object>> interactions = Arrays.asList(
ImmutableMap.<String, Object>builder()
.put("b0", new SaslTimeoutException(new TimeoutException()))
.put("b1", new IOException())
.build(),
ImmutableMap.<String, Object>builder()
.put("b0", block0)
.put("b1", new IOException())
.build(),
ImmutableMap.<String, Object>builder()
.put("b1", block1)
.build()
);
configMap.put("spark.shuffle.sasl.enableRetries", "true");
performInteractions(interactions, listener);
verify(listener, timeout(5000)).onBlockTransferSuccess("b0", block0);
verify(listener, timeout(5000)).onBlockTransferSuccess("b1", block1);
verify(listener, atLeastOnce()).getTransferType();
verifyNoMoreInteractions(listener);
// This should be equal to 1 because after the SASL exception is retried,
// retryCount should be set back to 0. Then after that b1 encounters an
// exception that is retried.
assert(_retryingBlockTransferor.getRetryCount() == 1);
}

/**
* Performs a set of interactions in response to block requests from a RetryingBlockFetcher.
* Each interaction is a Map from BlockId to either ManagedBuffer or Exception. This interaction
Expand All @@ -245,9 +356,7 @@ private static void performInteractions(List<? extends Map<String, Object>> inte
BlockFetchingListener listener)
throws IOException, InterruptedException {

MapConfigProvider provider = new MapConfigProvider(ImmutableMap.of(
"spark.shuffle.io.maxRetries", "2",
"spark.shuffle.io.retryWait", "0"));
MapConfigProvider provider = new MapConfigProvider(configMap);
TransportConf conf = new TransportConf("shuffle", provider);
BlockTransferStarter fetchStarter = mock(BlockTransferStarter.class);

Expand Down Expand Up @@ -299,6 +408,8 @@ private static void performInteractions(List<? extends Map<String, Object>> inte
assertNotNull(stub);
stub.when(fetchStarter).createAndStart(any(), any());
String[] blockIdArray = blockIds.toArray(new String[blockIds.size()]);
new RetryingBlockTransferor(conf, fetchStarter, blockIdArray, listener).start();
_retryingBlockTransferor =
new RetryingBlockTransferor(conf, fetchStarter, blockIdArray, listener);
_retryingBlockTransferor.start();
}
}