Skip to content

Commit

Permalink
Use mapIndex and attemptNo as taskAttemptNo in Spark2
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoMi committed Feb 27, 2024
1 parent ff555c5 commit 0762498
Show file tree
Hide file tree
Showing 5 changed files with 325 additions and 308 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Maps;
import org.apache.hadoop.conf.Configuration;
import org.apache.spark.MapOutputTracker;
Expand Down Expand Up @@ -50,6 +51,63 @@ public abstract class RssShuffleManagerBase implements RssShuffleManagerInterfac
private Method unregisterAllMapOutputMethod;
private Method registerShuffleMethod;

/**
* Provides a task attempt id that is unique for a shuffle stage.
*
* <p>We are not using context.taskAttemptId() here as this is a monotonically increasing number
* that is unique across the entire Spark app which can reach very large numbers, which can
* practically reach LONG.MAX_VALUE. That would overflow the bits in the block id.
*
* <p>Here we use the map index or task id, appended by the attempt number per task. The map index
* is limited by the number of partitions of a stage. The attempt number per task is limited /
* configured by spark.task.maxFailures (default: 4).
*
* @return a task attempt id unique for a shuffle stage
*/
@VisibleForTesting
protected static long getTaskAttemptId(
int mapIndex, int attemptNo, int maxFailures, boolean speculation, int maxTaskAttemptIdBits) {
// attempt number is zero based: 0, 1, …, maxFailures-1
// max maxFailures < 1 is not allowed but for safety, we interpret that as maxFailures == 1
int maxAttemptNo = maxFailures < 1 ? 0 : maxFailures - 1;

// with speculative execution enabled we could observe +1 attempts
if (speculation) {
maxAttemptNo++;
}

if (attemptNo > maxAttemptNo) {
// this should never happen, if it does, our assumptions are wrong,
// and we risk overflowing the attempt number bits
throw new RssException(
"Observing attempt number "
+ attemptNo
+ " while maxFailures is set to "
+ maxFailures
+ (speculation ? " with speculation enabled" : "")
+ ".");
}

int attemptBits = 32 - Integer.numberOfLeadingZeros(maxAttemptNo);
int mapIndexBits = 32 - Integer.numberOfLeadingZeros(mapIndex);
if (mapIndexBits + attemptBits > maxTaskAttemptIdBits) {
throw new RssException(
"Observing mapIndex["
+ mapIndex
+ "] that would produce a taskAttemptId with "
+ (mapIndexBits + attemptBits)
+ " bits which is larger than the allowed "
+ maxTaskAttemptIdBits
+ " bits (maxFailures["
+ maxFailures
+ "], speculation["
+ speculation
+ "]). Please consider providing more bits for taskAttemptIds.");
}

return (long) mapIndex << attemptBits | attemptNo;
}

@Override
public void unregisterAllMapOutput(int shuffleId) throws SparkException {
if (!RssSparkShuffleUtils.isStageResubmitSupported()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@

package org.apache.uniffle.shuffle.manager;

import java.util.Arrays;

import org.apache.spark.SparkConf;
import org.junit.jupiter.api.Test;

import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.exception.RssException;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrowsExactly;
import static org.junit.jupiter.api.Assertions.assertTrue;

public class RssShuffleManagerBaseTest {
Expand All @@ -39,4 +43,254 @@ public void testGetDefaultRemoteStorageInfo() {
assertEquals(remoteStorageInfo.getConfItems().size(), 1);
assertEquals(remoteStorageInfo.getConfItems().get("fs.defaultFs"), "hdfs://rbf-xxx/foo");
}

private long bits(String string) {
return Long.parseLong(string.replaceAll("[|]", ""), 2);
}

@Test
public void testGetTaskAttemptIdWithoutSpeculation() {
// the expected bits("xy|z") represents the expected Long in bit notation where | is used to
// separate map index from attempt number, so merely for visualization purposes

// maxFailures < 1 not allowed, we fall back to maxFailures=1 to be robust
for (int maxFailures : Arrays.asList(-1, 0, 1)) {
assertEquals(
bits("0000|"),
RssShuffleManagerBase.getTaskAttemptId(0, 0, maxFailures, false, 10),
String.valueOf(maxFailures));
assertEquals(
bits("0001|"),
RssShuffleManagerBase.getTaskAttemptId(1, 0, maxFailures, false, 10),
String.valueOf(maxFailures));
assertEquals(
bits("0010|"),
RssShuffleManagerBase.getTaskAttemptId(2, 0, maxFailures, false, 10),
String.valueOf(maxFailures));
}

// maxFailures of 2
assertEquals(bits("000|0"), RssShuffleManagerBase.getTaskAttemptId(0, 0, 2, false, 10));
assertEquals(bits("000|1"), RssShuffleManagerBase.getTaskAttemptId(0, 1, 2, false, 10));
assertEquals(bits("001|0"), RssShuffleManagerBase.getTaskAttemptId(1, 0, 2, false, 10));
assertEquals(bits("001|1"), RssShuffleManagerBase.getTaskAttemptId(1, 1, 2, false, 10));
assertEquals(bits("010|0"), RssShuffleManagerBase.getTaskAttemptId(2, 0, 2, false, 10));
assertEquals(bits("010|1"), RssShuffleManagerBase.getTaskAttemptId(2, 1, 2, false, 10));
assertEquals(bits("011|0"), RssShuffleManagerBase.getTaskAttemptId(3, 0, 2, false, 10));
assertEquals(bits("011|1"), RssShuffleManagerBase.getTaskAttemptId(3, 1, 2, false, 10));

// maxFailures of 3
assertEquals(bits("00|00"), RssShuffleManagerBase.getTaskAttemptId(0, 0, 3, false, 10));
assertEquals(bits("00|01"), RssShuffleManagerBase.getTaskAttemptId(0, 1, 3, false, 10));
assertEquals(bits("00|10"), RssShuffleManagerBase.getTaskAttemptId(0, 2, 3, false, 10));
assertEquals(bits("01|00"), RssShuffleManagerBase.getTaskAttemptId(1, 0, 3, false, 10));
assertEquals(bits("01|01"), RssShuffleManagerBase.getTaskAttemptId(1, 1, 3, false, 10));
assertEquals(bits("01|10"), RssShuffleManagerBase.getTaskAttemptId(1, 2, 3, false, 10));
assertEquals(bits("10|00"), RssShuffleManagerBase.getTaskAttemptId(2, 0, 3, false, 10));
assertEquals(bits("10|01"), RssShuffleManagerBase.getTaskAttemptId(2, 1, 3, false, 10));
assertEquals(bits("10|10"), RssShuffleManagerBase.getTaskAttemptId(2, 2, 3, false, 10));
assertEquals(bits("11|00"), RssShuffleManagerBase.getTaskAttemptId(3, 0, 3, false, 10));
assertEquals(bits("11|01"), RssShuffleManagerBase.getTaskAttemptId(3, 1, 3, false, 10));
assertEquals(bits("11|10"), RssShuffleManagerBase.getTaskAttemptId(3, 2, 3, false, 10));

// maxFailures of 4
assertEquals(bits("00|00"), RssShuffleManagerBase.getTaskAttemptId(0, 0, 4, false, 10));
assertEquals(bits("00|01"), RssShuffleManagerBase.getTaskAttemptId(0, 1, 4, false, 10));
assertEquals(bits("00|10"), RssShuffleManagerBase.getTaskAttemptId(0, 2, 4, false, 10));
assertEquals(bits("00|11"), RssShuffleManagerBase.getTaskAttemptId(0, 3, 4, false, 10));
assertEquals(bits("01|00"), RssShuffleManagerBase.getTaskAttemptId(1, 0, 4, false, 10));
assertEquals(bits("01|01"), RssShuffleManagerBase.getTaskAttemptId(1, 1, 4, false, 10));
assertEquals(bits("01|10"), RssShuffleManagerBase.getTaskAttemptId(1, 2, 4, false, 10));
assertEquals(bits("01|11"), RssShuffleManagerBase.getTaskAttemptId(1, 3, 4, false, 10));
assertEquals(bits("10|00"), RssShuffleManagerBase.getTaskAttemptId(2, 0, 4, false, 10));
assertEquals(bits("10|01"), RssShuffleManagerBase.getTaskAttemptId(2, 1, 4, false, 10));
assertEquals(bits("10|10"), RssShuffleManagerBase.getTaskAttemptId(2, 2, 4, false, 10));
assertEquals(bits("10|11"), RssShuffleManagerBase.getTaskAttemptId(2, 3, 4, false, 10));
assertEquals(bits("11|00"), RssShuffleManagerBase.getTaskAttemptId(3, 0, 4, false, 10));
assertEquals(bits("11|01"), RssShuffleManagerBase.getTaskAttemptId(3, 1, 4, false, 10));
assertEquals(bits("11|10"), RssShuffleManagerBase.getTaskAttemptId(3, 2, 4, false, 10));
assertEquals(bits("11|11"), RssShuffleManagerBase.getTaskAttemptId(3, 3, 4, false, 10));

// maxFailures of 5
assertEquals(bits("0|000"), RssShuffleManagerBase.getTaskAttemptId(0, 0, 5, false, 10));
assertEquals(bits("1|100"), RssShuffleManagerBase.getTaskAttemptId(1, 4, 5, false, 10));

// test with ints that overflow into signed int and long
assertEquals(
Integer.MAX_VALUE,
RssShuffleManagerBase.getTaskAttemptId(Integer.MAX_VALUE, 0, 1, false, 31));
assertEquals(
(long) Integer.MAX_VALUE << 1 | 1,
RssShuffleManagerBase.getTaskAttemptId(Integer.MAX_VALUE, 1, 2, false, 32));
assertEquals(
(long) Integer.MAX_VALUE << 2 | 3,
RssShuffleManagerBase.getTaskAttemptId(Integer.MAX_VALUE, 3, 4, false, 33));
assertEquals(
(long) Integer.MAX_VALUE << 3 | 7,
RssShuffleManagerBase.getTaskAttemptId(Integer.MAX_VALUE, 7, 8, false, 34));

// test with attemptNo >= maxFailures
assertThrowsExactly(
RssException.class, () -> RssShuffleManagerBase.getTaskAttemptId(0, 1, -1, false, 10));
assertThrowsExactly(
RssException.class, () -> RssShuffleManagerBase.getTaskAttemptId(0, 1, 0, false, 10));
for (int maxFailures : Arrays.asList(1, 2, 3, 4, 8, 128)) {
assertThrowsExactly(
RssException.class,
() -> RssShuffleManagerBase.getTaskAttemptId(0, maxFailures, maxFailures, false, 10),
String.valueOf(maxFailures));
assertThrowsExactly(
RssException.class,
() -> RssShuffleManagerBase.getTaskAttemptId(0, maxFailures + 1, maxFailures, false, 10),
String.valueOf(maxFailures));
assertThrowsExactly(
RssException.class,
() -> RssShuffleManagerBase.getTaskAttemptId(0, maxFailures + 2, maxFailures, false, 10),
String.valueOf(maxFailures));
Exception e =
assertThrowsExactly(
RssException.class,
() ->
RssShuffleManagerBase.getTaskAttemptId(
0, maxFailures + 128, maxFailures, false, 10),
String.valueOf(maxFailures));
assertEquals(
"Observing attempt number "
+ (maxFailures + 128)
+ " while maxFailures is set to "
+ maxFailures
+ ".",
e.getMessage());
}

// test with mapIndex that would require more than maxTaskAttemptBits
Exception e =
assertThrowsExactly(
RssException.class, () -> RssShuffleManagerBase.getTaskAttemptId(256, 0, 3, true, 10));
assertEquals(
"Observing mapIndex[256] that would produce a taskAttemptId with 11 bits "
+ "which is larger than the allowed 10 bits (maxFailures[3], speculation[true]). "
+ "Please consider providing more bits for taskAttemptIds.",
e.getMessage());
// check that a lower mapIndex works as expected
assertEquals(bits("11111111|00"), RssShuffleManagerBase.getTaskAttemptId(255, 0, 3, true, 10));
}

@Test
public void testGetTaskAttemptIdWithSpeculation() {
// with speculation, we expect maxFailures+1 attempts

// the expected bits("xy|z") represents the expected Long in bit notation where | is used to
// separate map index from attempt number, so merely for visualization purposes

// maxFailures < 1 not allowed, we fall back to maxFailures=1 to be robust
for (int maxFailures : Arrays.asList(-1, 0, 1)) {
for (int attemptNo : Arrays.asList(0, 1)) {
assertEquals(
bits("0000|" + attemptNo),
RssShuffleManagerBase.getTaskAttemptId(0, attemptNo, maxFailures, true, 10),
"maxFailures=" + maxFailures + ", attemptNo=" + attemptNo);
assertEquals(
bits("0001|" + attemptNo),
RssShuffleManagerBase.getTaskAttemptId(1, attemptNo, maxFailures, true, 10),
"maxFailures=" + maxFailures + ", attemptNo=" + attemptNo);
assertEquals(
bits("0010|" + attemptNo),
RssShuffleManagerBase.getTaskAttemptId(2, attemptNo, maxFailures, true, 10),
"maxFailures=" + maxFailures + ", attemptNo=" + attemptNo);
}
}

// maxFailures of 2
assertEquals(bits("00|00"), RssShuffleManagerBase.getTaskAttemptId(0, 0, 2, true, 10));
assertEquals(bits("00|01"), RssShuffleManagerBase.getTaskAttemptId(0, 1, 2, true, 10));
assertEquals(bits("00|10"), RssShuffleManagerBase.getTaskAttemptId(0, 2, 2, true, 10));
assertEquals(bits("01|00"), RssShuffleManagerBase.getTaskAttemptId(1, 0, 2, true, 10));
assertEquals(bits("01|01"), RssShuffleManagerBase.getTaskAttemptId(1, 1, 2, true, 10));
assertEquals(bits("01|10"), RssShuffleManagerBase.getTaskAttemptId(1, 2, 2, true, 10));
assertEquals(bits("10|00"), RssShuffleManagerBase.getTaskAttemptId(2, 0, 2, true, 10));
assertEquals(bits("10|01"), RssShuffleManagerBase.getTaskAttemptId(2, 1, 2, true, 10));
assertEquals(bits("10|10"), RssShuffleManagerBase.getTaskAttemptId(2, 2, 2, true, 10));
assertEquals(bits("11|00"), RssShuffleManagerBase.getTaskAttemptId(3, 0, 2, true, 10));
assertEquals(bits("11|01"), RssShuffleManagerBase.getTaskAttemptId(3, 1, 2, true, 10));
assertEquals(bits("11|10"), RssShuffleManagerBase.getTaskAttemptId(3, 2, 2, true, 10));

// maxFailures of 3
assertEquals(bits("00|00"), RssShuffleManagerBase.getTaskAttemptId(0, 0, 3, true, 10));
assertEquals(bits("00|01"), RssShuffleManagerBase.getTaskAttemptId(0, 1, 3, true, 10));
assertEquals(bits("00|10"), RssShuffleManagerBase.getTaskAttemptId(0, 2, 3, true, 10));
assertEquals(bits("00|11"), RssShuffleManagerBase.getTaskAttemptId(0, 3, 3, true, 10));
assertEquals(bits("01|00"), RssShuffleManagerBase.getTaskAttemptId(1, 0, 3, true, 10));
assertEquals(bits("01|01"), RssShuffleManagerBase.getTaskAttemptId(1, 1, 3, true, 10));
assertEquals(bits("01|10"), RssShuffleManagerBase.getTaskAttemptId(1, 2, 3, true, 10));
assertEquals(bits("01|11"), RssShuffleManagerBase.getTaskAttemptId(1, 3, 3, true, 10));
assertEquals(bits("10|00"), RssShuffleManagerBase.getTaskAttemptId(2, 0, 3, true, 10));
assertEquals(bits("10|01"), RssShuffleManagerBase.getTaskAttemptId(2, 1, 3, true, 10));
assertEquals(bits("10|10"), RssShuffleManagerBase.getTaskAttemptId(2, 2, 3, true, 10));
assertEquals(bits("10|11"), RssShuffleManagerBase.getTaskAttemptId(2, 3, 3, true, 10));
assertEquals(bits("11|00"), RssShuffleManagerBase.getTaskAttemptId(3, 0, 3, true, 10));
assertEquals(bits("11|01"), RssShuffleManagerBase.getTaskAttemptId(3, 1, 3, true, 10));
assertEquals(bits("11|10"), RssShuffleManagerBase.getTaskAttemptId(3, 2, 3, true, 10));
assertEquals(bits("11|11"), RssShuffleManagerBase.getTaskAttemptId(3, 3, 3, true, 10));

// maxFailures of 4
assertEquals(bits("0|000"), RssShuffleManagerBase.getTaskAttemptId(0, 0, 4, true, 10));
assertEquals(bits("1|100"), RssShuffleManagerBase.getTaskAttemptId(1, 4, 4, true, 10));

// test with ints that overflow into signed int and long
assertEquals(
(long) Integer.MAX_VALUE << 1,
RssShuffleManagerBase.getTaskAttemptId(Integer.MAX_VALUE, 0, 1, true, 32));
assertEquals(
(long) Integer.MAX_VALUE << 1 | 1,
RssShuffleManagerBase.getTaskAttemptId(Integer.MAX_VALUE, 1, 1, true, 32));
assertEquals(
(long) Integer.MAX_VALUE << 2 | 3,
RssShuffleManagerBase.getTaskAttemptId(Integer.MAX_VALUE, 3, 3, true, 33));
assertEquals(
(long) Integer.MAX_VALUE << 3 | 7,
RssShuffleManagerBase.getTaskAttemptId(Integer.MAX_VALUE, 7, 7, true, 34));

// test with attemptNo > maxFailures (attemptNo == maxFailures allowed for speculation enabled)
assertThrowsExactly(
RssException.class, () -> RssShuffleManagerBase.getTaskAttemptId(0, 2, -1, true, 10));
assertThrowsExactly(
RssException.class, () -> RssShuffleManagerBase.getTaskAttemptId(0, 2, 0, true, 10));
for (int maxFailures : Arrays.asList(1, 2, 3, 4, 8, 128)) {
assertThrowsExactly(
RssException.class,
() -> RssShuffleManagerBase.getTaskAttemptId(0, maxFailures + 1, maxFailures, true, 10),
String.valueOf(maxFailures));
assertThrowsExactly(
RssException.class,
() -> RssShuffleManagerBase.getTaskAttemptId(0, maxFailures + 2, maxFailures, true, 10),
String.valueOf(maxFailures));
Exception e =
assertThrowsExactly(
RssException.class,
() ->
RssShuffleManagerBase.getTaskAttemptId(
0, maxFailures + 128, maxFailures, true, 10),
String.valueOf(maxFailures));
assertEquals(
"Observing attempt number "
+ (maxFailures + 128)
+ " while maxFailures is set to "
+ maxFailures
+ " with speculation enabled.",
e.getMessage());
}

// test with mapIndex that would require more than maxTaskAttemptBits
Exception e =
assertThrowsExactly(
RssException.class, () -> RssShuffleManagerBase.getTaskAttemptId(256, 0, 4, false, 10));
assertEquals(
"Observing mapIndex[256] that would produce a taskAttemptId with 11 bits "
+ "which is larger than the allowed 10 bits (maxFailures[4], speculation[false]). "
+ "Please consider providing more bits for taskAttemptIds.",
e.getMessage());
// check that a lower mapIndex works as expected
assertEquals(bits("11111111|00"), RssShuffleManagerBase.getTaskAttemptId(255, 0, 4, false, 10));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.exception.RssFetchFailedException;
import org.apache.uniffle.common.rpc.GrpcServer;
import org.apache.uniffle.common.util.Constants;
import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.common.util.RetryUtils;
import org.apache.uniffle.common.util.RssUtils;
Expand Down Expand Up @@ -108,6 +109,8 @@ public class RssShuffleManager extends RssShuffleManagerBase {
private Set<String> failedTaskIds = Sets.newConcurrentHashSet();
private boolean heartbeatStarted = false;
private boolean dynamicConfEnabled = false;
private final int maxFailures;
private final boolean speculation;
private final String user;
private final String uuid;
private DataPusher dataPusher;
Expand Down Expand Up @@ -140,6 +143,8 @@ public RssShuffleManager(SparkConf sparkConf, boolean isDriver) {
"Spark2 doesn't support AQE, spark.sql.adaptive.enabled should be false.");
}
this.sparkConf = sparkConf;
this.maxFailures = sparkConf.getInt("spark.task.maxFailures", 4);
this.speculation = sparkConf.getBoolean("spark.speculation", false);
this.user = sparkConf.get("spark.rss.quota.user", "user");
this.uuid = sparkConf.get("spark.rss.quota.uuid", Long.toString(System.currentTimeMillis()));
// set & check replica config
Expand Down Expand Up @@ -462,11 +467,18 @@ public <K, V> ShuffleWriter<K, V> getWriter(
shuffleId, rssHandle.getPartitionToServers(), rssHandle.getRemoteStorage());
}
ShuffleWriteMetrics writeMetrics = context.taskMetrics().shuffleWriteMetrics();
long taskAttemptId =
getTaskAttemptId(
context.partitionId(),
context.attemptNumber(),
maxFailures,
speculation,
Constants.TASK_ATTEMPT_ID_MAX_LENGTH);
return new RssShuffleWriter<>(
rssHandle.getAppId(),
shuffleId,
taskId,
context.taskAttemptId(),
taskAttemptId,
writeMetrics,
this,
sparkConf,
Expand Down
Loading

0 comments on commit 0762498

Please sign in to comment.