Skip to content

Commit

Permalink
[#1068] feat(tez): Fail fast in client when failed to send data to se…
Browse files Browse the repository at this point in the history
…rver. (#1069)

### What changes were proposed in this pull request?

Currently, it only checks for blocks that failed to send after all buffer data has been sent. 
This check also needs to be moved forward into the addRecord method, allowing it to fail fast.

### Why are the changes needed?
Fix: # ([1068](#1068))

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Add more test case in WriteBufferManager.
  • Loading branch information
zhuyaogai committed Aug 3, 2023
1 parent 54c81f0 commit 1e4804b
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ public void addRecord(int partitionId, K key, V value) throws InterruptedExcepti
memoryLock.unlock();
}

// Fail fast if there are some failed blocks.
checkFailedBlocks();

if (!buffers.containsKey(partitionId)) {
WriteBuffer<K, V> sortWriterBuffer =
new WriteBuffer(
Expand Down Expand Up @@ -282,14 +285,7 @@ public void waitSendFinished() {
}
long start = System.currentTimeMillis();
while (true) {
if (failedBlockIds.size() > 0) {
String errorMsg =
"Send failed: failed because "
+ failedBlockIds.size()
+ " blocks can't be sent to shuffle server.";
LOG.error(errorMsg);
throw new RssException(errorMsg);
}
checkFailedBlocks();
allBlockIds.removeAll(successBlockIds);
if (allBlockIds.isEmpty()) {
break;
Expand Down Expand Up @@ -335,6 +331,18 @@ public void waitSendFinished() {
sortTime);
}

// Check if there are some failed blocks, if true then throw Exception.
private void checkFailedBlocks() {
if (failedBlockIds.size() > 0) {
String errorMsg =
"Send failed: failed because "
+ failedBlockIds.size()
+ " blocks can't be sent to shuffle server.";
LOG.error(errorMsg);
throw new RssException(errorMsg);
}
}

ShuffleBlockInfo createShuffleBlock(WriteBuffer wb) {
byte[] data = wb.getData();
copyTime += wb.getCopyTime();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
import org.apache.tez.common.counters.TaskCounter;
import org.apache.tez.common.counters.TezCounter;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.dag.records.TezVertexID;
import org.apache.tez.runtime.api.OutputContext;
import org.apache.tez.runtime.library.api.TezRuntimeConfiguration;
import org.apache.tez.runtime.library.output.OutputTestHelpers;
Expand All @@ -67,6 +66,7 @@
import org.apache.uniffle.storage.util.StorageType;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

public class WriteBufferManagerTest {
Expand Down Expand Up @@ -99,7 +99,9 @@ public void testWriteException(@TempDir File tmpDir) throws IOException, Interru
long sendCheckInterval = 500L;
long sendCheckTimeout = 5;
int bitmapSplitNum = 1;
int shuffleId = getShuffleId(tezTaskAttemptID, 1, 2);
int shuffleId =
RssTezUtils.computeShuffleId(
tezTaskAttemptID.getTaskID().getVertexID().getDAGId().getId(), 1, 2);

Configuration conf = new Configuration();
FileSystem localFs = FileSystem.getLocal(conf);
Expand Down Expand Up @@ -139,7 +141,7 @@ public void testWriteException(@TempDir File tmpDir) throws IOException, Interru
rssConf,
partitionToServers,
numMaps,
isMemoryShuffleEnabled(storageType),
StorageType.withMemory(StorageType.valueOf(storageType)),
sendCheckInterval,
sendCheckTimeout,
bitmapSplitNum,
Expand Down Expand Up @@ -197,7 +199,9 @@ public void testWriteNormal(@TempDir File tmpDir) throws IOException, Interrupte
long sendCheckInterval = 500L;
long sendCheckTimeout = 60 * 1000 * 10L;
int bitmapSplitNum = 1;
int shuffleId = getShuffleId(tezTaskAttemptID, 1, 2);
int shuffleId =
RssTezUtils.computeShuffleId(
tezTaskAttemptID.getTaskID().getVertexID().getDAGId().getId(), 1, 2);

Configuration conf = new Configuration();
FileSystem localFs = FileSystem.getLocal(conf);
Expand Down Expand Up @@ -237,7 +241,7 @@ public void testWriteNormal(@TempDir File tmpDir) throws IOException, Interrupte
rssConf,
partitionToServers,
numMaps,
isMemoryShuffleEnabled(storageType),
StorageType.withMemory(StorageType.valueOf(storageType)),
sendCheckInterval,
sendCheckTimeout,
bitmapSplitNum,
Expand Down Expand Up @@ -305,7 +309,9 @@ public void testCommitBlocksWhenMemoryShuffleDisabled(@TempDir File tmpDir)
long sendCheckInterval = 500L;
long sendCheckTimeout = 60 * 1000 * 10L;
int bitmapSplitNum = 1;
int shuffleId = getShuffleId(tezTaskAttemptID, 1, 2);
int shuffleId =
RssTezUtils.computeShuffleId(
tezTaskAttemptID.getTaskID().getVertexID().getDAGId().getId(), 1, 2);

Configuration conf = new Configuration();
FileSystem localFs = FileSystem.getLocal(conf);
Expand Down Expand Up @@ -371,15 +377,109 @@ public void testCommitBlocksWhenMemoryShuffleDisabled(@TempDir File tmpDir)
writeClient.mockedShuffleServer.getFlushBlockSize());
}

private int getShuffleId(TezTaskAttemptID tezTaskAttemptID, int upVertexId, int downVertexId) {
TezVertexID tezVertexID = tezTaskAttemptID.getTaskID().getVertexID();
@Test
public void testFastFailWhenSendBlocksFailed(@TempDir File tmpDir)
throws IOException, InterruptedException {
TezTaskAttemptID tezTaskAttemptID =
TezTaskAttemptID.fromString("attempt_1681717153064_3770270_1_00_000000_0");
final long maxMemSize = 10240;
final String appId = "application_1681717153064_3770270";
final long taskAttemptId = 0;
final Set<Long> successBlockIds = Sets.newConcurrentHashSet();
final Set<Long> failedBlockIds = Sets.newConcurrentHashSet();
MockShuffleWriteClient writeClient = new MockShuffleWriteClient();
// set mode = 1 to fake sending shuffle data failed.
writeClient.setMode(1);
RawComparator comparator = WritableComparator.get(BytesWritable.class);
long maxSegmentSize = 3 * 1024;
SerializationFactory serializationFactory = new SerializationFactory(new JobConf());
Serializer<BytesWritable> keySerializer =
serializationFactory.getSerializer(BytesWritable.class);
Serializer<BytesWritable> valSerializer =
serializationFactory.getSerializer(BytesWritable.class);
// note: max buffer size is tiny.
long maxBufferSize = 14 * 1024;
double memoryThreshold = 0.8f;
int sendThreadNum = 1;
double sendThreshold = 0.2f;
int batch = 50;
int numMaps = 1;
RssConf rssConf = new RssConf();
Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
long sendCheckInterval = 500L;
long sendCheckTimeout = 60 * 1000 * 10L;
int bitmapSplitNum = 1;
int shuffleId =
RssTezUtils.computeShuffleId(tezVertexID.getDAGId().getId(), upVertexId, downVertexId);
return shuffleId;
}
RssTezUtils.computeShuffleId(
tezTaskAttemptID.getTaskID().getVertexID().getDAGId().getId(), 1, 2);

Configuration conf = new Configuration();
FileSystem localFs = FileSystem.getLocal(conf);
Path workingDir =
new Path(
System.getProperty(
"test.build.data", System.getProperty("java.io.tmpdir", tmpDir.toString())),
RssOrderedPartitionedKVOutputTest.class.getName())
.makeQualified(localFs.getUri(), localFs.getWorkingDirectory());
conf.set(TezRuntimeConfiguration.TEZ_RUNTIME_KEY_CLASS, Text.class.getName());
conf.set(TezRuntimeConfiguration.TEZ_RUNTIME_VALUE_CLASS, Text.class.getName());
conf.set(
TezRuntimeConfiguration.TEZ_RUNTIME_PARTITIONER_CLASS, HashPartitioner.class.getName());
conf.setStrings(TezRuntimeFrameworkConfigs.LOCAL_DIRS, workingDir.toString());
OutputContext outputContext = OutputTestHelpers.createOutputContext(conf, workingDir);
TezCounter mapOutputByteCounter =
outputContext.getCounters().findCounter(TaskCounter.OUTPUT_BYTES);

private boolean isMemoryShuffleEnabled(String storageType) {
return StorageType.withMemory(StorageType.valueOf(storageType));
WriteBufferManager<BytesWritable, BytesWritable> bufferManager =
new WriteBufferManager(
tezTaskAttemptID,
maxMemSize,
appId,
taskAttemptId,
successBlockIds,
failedBlockIds,
writeClient,
comparator,
maxSegmentSize,
keySerializer,
valSerializer,
maxBufferSize,
memoryThreshold,
sendThreadNum,
sendThreshold,
batch,
rssConf,
partitionToServers,
numMaps,
false,
sendCheckInterval,
sendCheckTimeout,
bitmapSplitNum,
shuffleId,
true,
mapOutputByteCounter);

Random random = new Random();
RssException rssException =
assertThrows(
RssException.class,
() -> {
for (int i = 0; i < 10000; i++) {
byte[] key = new byte[20];
byte[] value = new byte[1024];
random.nextBytes(key);
random.nextBytes(value);
int partitionId = random.nextInt(50);
bufferManager.addRecord(
partitionId, new BytesWritable(key), new BytesWritable(value));
}
});
assertTrue(rssException.getMessage().contains("Send failed"));

rssException = assertThrows(RssException.class, bufferManager::waitSendFinished);
assertTrue(rssException.getMessage().contains("Send failed"));

assertTrue(mapOutputByteCounter.getValue() < 10520000);
}

class MockShuffleServer {
Expand Down

0 comments on commit 1e4804b

Please sign in to comment.