Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#1398] fix(mr)(tez): Make attempId computable and move it to taskAttemptId in BlockId layout. #1418

Merged
merged 13 commits into from
Jul 5, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ public void init(Context context) throws IOException, ClassNotFoundException {
ApplicationAttemptId applicationAttemptId = RssMRUtils.getApplicationAttemptId();
String appId = applicationAttemptId.toString();
long taskAttemptId =
RssMRUtils.convertTaskAttemptIdToLong(
mapTask.getTaskID(), applicationAttemptId.getAttemptId());
RssMRUtils.createRssTaskAttemptId(
mapTask.getTaskID(), applicationAttemptId.getAttemptId(), mrJobConf);
double sendThreshold =
RssMRUtils.getDouble(
rssJobConf,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.util.ClientUtils;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.BlockIdLayout;
Expand All @@ -44,37 +45,61 @@ public class RssMRUtils {

private static final Logger LOG = LoggerFactory.getLogger(RssMRUtils.class);
private static final BlockIdLayout LAYOUT = BlockIdLayout.DEFAULT;
private static final int MAX_ATTEMPT_LENGTH = 6;
private static final int MAX_ATTEMPT_ID = (1 << MAX_ATTEMPT_LENGTH) - 1;
private static final int MAX_SEQUENCE_NO =
(1 << (LAYOUT.sequenceNoBits - MAX_ATTEMPT_LENGTH)) - 1;

// Class TaskAttemptId have two field id and mapId, rss taskAttemptID have 21 bits,
// mapId is 19 bits, id is 2 bits. MR have a trick logic, taskAttemptId will increase
// 1000 * (appAttemptId - 1), so we will decrease it.
public static long convertTaskAttemptIdToLong(TaskAttemptID taskAttemptID, int appAttemptId) {
int lowBytes = taskAttemptID.getTaskID().getId();
if (lowBytes > LAYOUT.maxTaskAttemptId) {
throw new RssException("TaskAttempt " + taskAttemptID + " low bytes " + lowBytes + " exceed");
}

// Class TaskAttemptId have two field id and mapId. MR have a trick logic, taskAttemptId will
// increase 1000 * (appAttemptId - 1), so we will decrease it.
public static int createRssTaskAttemptId(
TaskAttemptID taskAttemptID, int appAttemptId, int maxAttemptNo) {
int attemptBits = ClientUtils.getNumberOfSignificantBits(maxAttemptNo);

if (appAttemptId < 1) {
throw new RssException("appAttemptId " + appAttemptId + " is wrong");
}
int highBytes = taskAttemptID.getId() - (appAttemptId - 1) * 1000;
if (highBytes > MAX_ATTEMPT_ID || highBytes < 0) {
int attemptId = taskAttemptID.getId() - (appAttemptId - 1) * 1000;
if (attemptId > maxAttemptNo || attemptId < 0) {
throw new RssException(
"TaskAttempt " + taskAttemptID + " high bytes " + highBytes + " exceed");
"TaskAttempt " + taskAttemptID + " attemptId " + attemptId + " exceed " + maxAttemptNo);
}
return LAYOUT.getBlockId(highBytes, 0, lowBytes);
int taskId = taskAttemptID.getTaskID().getId();

int mapIndexBits = ClientUtils.getNumberOfSignificantBits(taskId);
if (mapIndexBits + attemptBits > LAYOUT.taskAttemptIdBits) {
throw new RssException(
"Observing taskId["
+ taskId
+ "] that would produce a taskAttemptId with "
+ (mapIndexBits + attemptBits)
+ " bits which is larger than the allowed "
+ LAYOUT.taskAttemptIdBits
+ "]). Please consider providing more bits for taskAttemptIds.");
}

return (taskId << attemptBits) | attemptId;
}

public static int createRssTaskAttemptId(
TaskAttemptID taskAttemptID, int appAttemptId, int maxFailures, boolean speculation) {
int maxAttemptNo = ClientUtils.getMaxAttemptNo(maxFailures, speculation);
return createRssTaskAttemptId(taskAttemptID, appAttemptId, maxAttemptNo);
}

public static int createRssTaskAttemptId(
TaskAttemptID taskAttemptID, int appAttemptId, Configuration conf) {
int maxFailures = conf.getInt(MRJobConfig.MAP_MAX_ATTEMPTS, 4);
boolean speculation = conf.getBoolean(MRJobConfig.MAP_SPECULATIVE, true);
return createRssTaskAttemptId(taskAttemptID, appAttemptId, maxFailures, speculation);
}

public static TaskAttemptID createMRTaskAttemptId(
JobID jobID, TaskType taskType, long rssTaskAttemptId, int appAttemptId) {
JobID jobID, TaskType taskType, long rssTaskAttemptId, int appAttemptId, int maxAttemptNo) {
int attemptBits = ClientUtils.getNumberOfSignificantBits(maxAttemptNo);
if (appAttemptId < 1) {
throw new RssException("appAttemptId " + appAttemptId + " is wrong");
}
TaskID taskID = new TaskID(jobID, taskType, LAYOUT.getTaskAttemptId(rssTaskAttemptId));
int id = LAYOUT.getSequenceNo(rssTaskAttemptId) + 1000 * (appAttemptId - 1);
int task = (int) rssTaskAttemptId >> attemptBits;
int attempt = (int) rssTaskAttemptId & ((1 << attemptBits) - 1);
TaskID taskID = new TaskID(jobID, taskType, task);
int id = attempt + 1000 * (appAttemptId - 1);
return new TaskAttemptID(taskID, id);
}

Expand Down Expand Up @@ -228,27 +253,11 @@ public static String getString(Configuration rssJobConf, String key, String defa
}

public static long getBlockId(int partitionId, long taskAttemptId, int nextSeqNo) {
qijiale76 marked this conversation as resolved.
Show resolved Hide resolved
long attemptId = taskAttemptId >> (LAYOUT.partitionIdBits + LAYOUT.taskAttemptIdBits);
if (attemptId < 0 || attemptId > MAX_ATTEMPT_ID) {
throw new RssException(
"Can't support attemptId [" + attemptId + "], the max value should be " + MAX_ATTEMPT_ID);
}
if (nextSeqNo < 0 || nextSeqNo > MAX_SEQUENCE_NO) {
throw new RssException(
"Can't support sequence [" + nextSeqNo + "], the max value should be " + MAX_SEQUENCE_NO);
}

int atomicInt = (int) ((nextSeqNo << MAX_ATTEMPT_LENGTH) + attemptId);
long taskId =
taskAttemptId - (attemptId << (LAYOUT.partitionIdBits + LAYOUT.taskAttemptIdBits));

return LAYOUT.getBlockId(atomicInt, partitionId, taskId);
return LAYOUT.getBlockId(nextSeqNo, partitionId, taskAttemptId);
}

public static long getTaskAttemptId(long blockId) {
int mapId = LAYOUT.getTaskAttemptId(blockId);
int attemptId = LAYOUT.getSequenceNo(blockId) & MAX_ATTEMPT_ID;
return LAYOUT.getBlockId(attemptId, 0, mapId);
public static int getTaskAttemptId(long blockId) {
return LAYOUT.getTaskAttemptId(blockId);
}

public static int estimateTaskConcurrency(JobConf jobConf) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public Roaring64NavigableMap fetchAllRssTaskIds() {
String errMsg = "TaskAttemptIDs are inconsistent with map tasks";
for (TaskAttemptID taskAttemptID : successMaps) {
if (!obsoleteMaps.contains(taskAttemptID)) {
long rssTaskId = RssMRUtils.convertTaskAttemptIdToLong(taskAttemptID, appAttemptId);
long rssTaskId = RssMRUtils.createRssTaskAttemptId(taskAttemptID, appAttemptId, jobConf);
int mapIndex = taskAttemptID.getTaskID().getId();
// There can be multiple successful attempts on same map task.
// So we only need to accept one of them.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public void testWriteException() throws Exception {
manager =
new SortWriteBufferManager<BytesWritable, BytesWritable>(
10240,
1L,
1,
10,
serializationFactory.getSerializer(BytesWritable.class),
serializationFactory.getSerializer(BytesWritable.class),
Expand Down Expand Up @@ -140,7 +140,7 @@ public void testWriteException() throws Exception {
manager =
new SortWriteBufferManager<BytesWritable, BytesWritable>(
100,
1L,
1,
10,
serializationFactory.getSerializer(BytesWritable.class),
serializationFactory.getSerializer(BytesWritable.class),
Expand Down Expand Up @@ -192,7 +192,7 @@ public void testOnePartition() throws Exception {
manager =
new SortWriteBufferManager<BytesWritable, BytesWritable>(
10240,
1L,
1,
10,
serializationFactory.getSerializer(BytesWritable.class),
serializationFactory.getSerializer(BytesWritable.class),
Expand Down Expand Up @@ -244,7 +244,7 @@ public void testWriteNormal() throws Exception {
manager =
new SortWriteBufferManager<BytesWritable, BytesWritable>(
10240,
1L,
1,
10,
serializationFactory.getSerializer(BytesWritable.class),
serializationFactory.getSerializer(BytesWritable.class),
Expand Down Expand Up @@ -311,7 +311,7 @@ public void testCommitBlocksWhenMemoryShuffleDisabled() throws Exception {
manager =
new SortWriteBufferManager<BytesWritable, BytesWritable>(
10240,
1L,
1,
10,
serializationFactory.getSerializer(BytesWritable.class),
serializationFactory.getSerializer(BytesWritable.class),
Expand Down Expand Up @@ -390,7 +390,7 @@ public void testCombineBuffer() throws Exception {
SortWriteBufferManager<Text, IntWritable> manager =
new SortWriteBufferManager<Text, IntWritable>(
10240,
1L,
1,
10,
keySerializer,
valueSerializer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,21 @@ public void baskAttemptIdTest() {
TaskAttemptID mrTaskAttemptId = new TaskAttemptID(taskId, 3);
boolean isException = false;
try {
RssMRUtils.convertTaskAttemptIdToLong(mrTaskAttemptId, 1);
RssMRUtils.createRssTaskAttemptId(mrTaskAttemptId, 1, 4);
} catch (RssException e) {
isException = true;
}
assertTrue(isException);
taskAttemptId = (1 << 20) + 0x123;
mrTaskAttemptId = RssMRUtils.createMRTaskAttemptId(new JobID(), TaskType.MAP, taskAttemptId, 1);
long testId = RssMRUtils.convertTaskAttemptIdToLong(mrTaskAttemptId, 1);
taskAttemptId = (0x123 << 3) + 1;
mrTaskAttemptId =
RssMRUtils.createMRTaskAttemptId(new JobID(), TaskType.MAP, taskAttemptId, 1, 4);
int testId = RssMRUtils.createRssTaskAttemptId(mrTaskAttemptId, 1, 4);
assertEquals(taskAttemptId, testId);
TaskID taskID = new TaskID(new org.apache.hadoop.mapred.JobID(), TaskType.MAP, (int) (1 << 21));
TaskID taskID = new TaskID(new org.apache.hadoop.mapred.JobID(), TaskType.MAP, 1 << 21);
mrTaskAttemptId = new TaskAttemptID(taskID, 2);
isException = false;
try {
RssMRUtils.convertTaskAttemptIdToLong(mrTaskAttemptId, 1);
RssMRUtils.createRssTaskAttemptId(mrTaskAttemptId, 1, 4);
} catch (RssException e) {
isException = true;
}
Expand All @@ -70,7 +71,7 @@ public void blockConvertTest() {
JobID jobID = new JobID();
TaskID taskId = new TaskID(jobID, TaskType.MAP, 233);
TaskAttemptID taskAttemptID = new TaskAttemptID(taskId, 1);
long taskAttemptId = RssMRUtils.convertTaskAttemptIdToLong(taskAttemptID, 1);
long taskAttemptId = RssMRUtils.createRssTaskAttemptId(taskAttemptID, 1, 4);
long blockId = RssMRUtils.getBlockId(1, taskAttemptId, 0);
long newTaskAttemptId = RssMRUtils.getTaskAttemptId(blockId);
assertEquals(taskAttemptId, newTaskAttemptId);
Expand All @@ -85,7 +86,7 @@ public void partitionIdConvertBlockTest() {
JobID jobID = new JobID();
TaskID taskId = new TaskID(jobID, TaskType.MAP, 233);
TaskAttemptID taskAttemptID = new TaskAttemptID(taskId, 1);
long taskAttemptId = RssMRUtils.convertTaskAttemptIdToLong(taskAttemptID, 1);
long taskAttemptId = RssMRUtils.createRssTaskAttemptId(taskAttemptID, 1, 4);
long mask = (1L << layout.partitionIdBits) - 1;
for (int partitionId = 0; partitionId <= 3000; partitionId++) {
for (int seqNo = 0; seqNo <= 10; seqNo++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ public void singlePassEventFetch() throws IOException {
Roaring64NavigableMap expected = Roaring64NavigableMap.bitmapOf();
for (int mapIndex = 0; mapIndex < mapTaskNum; mapIndex++) {
long rssTaskId =
RssMRUtils.convertTaskAttemptIdToLong(
new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1);
RssMRUtils.createRssTaskAttemptId(
new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1, 4);
expected.addLong(rssTaskId);
}

Expand Down Expand Up @@ -89,8 +89,8 @@ public void singlePassWithRepeatedSuccessEventFetch() throws IOException {
Roaring64NavigableMap expected = Roaring64NavigableMap.bitmapOf();
for (int mapIndex = 0; mapIndex < mapTaskNum; mapIndex++) {
long rssTaskId =
RssMRUtils.convertTaskAttemptIdToLong(
new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1);
RssMRUtils.createRssTaskAttemptId(
new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1, 4);
expected.addLong(rssTaskId);
}

Expand Down Expand Up @@ -121,8 +121,8 @@ public void multiPassEventFetch() throws IOException {
Roaring64NavigableMap expected = Roaring64NavigableMap.bitmapOf();
for (int mapIndex = 0; mapIndex < mapTaskNum; mapIndex++) {
long rssTaskId =
RssMRUtils.convertTaskAttemptIdToLong(
new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1);
RssMRUtils.createRssTaskAttemptId(
new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1, 4);
expected.addLong(rssTaskId);
}
Roaring64NavigableMap taskIdBitmap = ef.fetchAllRssTaskIds();
Expand All @@ -146,8 +146,8 @@ public void missingEventFetch() throws IOException {
Roaring64NavigableMap expected = Roaring64NavigableMap.bitmapOf();
for (int mapIndex = 0; mapIndex < mapTaskNum; mapIndex++) {
long rssTaskId =
RssMRUtils.convertTaskAttemptIdToLong(
new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1);
RssMRUtils.createRssTaskAttemptId(
new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1, 4);
expected.addLong(rssTaskId);
}
IllegalStateException ex =
Expand All @@ -172,8 +172,8 @@ public void extraEventFetch() throws IOException {
Roaring64NavigableMap expected = Roaring64NavigableMap.bitmapOf();
for (int mapIndex = 0; mapIndex < mapTaskNum; mapIndex++) {
long rssTaskId =
RssMRUtils.convertTaskAttemptIdToLong(
new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1);
RssMRUtils.createRssTaskAttemptId(
new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1, 4);
expected.addLong(rssTaskId);
}
IllegalStateException ex =
Expand Down Expand Up @@ -205,14 +205,14 @@ public void obsoletedAndTipFailedEventFetch() throws IOException {
for (int mapIndex = 0; mapIndex < mapTaskNum; mapIndex++) {
if (!tipFailed.contains(mapIndex) && !obsoleted.contains(mapIndex)) {
long rssTaskId =
RssMRUtils.convertTaskAttemptIdToLong(
new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1);
RssMRUtils.createRssTaskAttemptId(
new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1, 4);
expected.addLong(rssTaskId);
}
if (obsoleted.contains(mapIndex)) {
long rssTaskId =
RssMRUtils.convertTaskAttemptIdToLong(
new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 1), 1);
RssMRUtils.createRssTaskAttemptId(
new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 1), 1, 4);
expected.addLong(rssTaskId);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,8 @@ public void testCodecIsDuplicated() throws Exception {
null,
new Progress(),
new MROutputFiles());
TaskAttemptID taskAttemptID = RssMRUtils.createMRTaskAttemptId(new JobID(), TaskType.MAP, 1, 1);
TaskAttemptID taskAttemptID =
RssMRUtils.createMRTaskAttemptId(new JobID(), TaskType.MAP, 1, 1, 4);
byte[] buffer = new byte[10];
MapOutput mapOutput1 = merger.reserve(taskAttemptID, 10, 1);
RssBypassWriter.write(mapOutput1, buffer);
Expand Down Expand Up @@ -350,7 +351,7 @@ private static byte[] writeMapOutputRss(Configuration conf, Map<String, String>
SortWriteBufferManager<Text, Text> manager =
new SortWriteBufferManager(
10240,
1L,
1,
10,
serializationFactory.getSerializer(Text.class),
serializationFactory.getSerializer(Text.class),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,10 @@ private static void configureBlockIdLayoutFromMaxPartitions(
+ maxPartitions);
}

int attemptIdBits = getAttemptIdBits(getMaxAttemptNo(maxFailures, speculation));
int partitionIdBits = 32 - Integer.numberOfLeadingZeros(maxPartitions - 1); // [1..31]
int attemptIdBits =
ClientUtils.getNumberOfSignificantBits(
ClientUtils.getMaxAttemptNo(maxFailures, speculation));
int partitionIdBits = ClientUtils.getNumberOfSignificantBits(maxPartitions - 1); // [1..31]
int taskAttemptIdBits = partitionIdBits + attemptIdBits; // [1+attemptIdBits..31+attemptIdBits]
int sequenceNoBits = 63 - partitionIdBits - taskAttemptIdBits; // [1-attemptIdBits..61]

Expand Down Expand Up @@ -332,23 +334,6 @@ private static void configureBlockIdLayoutFromLayoutConfig(
}
}

protected static int getMaxAttemptNo(int maxFailures, boolean speculation) {
// attempt number is zero based: 0, 1, …, maxFailures-1
// max maxFailures < 1 is not allowed but for safety, we interpret that as maxFailures == 1
int maxAttemptNo = maxFailures < 1 ? 0 : maxFailures - 1;

// with speculative execution enabled we could observe +1 attempts
if (speculation) {
maxAttemptNo++;
}

return maxAttemptNo;
}

protected static int getAttemptIdBits(int maxAttemptNo) {
return 32 - Integer.numberOfLeadingZeros(maxAttemptNo);
}

/** See static overload of this method. */
public abstract long getTaskAttemptIdForBlockId(int mapIndex, int attemptNo);

Expand All @@ -367,8 +352,8 @@ protected static int getAttemptIdBits(int maxAttemptNo) {
*/
protected static long getTaskAttemptIdForBlockId(
int mapIndex, int attemptNo, int maxFailures, boolean speculation, int maxTaskAttemptIdBits) {
int maxAttemptNo = getMaxAttemptNo(maxFailures, speculation);
int attemptBits = getAttemptIdBits(maxAttemptNo);
int maxAttemptNo = ClientUtils.getMaxAttemptNo(maxFailures, speculation);
int attemptBits = ClientUtils.getNumberOfSignificantBits(maxAttemptNo);

if (attemptNo > maxAttemptNo) {
// this should never happen, if it does, our assumptions are wrong,
Expand All @@ -382,7 +367,7 @@ protected static long getTaskAttemptIdForBlockId(
+ ".");
}

int mapIndexBits = 32 - Integer.numberOfLeadingZeros(mapIndex);
int mapIndexBits = ClientUtils.getNumberOfSignificantBits(mapIndex);
if (mapIndexBits + attemptBits > maxTaskAttemptIdBits) {
throw new RssException(
"Observing mapIndex["
Expand Down
Loading
Loading