Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#1398] fix(mr,tez): Make attempId computable and move it to taskAttemptId in BlockId layout. #1418

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.hadoop.io.RawComparator;
import org.apache.hadoop.io.serializer.SerializationFactory;
import org.apache.hadoop.mapreduce.JobContext;
import org.apache.hadoop.mapreduce.MRJobConfig;
import org.apache.hadoop.mapreduce.RssMRConfig;
import org.apache.hadoop.mapreduce.RssMRUtils;
import org.apache.hadoop.mapreduce.TaskCounter;
Expand All @@ -38,6 +39,7 @@
import org.slf4j.LoggerFactory;

import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.util.ClientUtils;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.ByteUnit;
Expand Down Expand Up @@ -99,9 +101,12 @@ public void init(Context context) throws IOException, ClassNotFoundException {
RssMRConfig.RSS_CLIENT_DEFAULT_MEMORY_THRESHOLD);
ApplicationAttemptId applicationAttemptId = RssMRUtils.getApplicationAttemptId();
String appId = applicationAttemptId.toString();
long taskAttemptId =
RssMRUtils.convertTaskAttemptIdToLong(
mapTask.getTaskID(), applicationAttemptId.getAttemptId());
int maxFailures = mrJobConf.getInt(MRJobConfig.MAP_MAX_ATTEMPTS, 4);
boolean speculation = mrJobConf.getBoolean(MRJobConfig.MAP_SPECULATIVE, true);
int maxAttemptNo = ClientUtils.getMaxAttemptNo(maxFailures, speculation);
int taskAttemptId =
RssMRUtils.createRssTaskAttemptId(
mapTask.getTaskID(), applicationAttemptId.getAttemptId(), maxAttemptNo);
double sendThreshold =
RssMRUtils.getDouble(
rssJobConf,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public class SortWriteBufferManager<K, V> {
private final Counters.Counter mapOutputRecordCounter;
private long uncompressedDataLen = 0;
private long compressTime = 0;
private final long taskAttemptId;
private final int taskAttemptId;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about restricting taskAttemptIds to int in such places.

Here is the situation:

  1. Spark, Tez and MR provide us with long task attempt ids (for Tez and MR, (taskId, attemptId) constitutes a long task attempt id, which we restrict to int for similar reasons as in 2.)
  2. for the purpose of the block id, we limit those long task attempt ids to int, since we allow only less that 32 bits for it
  3. the task attempt id retrieved from the block id is int because of that
  4. still, all other places could continue to work with long task attempt ids if that makes no difference for that code, up-casting int task attempt ids to long does not harm, as long as the code works with long.

This allows to support truly long task attempt ids without reverting such code changes in the future.

@zuston @jerqi @zhengchenyu what do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@EnricoMi
I think the original taskattemptid is long because it comes from TaskContext::taskAttemptId of spark, which is the unique id of attempt at the app level. MR and Tez inherited the long type, but implemented taskattemptid by bit concatenation of taskid and attemptid in this task.
I think you have changed taskattemptid to the concatenation of taskid and attemptid in #731, and limited it to no more than 32 bits, so int is enough.
What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Preferring using long for some common methods.

private final AtomicLong memoryUsedSize = new AtomicLong(0);
private final int batch;
private final AtomicLong inSendListBytes = new AtomicLong(0);
Expand Down Expand Up @@ -101,7 +101,7 @@ public class SortWriteBufferManager<K, V> {

public SortWriteBufferManager(
long maxMemSize,
long taskAttemptId,
int taskAttemptId,
int batch,
Serializer<K> keySerializer,
Serializer<V> valSerializer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.util.ClientUtils;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.BlockIdLayout;
Expand All @@ -44,38 +45,46 @@ public class RssMRUtils {

private static final Logger LOG = LoggerFactory.getLogger(RssMRUtils.class);
private static final BlockIdLayout LAYOUT = BlockIdLayout.DEFAULT;
private static final int MAX_ATTEMPT_LENGTH = 4;
private static final int MAX_ATTEMPT_ID = (1 << MAX_ATTEMPT_LENGTH) - 1;
private static final int MAX_TASK_LENGTH = LAYOUT.taskAttemptIdBits - MAX_ATTEMPT_LENGTH;
private static final int MAX_TASK_ID = (1 << MAX_TASK_LENGTH) - 1;

// Class TaskAttemptId have two field id and mapId. MR have a trick logic, taskAttemptId will
// increase 1000 * (appAttemptId - 1), so we will decrease it.
public static long convertTaskAttemptIdToLong(TaskAttemptID taskAttemptID, int appAttemptId) {
public static int createRssTaskAttemptId(
TaskAttemptID taskAttemptID, int appAttemptId, int maxAttemptNo) {
int attemptBits = ClientUtils.getAttemptIdBits(maxAttemptNo);

if (appAttemptId < 1) {
throw new RssException("appAttemptId " + appAttemptId + " is wrong");
}
long lowBytes = taskAttemptID.getId() - (appAttemptId - 1) * 1000L;
if (lowBytes > MAX_ATTEMPT_ID || lowBytes < 0) {
int attemptId = taskAttemptID.getId() - (appAttemptId - 1) * 1000;
if (attemptId > maxAttemptNo || attemptId < 0) {
throw new RssException(
"TaskAttempt " + taskAttemptID + " low bytes " + lowBytes + " exceed " + MAX_ATTEMPT_ID);
"TaskAttempt " + taskAttemptID + " attemptId " + attemptId + " exceed " + maxAttemptNo);
}
long highBytes = taskAttemptID.getTaskID().getId();
if (highBytes > MAX_TASK_ID || highBytes < 0) {
int taskId = taskAttemptID.getTaskID().getId();

int mapIndexBits = 32 - Integer.numberOfLeadingZeros(taskId);
qijiale76 marked this conversation as resolved.
Show resolved Hide resolved
if (mapIndexBits + attemptBits > LAYOUT.taskAttemptIdBits) {
throw new RssException(
"TaskAttempt " + taskAttemptID + " high bytes " + highBytes + " exceed " + MAX_TASK_ID);
"Observing taskId["
+ taskId
+ "] that would produce a taskAttemptId with "
+ (mapIndexBits + attemptBits)
+ " bits which is larger than the allowed "
+ LAYOUT.taskAttemptIdBits
+ "]). Please consider providing more bits for taskAttemptIds.");
}
long taskAttemptId = (highBytes << (MAX_ATTEMPT_LENGTH)) + lowBytes;
return LAYOUT.getBlockId(0, 0, taskAttemptId);

return (taskId << (attemptBits)) + attemptId;
qijiale76 marked this conversation as resolved.
Show resolved Hide resolved
}

public static TaskAttemptID createMRTaskAttemptId(
JobID jobID, TaskType taskType, long rssTaskAttemptId, int appAttemptId) {
JobID jobID, TaskType taskType, int rssTaskAttemptId, int appAttemptId, int maxAttemptNo) {
int attemptBits = ClientUtils.getAttemptIdBits(maxAttemptNo);
if (appAttemptId < 1) {
throw new RssException("appAttemptId " + appAttemptId + " is wrong");
}
int task = LAYOUT.getTaskAttemptId(rssTaskAttemptId) >> MAX_ATTEMPT_LENGTH;
int attempt = (int) (rssTaskAttemptId & MAX_ATTEMPT_ID);
int task = rssTaskAttemptId >> attemptBits;
int attempt = rssTaskAttemptId & ((1 << attemptBits) - 1);
TaskID taskID = new TaskID(jobID, taskType, task);
int id = attempt + 1000 * (appAttemptId - 1);
return new TaskAttemptID(taskID, id);
Expand Down Expand Up @@ -230,29 +239,7 @@ public static String getString(Configuration rssJobConf, String key, String defa
return rssJobConf.get(key, defaultValue);
}

public static long getBlockId(int partitionId, long taskAttemptId, int nextSeqNo) {
if (taskAttemptId < 0 || taskAttemptId > LAYOUT.maxTaskAttemptId) {
throw new RssException(
"Can't support attemptId ["
+ taskAttemptId
+ "], the max value should be "
+ LAYOUT.maxTaskAttemptId);
}
if (nextSeqNo < 0 || nextSeqNo > LAYOUT.maxSequenceNo) {
throw new RssException(
"Can't support sequence ["
+ nextSeqNo
+ "], the max value should be "
+ LAYOUT.maxSequenceNo);
}

if (partitionId < 0 || partitionId > LAYOUT.maxPartitionId) {
throw new RssException(
"Can't support partitionId ["
+ partitionId
+ "], the max value should be "
+ LAYOUT.maxPartitionId);
}
public static long getBlockId(int partitionId, int taskAttemptId, int nextSeqNo) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, the taskAttemptId can be long here as this before block id layout checks the bit size constraint (though we feed this method only with int taskAttemptIds produced by RssMRUtils.createRssTaskAttemptId()):

Suggested change
public static long getBlockId(int partitionId, int taskAttemptId, int nextSeqNo) {
public static long getBlockId(int partitionId, long taskAttemptId, int nextSeqNo) {

return LAYOUT.getBlockId(nextSeqNo, partitionId, taskAttemptId);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@
import org.apache.hadoop.mapred.MapTaskCompletionEventsUpdate;
import org.apache.hadoop.mapred.TaskCompletionEvent;
import org.apache.hadoop.mapred.TaskUmbilicalProtocol;
import org.apache.hadoop.mapreduce.MRJobConfig;
import org.apache.hadoop.mapreduce.RssMRUtils;
import org.apache.hadoop.mapreduce.TaskAttemptID;
import org.roaringbitmap.longlong.Roaring64NavigableMap;

import org.apache.uniffle.client.util.ClientUtils;
import org.apache.uniffle.common.exception.RssException;

public class RssEventFetcher<K, V> {
Expand Down Expand Up @@ -75,7 +77,11 @@ public Roaring64NavigableMap fetchAllRssTaskIds() {
String errMsg = "TaskAttemptIDs are inconsistent with map tasks";
for (TaskAttemptID taskAttemptID : successMaps) {
if (!obsoleteMaps.contains(taskAttemptID)) {
long rssTaskId = RssMRUtils.convertTaskAttemptIdToLong(taskAttemptID, appAttemptId);
int maxFailures = jobConf.getInt(MRJobConfig.MAP_MAX_ATTEMPTS, 4);
boolean speculation = jobConf.getBoolean(MRJobConfig.MAP_SPECULATIVE, true);
qijiale76 marked this conversation as resolved.
Show resolved Hide resolved
int maxAttemptNo = ClientUtils.getMaxAttemptNo(maxFailures, speculation);
int rssTaskId =
RssMRUtils.createRssTaskAttemptId(taskAttemptID, appAttemptId, maxAttemptNo);
qijiale76 marked this conversation as resolved.
Show resolved Hide resolved
int mapIndex = taskAttemptID.getTaskID().getId();
// There can be multiple successful attempts on same map task.
// So we only need to accept one of them.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public void testWriteException() throws Exception {
manager =
new SortWriteBufferManager<BytesWritable, BytesWritable>(
10240,
1L,
1,
10,
serializationFactory.getSerializer(BytesWritable.class),
serializationFactory.getSerializer(BytesWritable.class),
Expand Down Expand Up @@ -139,7 +139,7 @@ public void testWriteException() throws Exception {
manager =
new SortWriteBufferManager<BytesWritable, BytesWritable>(
100,
1L,
1,
10,
serializationFactory.getSerializer(BytesWritable.class),
serializationFactory.getSerializer(BytesWritable.class),
Expand Down Expand Up @@ -191,7 +191,7 @@ public void testOnePartition() throws Exception {
manager =
new SortWriteBufferManager<BytesWritable, BytesWritable>(
10240,
1L,
1,
10,
serializationFactory.getSerializer(BytesWritable.class),
serializationFactory.getSerializer(BytesWritable.class),
Expand Down Expand Up @@ -243,7 +243,7 @@ public void testWriteNormal() throws Exception {
manager =
new SortWriteBufferManager<BytesWritable, BytesWritable>(
10240,
1L,
1,
10,
serializationFactory.getSerializer(BytesWritable.class),
serializationFactory.getSerializer(BytesWritable.class),
Expand Down Expand Up @@ -310,7 +310,7 @@ public void testCommitBlocksWhenMemoryShuffleDisabled() throws Exception {
manager =
new SortWriteBufferManager<BytesWritable, BytesWritable>(
10240,
1L,
1,
10,
serializationFactory.getSerializer(BytesWritable.class),
serializationFactory.getSerializer(BytesWritable.class),
Expand Down Expand Up @@ -390,7 +390,7 @@ public void testCombineBuffer() throws Exception {
SortWriteBufferManager<Text, IntWritable> manager =
new SortWriteBufferManager<Text, IntWritable>(
10240,
1L,
1,
10,
keySerializer,
valueSerializer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,26 +39,27 @@ public class RssMRUtilsTest {

@Test
public void baskAttemptIdTest() {
long taskAttemptId = 0x1000ad12;
int taskAttemptId = 0x1000ad12;
JobID jobID = new JobID();
TaskID taskId = new TaskID(jobID, TaskType.MAP, (int) taskAttemptId);
TaskAttemptID mrTaskAttemptId = new TaskAttemptID(taskId, 3);
boolean isException = false;
try {
RssMRUtils.convertTaskAttemptIdToLong(mrTaskAttemptId, 1);
RssMRUtils.createRssTaskAttemptId(mrTaskAttemptId, 1, 4);
} catch (RssException e) {
isException = true;
}
assertTrue(isException);
taskAttemptId = (1 << 20) + 0x123;
mrTaskAttemptId = RssMRUtils.createMRTaskAttemptId(new JobID(), TaskType.MAP, taskAttemptId, 1);
long testId = RssMRUtils.convertTaskAttemptIdToLong(mrTaskAttemptId, 1);
taskAttemptId = (0x123 << 3) + 1;
mrTaskAttemptId =
RssMRUtils.createMRTaskAttemptId(new JobID(), TaskType.MAP, taskAttemptId, 1, 4);
int testId = RssMRUtils.createRssTaskAttemptId(mrTaskAttemptId, 1, 4);
assertEquals(taskAttemptId, testId);
TaskID taskID = new TaskID(new org.apache.hadoop.mapred.JobID(), TaskType.MAP, (int) (1 << 21));
mrTaskAttemptId = new TaskAttemptID(taskID, 2);
isException = false;
try {
RssMRUtils.convertTaskAttemptIdToLong(mrTaskAttemptId, 1);
RssMRUtils.createRssTaskAttemptId(mrTaskAttemptId, 1, 4);
} catch (RssException e) {
isException = true;
}
Expand All @@ -70,7 +71,7 @@ public void blockConvertTest() {
JobID jobID = new JobID();
TaskID taskId = new TaskID(jobID, TaskType.MAP, 233);
TaskAttemptID taskAttemptID = new TaskAttemptID(taskId, 1);
long taskAttemptId = RssMRUtils.convertTaskAttemptIdToLong(taskAttemptID, 1);
int taskAttemptId = RssMRUtils.createRssTaskAttemptId(taskAttemptID, 1, 4);
long blockId = RssMRUtils.getBlockId(1, taskAttemptId, 0);
long newTaskAttemptId = RssMRUtils.getTaskAttemptId(blockId);
assertEquals(taskAttemptId, newTaskAttemptId);
Expand All @@ -85,7 +86,7 @@ public void partitionIdConvertBlockTest() {
JobID jobID = new JobID();
TaskID taskId = new TaskID(jobID, TaskType.MAP, 233);
TaskAttemptID taskAttemptID = new TaskAttemptID(taskId, 1);
long taskAttemptId = RssMRUtils.convertTaskAttemptIdToLong(taskAttemptID, 1);
int taskAttemptId = RssMRUtils.createRssTaskAttemptId(taskAttemptID, 1, 4);
long mask = (1L << layout.partitionIdBits) - 1;
for (int partitionId = 0; partitionId <= 3000; partitionId++) {
for (int seqNo = 0; seqNo <= 10; seqNo++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ public void singlePassEventFetch() throws IOException {
RssEventFetcher ef = new RssEventFetcher(1, tid, umbilical, jobConf, MAX_EVENTS_TO_FETCH);
Roaring64NavigableMap expected = Roaring64NavigableMap.bitmapOf();
for (int mapIndex = 0; mapIndex < mapTaskNum; mapIndex++) {
long rssTaskId =
RssMRUtils.convertTaskAttemptIdToLong(
new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1);
int rssTaskId =
RssMRUtils.createRssTaskAttemptId(
new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1, 4);
expected.addLong(rssTaskId);
}

Expand Down Expand Up @@ -88,9 +88,9 @@ public void singlePassWithRepeatedSuccessEventFetch() throws IOException {
RssEventFetcher ef = new RssEventFetcher(1, tid, umbilical, jobConf, MAX_EVENTS_TO_FETCH);
Roaring64NavigableMap expected = Roaring64NavigableMap.bitmapOf();
for (int mapIndex = 0; mapIndex < mapTaskNum; mapIndex++) {
long rssTaskId =
RssMRUtils.convertTaskAttemptIdToLong(
new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1);
int rssTaskId =
RssMRUtils.createRssTaskAttemptId(
new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1, 4);
expected.addLong(rssTaskId);
}

Expand Down Expand Up @@ -120,9 +120,9 @@ public void multiPassEventFetch() throws IOException {

Roaring64NavigableMap expected = Roaring64NavigableMap.bitmapOf();
for (int mapIndex = 0; mapIndex < mapTaskNum; mapIndex++) {
long rssTaskId =
RssMRUtils.convertTaskAttemptIdToLong(
new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1);
int rssTaskId =
RssMRUtils.createRssTaskAttemptId(
new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1, 4);
expected.addLong(rssTaskId);
}
Roaring64NavigableMap taskIdBitmap = ef.fetchAllRssTaskIds();
Expand All @@ -145,9 +145,9 @@ public void missingEventFetch() throws IOException {
RssEventFetcher ef = new RssEventFetcher(1, tid, umbilical, jobConf, MAX_EVENTS_TO_FETCH);
Roaring64NavigableMap expected = Roaring64NavigableMap.bitmapOf();
for (int mapIndex = 0; mapIndex < mapTaskNum; mapIndex++) {
long rssTaskId =
RssMRUtils.convertTaskAttemptIdToLong(
new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1);
int rssTaskId =
RssMRUtils.createRssTaskAttemptId(
new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1, 4);
expected.addLong(rssTaskId);
}
IllegalStateException ex =
Expand All @@ -171,9 +171,9 @@ public void extraEventFetch() throws IOException {
RssEventFetcher ef = new RssEventFetcher(1, tid, umbilical, jobConf, MAX_EVENTS_TO_FETCH);
Roaring64NavigableMap expected = Roaring64NavigableMap.bitmapOf();
for (int mapIndex = 0; mapIndex < mapTaskNum; mapIndex++) {
long rssTaskId =
RssMRUtils.convertTaskAttemptIdToLong(
new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1);
int rssTaskId =
RssMRUtils.createRssTaskAttemptId(
new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1, 4);
expected.addLong(rssTaskId);
}
IllegalStateException ex =
Expand Down Expand Up @@ -204,15 +204,15 @@ public void obsoletedAndTipFailedEventFetch() throws IOException {
Roaring64NavigableMap expected = Roaring64NavigableMap.bitmapOf();
for (int mapIndex = 0; mapIndex < mapTaskNum; mapIndex++) {
if (!tipFailed.contains(mapIndex) && !obsoleted.contains(mapIndex)) {
long rssTaskId =
RssMRUtils.convertTaskAttemptIdToLong(
new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1);
int rssTaskId =
RssMRUtils.createRssTaskAttemptId(
new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1, 4);
expected.addLong(rssTaskId);
}
if (obsoleted.contains(mapIndex)) {
long rssTaskId =
RssMRUtils.convertTaskAttemptIdToLong(
new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 1), 1);
int rssTaskId =
RssMRUtils.createRssTaskAttemptId(
new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 1), 1, 4);
expected.addLong(rssTaskId);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,8 @@ public void testCodecIsDuplicated() throws Exception {
null,
new Progress(),
new MROutputFiles());
TaskAttemptID taskAttemptID = RssMRUtils.createMRTaskAttemptId(new JobID(), TaskType.MAP, 1, 1);
TaskAttemptID taskAttemptID =
RssMRUtils.createMRTaskAttemptId(new JobID(), TaskType.MAP, 1, 1, 4);
byte[] buffer = new byte[10];
MapOutput mapOutput1 = merger.reserve(taskAttemptID, 10, 1);
RssBypassWriter.write(mapOutput1, buffer);
Expand Down Expand Up @@ -349,7 +350,7 @@ private static byte[] writeMapOutputRss(Configuration conf, Map<String, String>
SortWriteBufferManager<Text, Text> manager =
new SortWriteBufferManager(
10240,
1L,
1,
10,
serializationFactory.getSerializer(Text.class),
serializationFactory.getSerializer(Text.class),
Expand Down