Skip to content

Commit

Permalink
address latest comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Aravind Patnam committed Jan 13, 2023
1 parent b91c2d4 commit 3440f12
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -294,15 +294,6 @@ public void onFailure(Throwable e) {
}
}

/**
* Exception thrown when sasl request times out.
*/
public static class SaslTimeoutException extends RuntimeException {
public SaslTimeoutException(Throwable cause) {
super((cause));
}
}

/**
* Sends an opaque message to the RpcHandler on the server-side. No reply is expected for the
* message, and no delivery guarantees are made.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.network.sasl;

import com.google.common.base.Throwables;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.concurrent.TimeoutException;
Expand Down Expand Up @@ -74,7 +73,7 @@ public void doBootstrap(TransportClient client, Channel channel) {
} catch (RuntimeException ex) {
// We know it is a Sasl timeout here if it is a TimeoutException.
if (ex.getCause() instanceof TimeoutException) {
throw Throwables.propagate(new TransportClient.SaslTimeoutException(ex.getCause()));
throw new SaslTimeoutException(ex.getCause());
} else {
throw ex;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package org.apache.spark.network.sasl;

public class SaslTimeoutException extends RuntimeException {
public SaslTimeoutException(Throwable cause) {
super(cause);
}

public SaslTimeoutException(String message) {
super(message);
}

public SaslTimeoutException(String message, Throwable cause) {
super(message, cause);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

import com.google.common.collect.Sets;
import com.google.common.util.concurrent.Uninterruptibles;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.sasl.SaslTimeoutException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -85,6 +87,12 @@ void createAndStart(String[] blockIds, BlockTransferListener listener)
// while inside a synchronized block.
/** Number of times we've attempted to retry so far. */
private int retryCount = 0;
/**
* Map to track blockId to exception that the block is being retried for.
* This is mainly used in the case of SASL retries, because we need to set
* `retryCount` back to 0 in those cases.
*/
private Map<String, Throwable> blockIdToException;

/**
* Set of all block ids which have not been transferred successfully or with a non-IO Exception.
Expand Down Expand Up @@ -120,6 +128,7 @@ public RetryingBlockTransferor(
this.currentListener = new RetryingBlockTransferListener();
this.errorHandler = errorHandler;
this.enableSaslRetries = conf.enableSaslRetries();
this.blockIdToException = new HashMap<String, Throwable>();
}

public RetryingBlockTransferor(
Expand Down Expand Up @@ -197,9 +206,7 @@ private synchronized void initiateRetry() {
private synchronized boolean shouldRetry(Throwable e) {
boolean isIOException = e instanceof IOException
|| e.getCause() instanceof IOException;
boolean isSaslTimeout = enableSaslRetries &&
(e instanceof TransportClient.SaslTimeoutException ||
(e.getCause() != null && e.getCause() instanceof TransportClient.SaslTimeoutException));
boolean isSaslTimeout = enableSaslRetries && e instanceof SaslTimeoutException;
boolean hasRemainingRetries = retryCount < maxRetries;
return (isSaslTimeout || isIOException) &&
hasRemainingRetries && errorHandler.shouldRetryError(e);
Expand All @@ -220,6 +227,10 @@ private void handleBlockTransferSuccess(String blockId, ManagedBuffer data) {
if (this == currentListener && outstandingBlocksIds.contains(blockId)) {
outstandingBlocksIds.remove(blockId);
shouldForwardSuccess = true;
if (blockIdToException.containsKey(blockId) &&
blockIdToException.get(blockId) instanceof SaslTimeoutException) {
retryCount = 0;
}
}
}

Expand All @@ -236,6 +247,7 @@ private void handleBlockTransferFailure(String blockId, Throwable exception) {
synchronized (RetryingBlockTransferor.this) {
if (this == currentListener && outstandingBlocksIds.contains(blockId)) {
if (shouldRetry(exception)) {
blockIdToException.putIfAbsent(blockId, exception);
initiateRetry();
} else {
if (errorHandler.shouldLogError(exception)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;
import java.util.concurrent.TimeoutException;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.sasl.SaslTimeoutException;
import org.junit.Before;
import org.junit.Test;
import org.mockito.stubbing.Answer;
Expand Down Expand Up @@ -247,8 +247,8 @@ public void testRetryAndUnrecoverable() throws IOException, InterruptedException
public void testSaslTimeoutFailure() throws IOException, InterruptedException {
BlockFetchingListener listener = mock(BlockFetchingListener.class);
TimeoutException timeoutException = new TimeoutException();
TransportClient.SaslTimeoutException saslTimeoutException =
new TransportClient.SaslTimeoutException(timeoutException);
SaslTimeoutException saslTimeoutException =
new SaslTimeoutException(timeoutException);
List<? extends Map<String, Object>> interactions = Arrays.asList(
ImmutableMap.<String, Object>builder()
.put("b0", saslTimeoutException)
Expand All @@ -272,7 +272,7 @@ public void testRetryOnSaslTimeout() throws IOException, InterruptedException {
List<? extends Map<String, Object>> interactions = Arrays.asList(
// SaslTimeout will cause a retry. Since b0 fails, we will retry both.
ImmutableMap.<String, Object>builder()
.put("b0", new TransportClient.SaslTimeoutException(new TimeoutException()))
.put("b0", new SaslTimeoutException(new TimeoutException()))
.build(),
ImmutableMap.<String, Object>builder()
.put("b0", block0)
Expand Down

0 comments on commit 3440f12

Please sign in to comment.