Skip to content

Commit

Permalink
Address comment
Browse files Browse the repository at this point in the history
  • Loading branch information
gaoyunhaii committed Oct 30, 2021
1 parent e8f3eea commit efd697d
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 144 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@
@Internal
public class IterationFactory {

/**
* We finally want to connect the head / tail of the criteria stream into the same pipeline
* region to avoid they do not restart simultaneously.
*/
private static final OutputTag<IterationRecord<Integer>> FAKE_CRITERIA_OUTPUT_TAG =
new OutputTag<IterationRecord<Integer>>("fake") {};

@SuppressWarnings({"unchecked", "rawtypes"})
public static DataStreamList createIteration(
DataStreamList initVariableStreams,
Expand Down Expand Up @@ -148,7 +155,7 @@ private static void addCriteriaStream(
// It should always has the IterationRecordTypeInfo
checkState(
terminationCriteria.getType().getClass().equals(IterationRecordTypeInfo.class),
"The termination criteria should always returns IterationRecord.");
"The termination criteria should always return IterationRecord.");
TypeInformation<?> innerType =
((IterationRecordTypeInfo<?>) terminationCriteria.getType()).getInnerTypeInfo();

Expand Down Expand Up @@ -179,10 +186,10 @@ private static void addCriteriaStream(

// Since co-located task must be in the same region, we will have to add a fake op.
((SingleOutputStreamOperator<?>) criteriaHeaders.get(0))
.getSideOutput(new OutputTag<IterationRecord<Integer>>("fake") {})
.getSideOutput(FAKE_CRITERIA_OUTPUT_TAG)
.union(
((SingleOutputStreamOperator<?>) criteriaTails.get(0))
.getSideOutput(new OutputTag<IterationRecord<Integer>>("fake") {}))
.getSideOutput(FAKE_CRITERIA_OUTPUT_TAG))
.map(x -> x)
.returns(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
.name("criteria-discard")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
import org.apache.flink.iteration.Iterations;
import org.apache.flink.iteration.ReplayableDataStreamList;
import org.apache.flink.iteration.compile.DraftExecutionEnvironment;
import org.apache.flink.iteration.itcases.operators.CollectSink;
import org.apache.flink.iteration.itcases.operators.EpochRecord;
import org.apache.flink.iteration.itcases.operators.IncrementEpochMap;
import org.apache.flink.iteration.itcases.operators.OutputRecord;
import org.apache.flink.iteration.itcases.operators.RoundBasedTerminationCriteria;
import org.apache.flink.iteration.itcases.operators.SequenceSource;
Expand All @@ -35,194 +38,185 @@
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
import org.apache.flink.testutils.junit.SharedObjects;
import org.apache.flink.testutils.junit.SharedReference;
import org.apache.flink.util.OutputTag;
import org.apache.flink.util.TestLogger;

import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;

import java.util.HashMap;
import javax.annotation.Nullable;

import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;

import static org.apache.flink.iteration.itcases.UnboundedStreamIterationITCase.computeRoundStat;
import static org.apache.flink.iteration.itcases.UnboundedStreamIterationITCase.createMiniClusterConfiguration;
import static org.apache.flink.iteration.itcases.UnboundedStreamIterationITCase.createVariableAndConstantJobGraph;
import static org.apache.flink.iteration.itcases.UnboundedStreamIterationITCase.createVariableOnlyJobGraph;
import static org.apache.flink.iteration.itcases.UnboundedStreamIterationITCase.verifyResult;
import static org.junit.Assert.assertEquals;

/**
* Tests the cases of {@link Iterations#iterateBoundedStreamsUntilTermination(DataStreamList,
* DataStreamList, IterationBody)}.
* ReplayableDataStreamList, IterationConfig, IterationBody)} that using all-round iterations.
*/
public class BoundedAllRoundStreamIterationITCase {
public class BoundedAllRoundStreamIterationITCase extends TestLogger {

@Rule public final SharedObjects sharedObjects = SharedObjects.create();

private static BlockingQueue<OutputRecord<Integer>> result = new LinkedBlockingQueue<>();
private MiniCluster miniCluster;

private SharedReference<BlockingQueue<OutputRecord<Integer>>> result;

@Before
public void setup() {
result.clear();
public void setup() throws Exception {
miniCluster = new MiniCluster(createMiniClusterConfiguration(2, 2));
miniCluster.start();

result = sharedObjects.add(new LinkedBlockingQueue<>());
}

@Test(timeout = 60000)
public void testSyncVariableOnlyBoundedIteration() throws Exception {
try (MiniCluster miniCluster = new MiniCluster(createMiniClusterConfiguration(2, 2))) {
miniCluster.start();

// Create the test job
JobGraph jobGraph =
createVariableOnlyJobGraph(
4,
1000,
false,
0,
true,
4,
new SinkFunction<OutputRecord<Integer>>() {
@Override
public void invoke(OutputRecord<Integer> value, Context context) {
result.add(value);
}
});
miniCluster.executeJobBlocking(jobGraph);

assertEquals(6, result.size());

Map<Integer, Tuple2<Integer, Integer>> roundsStat = new HashMap<>();
for (int i = 0; i < 5; ++i) {
OutputRecord<Integer> next = result.take();
assertEquals(OutputRecord.Event.EPOCH_WATERMARK_INCREMENTED, next.getEvent());
Tuple2<Integer, Integer> state =
roundsStat.computeIfAbsent(next.getRound(), ignored -> new Tuple2<>(0, 0));
state.f0++;
state.f1 = next.getValue();
}

verifyResult(roundsStat, 5, 1, 4 * (0 + 999) * 1000 / 2);
assertEquals(OutputRecord.Event.TERMINATED, result.take().getEvent());
@After
public void teardown() throws Exception {
if (miniCluster != null) {
miniCluster.close();
}
}

@Test(timeout = 60000)
public void testSyncVariableAndConstantBoundedIteration() throws Exception {
try (MiniCluster miniCluster = new MiniCluster(createMiniClusterConfiguration(2, 2))) {
miniCluster.start();

// Create the test job
JobGraph jobGraph =
createVariableAndConstantJobGraph(
4,
1000,
false,
0,
true,
4,
new SinkFunction<OutputRecord<Integer>>() {
@Override
public void invoke(OutputRecord<Integer> value, Context context) {
result.add(value);
}
});
miniCluster.executeJobBlocking(jobGraph);

assertEquals(6, result.size());

Map<Integer, Tuple2<Integer, Integer>> roundsStat = new HashMap<>();
for (int i = 0; i < 5; ++i) {
OutputRecord<Integer> next = result.take();
assertEquals(OutputRecord.Event.EPOCH_WATERMARK_INCREMENTED, next.getEvent());
Tuple2<Integer, Integer> state =
roundsStat.computeIfAbsent(next.getRound(), ignored -> new Tuple2<>(0, 0));
state.f0++;
state.f1 = next.getValue();
}

verifyResult(roundsStat, 5, 1, 4 * (0 + 999) * 1000 / 2);
assertEquals(OutputRecord.Event.TERMINATED, result.take().getEvent());
}
public void testSyncVariableOnlyBoundedIteration() throws Exception {
JobGraph jobGraph = createVariableOnlyJobGraph(4, 1000, false, 0, true, 4, null, result);
miniCluster.executeJobBlocking(jobGraph);

assertEquals(6, result.get().size());
Map<Integer, Tuple2<Integer, Integer>> roundsStat =
computeRoundStat(result.get(), OutputRecord.Event.EPOCH_WATERMARK_INCREMENTED, 5);

verifyResult(roundsStat, 5, 1, 4 * (0 + 999) * 1000 / 2);
assertEquals(OutputRecord.Event.TERMINATED, result.get().take().getEvent());
}

@Test
public void testTerminationCriteria() throws Exception {
try (MiniCluster miniCluster = new MiniCluster(createMiniClusterConfiguration(2, 2))) {
miniCluster.start();

// Create the test job
JobGraph jobGraph =
createJobGraphWithTerminationCriteria(
4,
1000,
false,
0,
true,
4,
new SinkFunction<OutputRecord<Integer>>() {
@Override
public void invoke(OutputRecord<Integer> value, Context context) {
result.add(value);
}
});
miniCluster.executeJobBlocking(jobGraph);

assertEquals(6, result.size());

Map<Integer, Tuple2<Integer, Integer>> roundsStat = new HashMap<>();
for (int i = 0; i < 5; ++i) {
OutputRecord<Integer> next = result.take();
assertEquals(OutputRecord.Event.EPOCH_WATERMARK_INCREMENTED, next.getEvent());
Tuple2<Integer, Integer> state =
roundsStat.computeIfAbsent(next.getRound(), ignored -> new Tuple2<>(0, 0));
state.f0++;
state.f1 = next.getValue();
}

verifyResult(roundsStat, 5, 1, 4 * (0 + 999) * 1000 / 2);
assertEquals(OutputRecord.Event.TERMINATED, result.take().getEvent());
}
public void testSyncVariableOnlyBoundedIterationWithTerminationCriteria() throws Exception {
JobGraph jobGraph = createVariableOnlyJobGraph(4, 1000, false, 0, true, 40, 4, result);
miniCluster.executeJobBlocking(jobGraph);

assertEquals(6, result.get().size());
Map<Integer, Tuple2<Integer, Integer>> roundsStat =
computeRoundStat(result.get(), OutputRecord.Event.EPOCH_WATERMARK_INCREMENTED, 5);

verifyResult(roundsStat, 5, 1, 4 * (0 + 999) * 1000 / 2);
assertEquals(OutputRecord.Event.TERMINATED, result.get().take().getEvent());
}

@Test(timeout = 60000)
public void testSyncVariableAndConstantBoundedIteration() throws Exception {
JobGraph jobGraph = createVariableAndConstantJobGraph(4, 1000, false, 0, true, 4, result);
miniCluster.executeJobBlocking(jobGraph);

assertEquals(6, result.get().size());
Map<Integer, Tuple2<Integer, Integer>> roundsStat =
computeRoundStat(result.get(), OutputRecord.Event.EPOCH_WATERMARK_INCREMENTED, 5);

verifyResult(roundsStat, 5, 1, 4 * (0 + 999) * 1000 / 2);
assertEquals(OutputRecord.Event.TERMINATED, result.get().take().getEvent());
}

static JobGraph createJobGraphWithTerminationCriteria(
private static JobGraph createVariableOnlyJobGraph(
int numSources,
int numRecordsPerSource,
boolean holdSource,
int period,
boolean sync,
int maxRound,
SinkFunction<OutputRecord<Integer>> sinkFunction) {
@Nullable Integer terminationCriteriaRound,
SharedReference<BlockingQueue<OutputRecord<Integer>>> result) {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(1);
DataStream<Integer> variableSource =
env.addSource(new DraftExecutionEnvironment.EmptySource<Integer>() {})
DataStream<EpochRecord> variableSource =
env.addSource(new DraftExecutionEnvironment.EmptySource<EpochRecord>() {})
.setParallelism(numSources)
.name("Variable");
DataStream<Integer> constSource =
DataStream<EpochRecord> constSource =
env.addSource(new SequenceSource(numRecordsPerSource, holdSource, period))
.setParallelism(numSources)
.name("Constants");
.name("Constant");
DataStreamList outputs =
Iterations.iterateBoundedStreamsUntilTermination(
DataStreamList.of(variableSource),
ReplayableDataStreamList.notReplay(constSource),
IterationConfig.newBuilder().build(),
(variableStreams, dataStreams) -> {
SingleOutputStreamOperator<Integer> reducer =
SingleOutputStreamOperator<EpochRecord> reducer =
variableStreams
.<Integer>get(0)
.connect(dataStreams.<Integer>get(0))
.<EpochRecord>get(0)
.connect(dataStreams.<EpochRecord>get(0))
.process(
new TwoInputReduceAllRoundProcessFunction(
sync, maxRound * 10));
sync, maxRound));
return new IterationBodyResult(
DataStreamList.of(
reducer.map(x -> x).setParallelism(numSources)),
reducer.map(new IncrementEpochMap())
.setParallelism(numSources)),
DataStreamList.of(
reducer.getSideOutput(
new OutputTag<OutputRecord<Integer>>(
"output") {})),
reducer.flatMap(new RoundBasedTerminationCriteria(maxRound)));
terminationCriteriaRound == null
? null
: reducer.flatMap(
new RoundBasedTerminationCriteria(
terminationCriteriaRound)));
});
outputs.<OutputRecord<Integer>>get(0).addSink(new CollectSink(result));

return env.getStreamGraph().getJobGraph();
}

private static JobGraph createVariableAndConstantJobGraph(
int numSources,
int numRecordsPerSource,
boolean holdSource,
int period,
boolean sync,
int maxRound,
SharedReference<BlockingQueue<OutputRecord<Integer>>> result) {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(1);
DataStream<EpochRecord> variableSource =
env.addSource(new DraftExecutionEnvironment.EmptySource<EpochRecord>() {})
.setParallelism(numSources)
.name("Variable");
DataStream<EpochRecord> constSource =
env.addSource(new SequenceSource(numRecordsPerSource, holdSource, period))
.setParallelism(numSources)
.name("Constant");
DataStreamList outputs =
Iterations.iterateBoundedStreamsUntilTermination(
DataStreamList.of(variableSource),
ReplayableDataStreamList.notReplay(constSource),
IterationConfig.newBuilder().build(),
(variableStreams, dataStreams) -> {
SingleOutputStreamOperator<EpochRecord> reducer =
variableStreams
.<EpochRecord>get(0)
.connect(dataStreams.<EpochRecord>get(0))
.process(
new TwoInputReduceAllRoundProcessFunction(
sync, maxRound));
return new IterationBodyResult(
DataStreamList.of(
reducer.map(new IncrementEpochMap())
.setParallelism(numSources)),
DataStreamList.of(
reducer.getSideOutput(
new OutputTag<OutputRecord<Integer>>(
"output") {})));
});
outputs.<OutputRecord<Integer>>get(0).addSink(sinkFunction);
outputs.<OutputRecord<Integer>>get(0).addSink(new CollectSink(result));

return env.getStreamGraph().getJobGraph();
}
Expand Down

0 comments on commit efd697d

Please sign in to comment.