Skip to content

Commit

Permalink
add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengchenyu001 committed May 19, 2023
1 parent 6afeba4 commit 98c4434
Showing 1 changed file with 104 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

package org.apache.hadoop.mapred;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import com.google.common.collect.Sets;
import org.apache.hadoop.conf.Configuration;
Expand Down Expand Up @@ -261,9 +263,91 @@ public void testWriteNormal() throws Exception {
assertTrue(manager.getWaitSendBuffers().isEmpty());
}

@Test
public void testCommitBlocksWhenMemoryShuffleDisabled() throws Exception {
JobConf jobConf = new JobConf(new Configuration());
SerializationFactory serializationFactory = new SerializationFactory(jobConf);
MockShuffleWriteClient client = new MockShuffleWriteClient();
client.setMode(3);
Map<Integer, List<ShuffleServerInfo>> partitionToServers = JavaUtils.newConcurrentMap();
Set<Long> successBlocks = Sets.newConcurrentHashSet();
Set<Long> failedBlocks = Sets.newConcurrentHashSet();
Counters.Counter mapOutputByteCounter = new Counters.Counter();
Counters.Counter mapOutputRecordCounter = new Counters.Counter();
SortWriteBufferManager<BytesWritable, BytesWritable> manager;
manager = new SortWriteBufferManager<BytesWritable, BytesWritable>(
10240,
1L,
10,
serializationFactory.getSerializer(BytesWritable.class),
serializationFactory.getSerializer(BytesWritable.class),
WritableComparator.get(BytesWritable.class),
0.9,
"test",
client,
500,
5 * 1000,
partitionToServers,
successBlocks,
failedBlocks,
mapOutputByteCounter,
mapOutputRecordCounter,
1,
100,
1,
false,
5,
0.2f,
1024000L,
new RssConf());
Random random = new Random();
for (int i = 0; i < 1000; i++) {
byte[] key = new byte[20];
byte[] value = new byte[1024];
random.nextBytes(key);
random.nextBytes(value);
int partitionId = random.nextInt(50);
manager.addRecord(partitionId, new BytesWritable(key), new BytesWritable(value));
}
manager.waitSendFinished();
assertTrue(manager.getWaitSendBuffers().isEmpty());
// When MEMOEY storage type is disable, all blocks should flush.
assertEquals(client.mockedShuffleServer.getFinishBlockSize(), client.mockedShuffleServer.getFlushBlockSize());
}

class MockShuffleServer {

// All methods of MockShuffle are thread safe, because send-thread may do something in concurrent way.
private List<ShuffleBlockInfo> cachedBlockInfos = new ArrayList<>();
private List<ShuffleBlockInfo> flushBlockInfos = new ArrayList<>();
private List<Long> finishedBlockInfos = new ArrayList<>();

public synchronized void finishShuffle() {
flushBlockInfos.addAll(cachedBlockInfos);
}

public synchronized void addCachedBlockInfos(List<ShuffleBlockInfo> shuffleBlockInfoList) {
cachedBlockInfos.addAll(shuffleBlockInfoList);
}

public synchronized void addFinishedBlockInfos(List<Long> shuffleBlockInfoList) {
finishedBlockInfos.addAll(shuffleBlockInfoList);
}

public synchronized int getFlushBlockSize() {
return flushBlockInfos.size();
}

public synchronized int getFinishBlockSize() {
return finishedBlockInfos.size();
}
}

class MockShuffleWriteClient implements ShuffleWriteClient {

int mode = 0;
MockShuffleServer mockedShuffleServer = new MockShuffleServer();
int committedMaps = 0;

public void setMode(int mode) {
this.mode = mode;
Expand All @@ -277,6 +361,15 @@ public SendShuffleDataResult sendShuffleData(String appId, List<ShuffleBlockInfo
} else if (mode == 1) {
return new SendShuffleDataResult(Sets.newHashSet(2L), Sets.newHashSet(1L));
} else {
if (mode == 3) {
try {
Thread.sleep(10);
mockedShuffleServer.addCachedBlockInfos(shuffleBlockInfoList);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RssException(e);
}
}
Set<Long> successBlockIds = Sets.newHashSet();
for (ShuffleBlockInfo blockInfo : shuffleBlockInfoList) {
successBlockIds.add(blockInfo.getBlockId());
Expand Down Expand Up @@ -308,6 +401,13 @@ public void registerShuffle(

@Override
public boolean sendCommit(Set<ShuffleServerInfo> shuffleServerInfoSet, String appId, int shuffleId, int numMaps) {
if (mode == 3) {
committedMaps++;
if (committedMaps >= numMaps) {
mockedShuffleServer.finishShuffle();
}
return true;
}
return false;
}

Expand All @@ -329,7 +429,10 @@ public RemoteStorageInfo fetchRemoteStorage(String appId) {
@Override
public void reportShuffleResult(Map<Integer, List<ShuffleServerInfo>> partitionToServers, String appId,
int shuffleId, long taskAttemptId, Map<Integer, List<Long>> partitionToBlockIds, int bitmapNum) {

if (mode == 3) {
mockedShuffleServer.addFinishedBlockInfos(
partitionToBlockIds.values().stream().flatMap(it -> it.stream()).collect(Collectors.toList()));
}
}

@Override
Expand Down

0 comments on commit 98c4434

Please sign in to comment.