Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#1068] improvement(tez): Fail fast in WriteBufferManager when failed to send shuffle data to shuffle sever. #1069

Merged
merged 2 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ public void addRecord(int partitionId, K key, V value) throws InterruptedExcepti
memoryLock.unlock();
}

// Fail fast if there are some failed blocks.
checkFailedBlocks();

if (!buffers.containsKey(partitionId)) {
WriteBuffer<K, V> sortWriterBuffer =
new WriteBuffer(
Expand Down Expand Up @@ -282,14 +285,7 @@ public void waitSendFinished() {
}
long start = System.currentTimeMillis();
while (true) {
if (failedBlockIds.size() > 0) {
String errorMsg =
"Send failed: failed because "
+ failedBlockIds.size()
+ " blocks can't be sent to shuffle server.";
LOG.error(errorMsg);
throw new RssException(errorMsg);
}
checkFailedBlocks();
allBlockIds.removeAll(successBlockIds);
if (allBlockIds.isEmpty()) {
break;
Expand Down Expand Up @@ -335,6 +331,18 @@ public void waitSendFinished() {
sortTime);
}

// Check if there are some failed blocks, if true then throw Exception.
private void checkFailedBlocks() {
if (failedBlockIds.size() > 0) {
String errorMsg =
"Send failed: failed because "
+ failedBlockIds.size()
+ " blocks can't be sent to shuffle server.";
LOG.error(errorMsg);
throw new RssException(errorMsg);
}
}

ShuffleBlockInfo createShuffleBlock(WriteBuffer wb) {
byte[] data = wb.getData();
copyTime += wb.getCopyTime();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
import org.apache.tez.common.counters.TaskCounter;
import org.apache.tez.common.counters.TezCounter;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.dag.records.TezVertexID;
import org.apache.tez.runtime.api.OutputContext;
import org.apache.tez.runtime.library.api.TezRuntimeConfiguration;
import org.apache.tez.runtime.library.output.OutputTestHelpers;
Expand All @@ -67,6 +66,7 @@
import org.apache.uniffle.storage.util.StorageType;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

public class WriteBufferManagerTest {
Expand Down Expand Up @@ -99,7 +99,9 @@ public void testWriteException(@TempDir File tmpDir) throws IOException, Interru
long sendCheckInterval = 500L;
long sendCheckTimeout = 5;
int bitmapSplitNum = 1;
int shuffleId = getShuffleId(tezTaskAttemptID, 1, 2);
int shuffleId =
RssTezUtils.computeShuffleId(
tezTaskAttemptID.getTaskID().getVertexID().getDAGId().getId(), 1, 2);

Configuration conf = new Configuration();
FileSystem localFs = FileSystem.getLocal(conf);
Expand Down Expand Up @@ -139,7 +141,7 @@ public void testWriteException(@TempDir File tmpDir) throws IOException, Interru
rssConf,
partitionToServers,
numMaps,
isMemoryShuffleEnabled(storageType),
StorageType.withMemory(StorageType.valueOf(storageType)),
sendCheckInterval,
sendCheckTimeout,
bitmapSplitNum,
Expand Down Expand Up @@ -197,7 +199,9 @@ public void testWriteNormal(@TempDir File tmpDir) throws IOException, Interrupte
long sendCheckInterval = 500L;
long sendCheckTimeout = 60 * 1000 * 10L;
int bitmapSplitNum = 1;
int shuffleId = getShuffleId(tezTaskAttemptID, 1, 2);
int shuffleId =
RssTezUtils.computeShuffleId(
tezTaskAttemptID.getTaskID().getVertexID().getDAGId().getId(), 1, 2);

Configuration conf = new Configuration();
FileSystem localFs = FileSystem.getLocal(conf);
Expand Down Expand Up @@ -237,7 +241,7 @@ public void testWriteNormal(@TempDir File tmpDir) throws IOException, Interrupte
rssConf,
partitionToServers,
numMaps,
isMemoryShuffleEnabled(storageType),
StorageType.withMemory(StorageType.valueOf(storageType)),
sendCheckInterval,
sendCheckTimeout,
bitmapSplitNum,
Expand Down Expand Up @@ -305,7 +309,9 @@ public void testCommitBlocksWhenMemoryShuffleDisabled(@TempDir File tmpDir)
long sendCheckInterval = 500L;
long sendCheckTimeout = 60 * 1000 * 10L;
int bitmapSplitNum = 1;
int shuffleId = getShuffleId(tezTaskAttemptID, 1, 2);
int shuffleId =
RssTezUtils.computeShuffleId(
tezTaskAttemptID.getTaskID().getVertexID().getDAGId().getId(), 1, 2);

Configuration conf = new Configuration();
FileSystem localFs = FileSystem.getLocal(conf);
Expand Down Expand Up @@ -371,15 +377,109 @@ public void testCommitBlocksWhenMemoryShuffleDisabled(@TempDir File tmpDir)
writeClient.mockedShuffleServer.getFlushBlockSize());
}

private int getShuffleId(TezTaskAttemptID tezTaskAttemptID, int upVertexId, int downVertexId) {
TezVertexID tezVertexID = tezTaskAttemptID.getTaskID().getVertexID();
@Test
public void testFastFailWhenSendBlocksFailed(@TempDir File tmpDir)
throws IOException, InterruptedException {
TezTaskAttemptID tezTaskAttemptID =
TezTaskAttemptID.fromString("attempt_1681717153064_3770270_1_00_000000_0");
final long maxMemSize = 10240;
final String appId = "application_1681717153064_3770270";
final long taskAttemptId = 0;
final Set<Long> successBlockIds = Sets.newConcurrentHashSet();
final Set<Long> failedBlockIds = Sets.newConcurrentHashSet();
MockShuffleWriteClient writeClient = new MockShuffleWriteClient();
// set mode = 1 to fake sending shuffle data failed.
writeClient.setMode(1);
RawComparator comparator = WritableComparator.get(BytesWritable.class);
long maxSegmentSize = 3 * 1024;
SerializationFactory serializationFactory = new SerializationFactory(new JobConf());
Serializer<BytesWritable> keySerializer =
serializationFactory.getSerializer(BytesWritable.class);
Serializer<BytesWritable> valSerializer =
serializationFactory.getSerializer(BytesWritable.class);
// note: max buffer size is tiny.
long maxBufferSize = 14 * 1024;
double memoryThreshold = 0.8f;
int sendThreadNum = 1;
double sendThreshold = 0.2f;
int batch = 50;
int numMaps = 1;
RssConf rssConf = new RssConf();
Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
long sendCheckInterval = 500L;
long sendCheckTimeout = 60 * 1000 * 10L;
int bitmapSplitNum = 1;
int shuffleId =
RssTezUtils.computeShuffleId(tezVertexID.getDAGId().getId(), upVertexId, downVertexId);
return shuffleId;
}
RssTezUtils.computeShuffleId(
tezTaskAttemptID.getTaskID().getVertexID().getDAGId().getId(), 1, 2);

Configuration conf = new Configuration();
FileSystem localFs = FileSystem.getLocal(conf);
Path workingDir =
new Path(
System.getProperty(
"test.build.data", System.getProperty("java.io.tmpdir", tmpDir.toString())),
RssOrderedPartitionedKVOutputTest.class.getName())
.makeQualified(localFs.getUri(), localFs.getWorkingDirectory());
conf.set(TezRuntimeConfiguration.TEZ_RUNTIME_KEY_CLASS, Text.class.getName());
conf.set(TezRuntimeConfiguration.TEZ_RUNTIME_VALUE_CLASS, Text.class.getName());
conf.set(
TezRuntimeConfiguration.TEZ_RUNTIME_PARTITIONER_CLASS, HashPartitioner.class.getName());
conf.setStrings(TezRuntimeFrameworkConfigs.LOCAL_DIRS, workingDir.toString());
OutputContext outputContext = OutputTestHelpers.createOutputContext(conf, workingDir);
TezCounter mapOutputByteCounter =
outputContext.getCounters().findCounter(TaskCounter.OUTPUT_BYTES);

private boolean isMemoryShuffleEnabled(String storageType) {
return StorageType.withMemory(StorageType.valueOf(storageType));
WriteBufferManager<BytesWritable, BytesWritable> bufferManager =
new WriteBufferManager(
tezTaskAttemptID,
maxMemSize,
appId,
taskAttemptId,
successBlockIds,
failedBlockIds,
writeClient,
comparator,
maxSegmentSize,
keySerializer,
valSerializer,
maxBufferSize,
memoryThreshold,
sendThreadNum,
sendThreshold,
batch,
rssConf,
partitionToServers,
numMaps,
false,
sendCheckInterval,
sendCheckTimeout,
bitmapSplitNum,
shuffleId,
true,
mapOutputByteCounter);

Random random = new Random();
RssException rssException =
assertThrows(
RssException.class,
() -> {
for (int i = 0; i < 10000; i++) {
byte[] key = new byte[20];
byte[] value = new byte[1024];
random.nextBytes(key);
random.nextBytes(value);
int partitionId = random.nextInt(50);
bufferManager.addRecord(
partitionId, new BytesWritable(key), new BytesWritable(value));
}
});
assertTrue(rssException.getMessage().contains("Send failed"));

rssException = assertThrows(RssException.class, bufferManager::waitSendFinished);
assertTrue(rssException.getMessage().contains("Send failed"));

assertTrue(mapOutputByteCounter.getValue() < 10520000);
}

class MockShuffleServer {
Expand Down