Skip to content

Commit

Permalink
[#1018] test(tez) RssUnorderedPartitionedKVOutputTest add close func …
Browse files Browse the repository at this point in the history
…unit test (#1034)

### What changes were proposed in this pull request?

tez-client, RssUnorderedPartitionedKVOutputTest add close func unit test

### Why are the changes needed?

tez-client, RssUnorderedPartitionedKVOutputTest add close func unit test

Fix: #1018 

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

unit test.
  • Loading branch information
bin41215 committed Jul 24, 2023
1 parent ecfed5e commit 6fb2a9a
Showing 1 changed file with 54 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.ipc.RPC;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.util.ConverterUtils;
import org.apache.tez.common.GetShuffleServerResponse;
import org.apache.tez.common.ShuffleAssignmentsInfoWritable;
import org.apache.tez.common.TezCommonUtils;
import org.apache.tez.common.TezRemoteShuffleUmbilicalProtocol;
import org.apache.tez.common.TezRuntimeFrameworkConfigs;
import org.apache.tez.common.TezUtilsInternal;
import org.apache.tez.runtime.api.Event;
Expand All @@ -41,21 +47,35 @@
import org.apache.tez.runtime.library.partitioner.HashPartitioner;
import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.mockito.MockedStatic;
import org.mockito.Mockito;

import org.apache.uniffle.common.ShuffleAssignmentsInfo;
import org.apache.uniffle.common.ShuffleServerInfo;

import static org.apache.tez.common.RssTezConfig.RSS_AM_SHUFFLE_MANAGER_ADDRESS;
import static org.apache.tez.common.RssTezConfig.RSS_AM_SHUFFLE_MANAGER_PORT;
import static org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_DESTINATION_VERTEX_ID;
import static org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_SOURCE_VERTEX_ID;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;

public class RssUnorderedPartitionedKVOutputTest {
private static Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
private Configuration conf;
private FileSystem localFs;
private Path workingDir;

/** set up */
@BeforeEach
public void setup() throws IOException {
conf = new Configuration();
Expand All @@ -70,6 +90,10 @@ public void setup() throws IOException {
conf.set(
TezRuntimeConfiguration.TEZ_RUNTIME_PARTITIONER_CLASS, HashPartitioner.class.getName());
conf.setStrings(TezRuntimeFrameworkConfigs.LOCAL_DIRS, workingDir.toString());
conf.set(RSS_AM_SHUFFLE_MANAGER_ADDRESS, "localhost");
conf.setInt(RSS_AM_SHUFFLE_MANAGER_PORT, 0);
conf.setInt(RSS_SHUFFLE_SOURCE_VERTEX_ID, 0);
conf.setInt(RSS_SHUFFLE_DESTINATION_VERTEX_ID, 1);
}

@AfterEach
Expand Down Expand Up @@ -103,4 +127,34 @@ public void testNonStartedOutput() throws Exception {
assertTrue(emptyPartionsBitSet.get(i));
}
}

@Test
@Timeout(value = 8000, unit = TimeUnit.MILLISECONDS)
public void testClose() throws Exception {
try (MockedStatic<RPC> rpc = Mockito.mockStatic(RPC.class); ) {
TezRemoteShuffleUmbilicalProtocol protocol = mock(TezRemoteShuffleUmbilicalProtocol.class);
GetShuffleServerResponse response = new GetShuffleServerResponse();
ShuffleAssignmentsInfo shuffleAssignmentsInfo =
new ShuffleAssignmentsInfo(new HashMap(), new HashMap());
response.setShuffleAssignmentsInfoWritable(
new ShuffleAssignmentsInfoWritable(shuffleAssignmentsInfo));
doReturn(response).when(protocol).getShuffleAssignments(any());
rpc.when(() -> RPC.getProxy(any(), anyLong(), any(), any())).thenReturn(protocol);
try (MockedStatic<ConverterUtils> converterUtils = Mockito.mockStatic(ConverterUtils.class)) {
ContainerId containerId = ContainerId.newContainerId(OutputTestHelpers.APP_ATTEMPT_ID, 1);
converterUtils.when(() -> ConverterUtils.toContainerId(null)).thenReturn(containerId);
converterUtils
.when(() -> ConverterUtils.toContainerId(anyString()))
.thenReturn(containerId);
OutputContext outputContext = OutputTestHelpers.createOutputContext(conf, workingDir);
int numPartitions = 1;
RssUnorderedPartitionedKVOutput output =
new RssUnorderedPartitionedKVOutput(outputContext, numPartitions);
output.initialize();
output.start();
Assertions.assertNotNull(output.getWriter());
output.close();
}
}
}
}

0 comments on commit 6fb2a9a

Please sign in to comment.