Skip to content

Commit

Permalink
[FLINK-7][iteration] Add bounded all-round iteration
Browse files Browse the repository at this point in the history
  • Loading branch information
gaoyunhaii committed Oct 25, 2021
1 parent fed9dca commit 5a8880b
Show file tree
Hide file tree
Showing 6 changed files with 464 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ public static IterationConfigBuilder newBuilder() {
return new IterationConfigBuilder();
}

public OperatorLifeCycle getOperatorLifeCycle() {
return operatorLifeCycle;
}

/** The builder of the {@link IterationConfig}. */
public static class IterationConfigBuilder {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.flink.iteration;

import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.iteration.compile.DraftExecutionEnvironment;
import org.apache.flink.iteration.operator.HeadOperatorFactory;
Expand All @@ -31,6 +32,7 @@
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.transformations.OneInputTransformation;
import org.apache.flink.util.OutputTag;

import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -109,11 +111,78 @@ public static DataStreamList createIteration(
mayHaveCriteria || iterationBodyResult.getTerminationCriteria() == null,
"The current iteration type does not support the termination criteria.");

// TODO: will consider the termination criteria in the next.
if (iterationBodyResult.getTerminationCriteria() != null) {
addCriteriaStream(
iterationBodyResult.getTerminationCriteria(),
iterationId,
env,
draftEnv,
initVariableStreams,
headStreams,
totalInitVariableParallelism);
}

return addOutputs(getActualDataStreams(iterationBodyResult.getOutputStreams(), draftEnv));
}

private static void addCriteriaStream(
DataStream<?> draftCriteriaStream,
IterationID iterationId,
StreamExecutionEnvironment env,
DraftExecutionEnvironment draftEnv,
DataStreamList initVariableStreams,
DataStreamList headStreams,
int totalInitVariableParallelism) {
// deal with the criteria streams
DataStream<?> terminationCriteria = draftEnv.getActualStream(draftCriteriaStream.getId());
// It should always has the IterationRecordTypeInfo
checkState(
terminationCriteria.getType().getClass().equals(IterationRecordTypeInfo.class),
"The termination criteria should always returns IterationRecord.");
TypeInformation<?> innerType =
((IterationRecordTypeInfo<?>) terminationCriteria.getType()).getInnerTypeInfo();

DataStream<?> emptyCriteriaSource =
env.addSource(new DraftExecutionEnvironment.EmptySource())
.returns(innerType)
.name(terminationCriteria.getTransformation().getName())
.setParallelism(terminationCriteria.getParallelism());
DataStreamList criteriaSources = DataStreamList.of(emptyCriteriaSource);
DataStreamList criteriaInputs = addInputs(criteriaSources, false);
DataStreamList criteriaHeaders =
addHeads(
criteriaSources,
criteriaInputs,
iterationId,
totalInitVariableParallelism,
true,
initVariableStreams.size());
DataStreamList criteriaTails =
addTails(
DataStreamList.of(terminationCriteria),
iterationId,
initVariableStreams.size());

String coLocationGroupKey = "co-" + iterationId.toHexString() + "-cri";
criteriaHeaders.get(0).getTransformation().setCoLocationGroupKey(coLocationGroupKey);
criteriaTails.get(0).getTransformation().setCoLocationGroupKey(coLocationGroupKey);

// Since co-located task must be in the same region, we will have to add a fake op.
((SingleOutputStreamOperator<?>) criteriaHeaders.get(0))
.getSideOutput(new OutputTag<IterationRecord<Integer>>("fake") {})
.union(
((SingleOutputStreamOperator<?>) criteriaTails.get(0))
.getSideOutput(new OutputTag<IterationRecord<Integer>>("fake") {}))
.map(x -> x)
.returns(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
.name("criteria-discard")
.setParallelism(1);

// Now we notify all the head operators to count the criteria stream.
setCriteriaParallelism(headStreams, terminationCriteria.getParallelism());
setCriteriaParallelism(criteriaHeaders, terminationCriteria.getParallelism());
}

private static List<TypeInformation<?>> getTypeInfos(DataStreamList dataStreams) {
return map(dataStreams, DataStream::getType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import org.apache.flink.annotation.Experimental;
import org.apache.flink.iteration.operator.allround.AllRoundOperatorWrapper;
import org.apache.flink.util.Preconditions;

/**
* A helper class to create iterations. To construct an iteration, Users are required to provide
Expand Down Expand Up @@ -111,6 +112,15 @@ public static DataStreamList iterateBoundedStreamsUntilTermination(
ReplayableDataStreamList dataStreams,
IterationConfig config,
IterationBody body) {
return null;
Preconditions.checkArgument(
config.getOperatorLifeCycle() == IterationConfig.OperatorLifeCycle.ALL_ROUND);
Preconditions.checkArgument(dataStreams.getReplayedDataStreams().size() == 0);

return IterationFactory.createIteration(
initVariableStreams,
new DataStreamList(dataStreams.getNonReplayedStreams()),
body,
new AllRoundOperatorWrapper(),
true);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,106 @@ public void processElement2(
assertSame(vertices.get(3).getCoLocationGroup(), vertices.get(7).getCoLocationGroup());
assertSame(vertices.get(4).getCoLocationGroup(), vertices.get(9).getCoLocationGroup());
}

@Test
public void testBoundedIterationWithTerminationCriteria() {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
DataStream<Integer> variableSource1 =
env.addSource(new DraftExecutionEnvironment.EmptySource<Integer>() {})
.setParallelism(2)
.name("Variable0");
DataStream<Integer> variableSource2 =
env.addSource(new DraftExecutionEnvironment.EmptySource<Integer>() {})
.setParallelism(2)
.name("Variable1");

DataStream<Integer> constantSource =
env.addSource(new DraftExecutionEnvironment.EmptySource<Integer>() {})
.setParallelism(3)
.name("Constant");

DataStreamList result =
Iterations.iterateBoundedStreamsUntilTermination(
DataStreamList.of(variableSource1, variableSource2),
ReplayableDataStreamList.notReplay(constantSource),
IterationConfig.newBuilder().build(),
(variableStreams, dataStreams) -> {
SingleOutputStreamOperator<Integer> processor =
variableStreams
.<Integer>get(0)
.union(variableStreams.<Integer>get(1))
.connect(dataStreams.<Integer>get(0))
.process(
new CoProcessFunction<
Integer, Integer, Integer>() {
@Override
public void processElement1(
Integer value,
Context ctx,
Collector<Integer> out)
throws Exception {}

@Override
public void processElement2(
Integer value,
Context ctx,
Collector<Integer> out)
throws Exception {}
})
.name("Processor")
.setParallelism(4);

return new IterationBodyResult(
DataStreamList.of(
processor
.map(x -> x)
.name("Feedback0")
.setParallelism(2),
processor
.map(x -> x)
.name("Feedback1")
.setParallelism(3)),
DataStreamList.of(
processor.getSideOutput(
new OutputTag<Integer>("output") {})),
processor.map(x -> x).name("Termination").setParallelism(5));
});
result.get(0).addSink(new DiscardingSink<>()).name("Sink").setParallelism(4);

List<String> expectedVertexNames =
Arrays.asList(
/* 0 */ "Source: Variable0 -> input-Variable0",
/* 1 */ "Source: Variable1 -> input-Variable1",
/* 2 */ "Source: Constant -> input-Constant",
/* 3 */ "Source: Termination -> input-Termination",
/* 4 */ "head-Variable0",
/* 5 */ "head-Variable1",
/* 6 */ "Processor -> output-SideOutput -> Sink: Sink",
/* 7 */ "Feedback0",
/* 8 */ "tail-Feedback0",
/* 9 */ "Feedback1",
/* 10 */ "tail-Feedback1",
/* 11 */ "Termination",
/* 12 */ "tail-Termination",
/* 13 */ "head-Termination",
/* 14 */ "criteria-discard");
List<Integer> expectedParallelisms =
Arrays.asList(2, 2, 3, 5, 2, 2, 4, 2, 2, 3, 3, 5, 5, 5, 1);

JobGraph jobGraph = env.getStreamGraph().getJobGraph();
List<JobVertex> vertices = jobGraph.getVerticesSortedTopologicallyFromSources();
assertEquals(
expectedVertexNames,
vertices.stream().map(JobVertex::getName).collect(Collectors.toList()));
assertEquals(
expectedParallelisms,
vertices.stream().map(JobVertex::getParallelism).collect(Collectors.toList()));

assertNotNull(vertices.get(4).getCoLocationGroup());
assertNotNull(vertices.get(5).getCoLocationGroup());
assertNotNull(vertices.get(13).getCoLocationGroup());
assertSame(vertices.get(4).getCoLocationGroup(), vertices.get(8).getCoLocationGroup());
assertSame(vertices.get(5).getCoLocationGroup(), vertices.get(10).getCoLocationGroup());
assertSame(vertices.get(13).getCoLocationGroup(), vertices.get(12).getCoLocationGroup());
}
}

0 comments on commit 5a8880b

Please sign in to comment.