Skip to content

Commit

Permalink
[SPARK-41415] SASL Request Retries
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Add the ability to retry SASL requests. Will add it as a metric too soon to track SASL retries.

### Why are the changes needed?
We are seeing increased SASL timeouts internally, and this issue would mitigate the issue. We already have this feature enabled for our 2.3 jobs, and we have seen failures significantly decrease.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Added unit tests, and tested on cluster to ensure the retries are being triggered correctly.

Closes #38959 from akpatnam25/SPARK-41415.

Authored-by: Aravind Patnam <apatnam@linkedin.com>
Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
  • Loading branch information
Aravind Patnam authored and Mridul Muralidharan committed Jan 15, 2023
1 parent 9dc792d commit 2878cd8
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 8 deletions.
Expand Up @@ -19,6 +19,7 @@

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.concurrent.TimeoutException;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException;

Expand Down Expand Up @@ -65,9 +66,18 @@ public void doBootstrap(TransportClient client, Channel channel) {
SaslMessage msg = new SaslMessage(appId, payload);
ByteBuf buf = Unpooled.buffer(msg.encodedLength() + (int) msg.body().size());
msg.encode(buf);
ByteBuffer response;
buf.writeBytes(msg.body().nioByteBuffer());

ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.authRTTimeoutMs());
try {
response = client.sendRpcSync(buf.nioBuffer(), conf.authRTTimeoutMs());
} catch (RuntimeException ex) {
// We know it is a Sasl timeout here if it is a TimeoutException.
if (ex.getCause() instanceof TimeoutException) {
throw new SaslTimeoutException(ex.getCause());
} else {
throw ex;
}
}
payload = saslClient.response(JavaUtils.bufferToArray(response));
}

Expand Down
@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.network.sasl;

/**
* An exception thrown if there is a SASL timeout.
*/
public class SaslTimeoutException extends RuntimeException {
public SaslTimeoutException(Throwable cause) {
super(cause);
}

public SaslTimeoutException(String message) {
super(message);
}

public SaslTimeoutException(String message, Throwable cause) {
super(message, cause);
}
}
Expand Up @@ -333,6 +333,13 @@ public boolean useOldFetchProtocol() {
return conf.getBoolean("spark.shuffle.useOldFetchProtocol", false);
}

/** Whether to enable sasl retries or not. The number of retries is dictated by the config
* `spark.shuffle.io.maxRetries`.
*/
public boolean enableSaslRetries() {
return conf.getBoolean("spark.shuffle.sasl.enableRetries", false);
}

/**
* Class name of the implementation of MergedShuffleFileManager that merges the blocks
* pushed to it when push-based shuffle is enabled. By default, push-based shuffle is disabled at
Expand Down
Expand Up @@ -24,12 +24,14 @@
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Sets;
import com.google.common.util.concurrent.Uninterruptibles;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.sasl.SaslTimeoutException;
import org.apache.spark.network.util.NettyUtils;
import org.apache.spark.network.util.TransportConf;

Expand Down Expand Up @@ -85,6 +87,8 @@ void createAndStart(String[] blockIds, BlockTransferListener listener)
/** Number of times we've attempted to retry so far. */
private int retryCount = 0;

private boolean saslTimeoutSeen;

/**
* Set of all block ids which have not been transferred successfully or with a non-IO Exception.
* A retry involves requesting every outstanding block. Note that since this is a LinkedHashSet,
Expand All @@ -99,6 +103,9 @@ void createAndStart(String[] blockIds, BlockTransferListener listener)
*/
private RetryingBlockTransferListener currentListener;

/** Whether sasl retries are enabled. */
private final boolean enableSaslRetries;

private final ErrorHandler errorHandler;

public RetryingBlockTransferor(
Expand All @@ -115,6 +122,8 @@ public RetryingBlockTransferor(
Collections.addAll(outstandingBlocksIds, blockIds);
this.currentListener = new RetryingBlockTransferListener();
this.errorHandler = errorHandler;
this.enableSaslRetries = conf.enableSaslRetries();
this.saslTimeoutSeen = false;
}

public RetryingBlockTransferor(
Expand Down Expand Up @@ -187,13 +196,29 @@ private synchronized void initiateRetry() {

/**
* Returns true if we should retry due a block transfer failure. We will retry if and only if
* the exception was an IOException and we haven't retried 'maxRetries' times already.
* the exception was an IOException or SaslTimeoutException and we haven't retried
* 'maxRetries' times already.
*/
private synchronized boolean shouldRetry(Throwable e) {
boolean isIOException = e instanceof IOException
|| e.getCause() instanceof IOException;
boolean isSaslTimeout = enableSaslRetries && e instanceof SaslTimeoutException;
if (!isSaslTimeout && saslTimeoutSeen) {
retryCount = 0;
saslTimeoutSeen = false;
}
boolean hasRemainingRetries = retryCount < maxRetries;
return isIOException && hasRemainingRetries && errorHandler.shouldRetryError(e);
boolean shouldRetry = (isSaslTimeout || isIOException) &&
hasRemainingRetries && errorHandler.shouldRetryError(e);
if (shouldRetry && isSaslTimeout) {
this.saslTimeoutSeen = true;
}
return shouldRetry;
}

@VisibleForTesting
public int getRetryCount() {
return retryCount;
}

/**
Expand All @@ -211,6 +236,10 @@ private void handleBlockTransferSuccess(String blockId, ManagedBuffer data) {
if (this == currentListener && outstandingBlocksIds.contains(blockId)) {
outstandingBlocksIds.remove(blockId);
shouldForwardSuccess = true;
if (saslTimeoutSeen) {
retryCount = 0;
saslTimeoutSeen = false;
}
}
}

Expand Down
Expand Up @@ -20,13 +20,18 @@

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeoutException;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;

import org.junit.Before;
import org.junit.Test;
import org.mockito.stubbing.Answer;
import org.mockito.stubbing.Stubber;
Expand All @@ -38,6 +43,7 @@
import org.apache.spark.network.buffer.NioManagedBuffer;
import org.apache.spark.network.util.MapConfigProvider;
import org.apache.spark.network.util.TransportConf;
import org.apache.spark.network.sasl.SaslTimeoutException;
import static org.apache.spark.network.shuffle.RetryingBlockTransferor.BlockTransferStarter;

/**
Expand All @@ -49,6 +55,16 @@ public class RetryingBlockTransferorSuite {
private final ManagedBuffer block0 = new NioManagedBuffer(ByteBuffer.wrap(new byte[13]));
private final ManagedBuffer block1 = new NioManagedBuffer(ByteBuffer.wrap(new byte[7]));
private final ManagedBuffer block2 = new NioManagedBuffer(ByteBuffer.wrap(new byte[19]));
private static Map<String, String> configMap;
private static RetryingBlockTransferor _retryingBlockTransferor;

@Before
public void initMap() {
configMap = new HashMap<String, String>() {{
put("spark.shuffle.io.maxRetries", "2");
put("spark.shuffle.io.retryWait", "0");
}};
}

@Test
public void testNoFailures() throws IOException, InterruptedException {
Expand Down Expand Up @@ -230,6 +246,101 @@ public void testRetryAndUnrecoverable() throws IOException, InterruptedException
verifyNoMoreInteractions(listener);
}

@Test
public void testSaslTimeoutFailure() throws IOException, InterruptedException {
BlockFetchingListener listener = mock(BlockFetchingListener.class);
TimeoutException timeoutException = new TimeoutException();
SaslTimeoutException saslTimeoutException =
new SaslTimeoutException(timeoutException);
List<? extends Map<String, Object>> interactions = Arrays.asList(
ImmutableMap.<String, Object>builder()
.put("b0", saslTimeoutException)
.build(),
ImmutableMap.<String, Object>builder()
.put("b0", block0)
.build()
);

performInteractions(interactions, listener);

verify(listener, timeout(5000)).onBlockTransferFailure("b0", saslTimeoutException);
verify(listener).getTransferType();
verifyNoMoreInteractions(listener);
}

@Test
public void testRetryOnSaslTimeout() throws IOException, InterruptedException {
BlockFetchingListener listener = mock(BlockFetchingListener.class);

List<? extends Map<String, Object>> interactions = Arrays.asList(
// SaslTimeout will cause a retry. Since b0 fails, we will retry both.
ImmutableMap.<String, Object>builder()
.put("b0", new SaslTimeoutException(new TimeoutException()))
.build(),
ImmutableMap.<String, Object>builder()
.put("b0", block0)
.build()
);
configMap.put("spark.shuffle.sasl.enableRetries", "true");
performInteractions(interactions, listener);

verify(listener, timeout(5000)).onBlockTransferSuccess("b0", block0);
verify(listener).getTransferType();
verifyNoMoreInteractions(listener);
assert(_retryingBlockTransferor.getRetryCount() == 0);
}

@Test
public void testRepeatedSaslRetryFailures() throws IOException, InterruptedException {
BlockFetchingListener listener = mock(BlockFetchingListener.class);
TimeoutException timeoutException = new TimeoutException();
SaslTimeoutException saslTimeoutException =
new SaslTimeoutException(timeoutException);
List<ImmutableMap<String, Object>> interactions = new ArrayList<>();
for (int i = 0; i < 3; i++) {
interactions.add(
ImmutableMap.<String, Object>builder()
.put("b0", saslTimeoutException)
.build()
);
}
configMap.put("spark.shuffle.sasl.enableRetries", "true");
performInteractions(interactions, listener);
verify(listener, timeout(5000)).onBlockTransferFailure("b0", saslTimeoutException);
verify(listener, times(3)).getTransferType();
verifyNoMoreInteractions(listener);
assert(_retryingBlockTransferor.getRetryCount() == 2);
}

@Test
public void testBlockTransferFailureAfterSasl() throws IOException, InterruptedException {
BlockFetchingListener listener = mock(BlockFetchingListener.class);

List<? extends Map<String, Object>> interactions = Arrays.asList(
ImmutableMap.<String, Object>builder()
.put("b0", new SaslTimeoutException(new TimeoutException()))
.put("b1", new IOException())
.build(),
ImmutableMap.<String, Object>builder()
.put("b0", block0)
.put("b1", new IOException())
.build(),
ImmutableMap.<String, Object>builder()
.put("b1", block1)
.build()
);
configMap.put("spark.shuffle.sasl.enableRetries", "true");
performInteractions(interactions, listener);
verify(listener, timeout(5000)).onBlockTransferSuccess("b0", block0);
verify(listener, timeout(5000)).onBlockTransferSuccess("b1", block1);
verify(listener, atLeastOnce()).getTransferType();
verifyNoMoreInteractions(listener);
// This should be equal to 1 because after the SASL exception is retried,
// retryCount should be set back to 0. Then after that b1 encounters an
// exception that is retried.
assert(_retryingBlockTransferor.getRetryCount() == 1);
}

/**
* Performs a set of interactions in response to block requests from a RetryingBlockFetcher.
* Each interaction is a Map from BlockId to either ManagedBuffer or Exception. This interaction
Expand All @@ -244,9 +355,7 @@ private static void performInteractions(List<? extends Map<String, Object>> inte
BlockFetchingListener listener)
throws IOException, InterruptedException {

MapConfigProvider provider = new MapConfigProvider(ImmutableMap.of(
"spark.shuffle.io.maxRetries", "2",
"spark.shuffle.io.retryWait", "0"));
MapConfigProvider provider = new MapConfigProvider(configMap);
TransportConf conf = new TransportConf("shuffle", provider);
BlockTransferStarter fetchStarter = mock(BlockTransferStarter.class);

Expand Down Expand Up @@ -298,6 +407,8 @@ private static void performInteractions(List<? extends Map<String, Object>> inte
assertNotNull(stub);
stub.when(fetchStarter).createAndStart(any(), any());
String[] blockIdArray = blockIds.toArray(new String[blockIds.size()]);
new RetryingBlockTransferor(conf, fetchStarter, blockIdArray, listener).start();
_retryingBlockTransferor =
new RetryingBlockTransferor(conf, fetchStarter, blockIdArray, listener);
_retryingBlockTransferor.start();
}
}

0 comments on commit 2878cd8

Please sign in to comment.