Skip to content

Commit

Permalink
[FLINK-4460] Catch side output OutputTag id clashes
Browse files Browse the repository at this point in the history
This also adds tests.
  • Loading branch information
aljoscha committed Mar 17, 2017
1 parent b6afef3 commit 20d8d67
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
import org.apache.flink.util.Preconditions;

import java.util.HashMap;
import java.util.Map;

import static java.util.Objects.requireNonNull;

/**
Expand All @@ -46,6 +49,13 @@ public class SingleOutputStreamOperator<T> extends DataStream<T> {
/** Indicate this is a non-parallel operator and cannot set a non-1 degree of parallelism. **/
protected boolean nonParallel = false;

/**
* We keep track of the side outputs that were already requested and their types. With this,
* we can catch the case when a side output with a matching id is requested for a different
* type because this would lead to problems at runtime.
*/
private Map<OutputTag<?>, TypeInformation> requestedSideOutputs = new HashMap<>();

protected SingleOutputStreamOperator(StreamExecutionEnvironment environment, StreamTransformation<T> transformation) {
super(environment, transformation);
}
Expand Down Expand Up @@ -425,9 +435,22 @@ public SingleOutputStreamOperator<T> slotSharingGroup(String slotSharingGroup) {
*
* @see org.apache.flink.streaming.api.functions.ProcessFunction.Context#output(OutputTag, Object)
*/
public <X> DataStream<X> getSideOutput(OutputTag<X> sideOutputTag){
sideOutputTag = clean(sideOutputTag);
SideOutputTransformation<X> sideOutputTransformation = new SideOutputTransformation<>(this.getTransformation(), requireNonNull(sideOutputTag));
public <X> DataStream<X> getSideOutput(OutputTag<X> sideOutputTag) {
sideOutputTag = clean(requireNonNull(sideOutputTag));

// make a defensive copy
sideOutputTag = new OutputTag<X>(sideOutputTag.getId(), sideOutputTag.getTypeInfo());

TypeInformation<?> type = requestedSideOutputs.get(sideOutputTag);
if (type != null && !type.equals(sideOutputTag.getTypeInfo())) {
throw new UnsupportedOperationException("A side output with a matching id was " +
"already requested with a different type. This is not allowed, side output " +
"ids need to be unique.");
}

requestedSideOutputs.put(sideOutputTag, sideOutputTag.getTypeInfo());

SideOutputTransformation<X> sideOutputTransformation = new SideOutputTransformation<>(this.getTransformation(), sideOutputTag);
return new DataStream<>(this.getExecutionEnvironment(), sideOutputTransformation);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,23 @@ public void addVirtualSideOutputNode(Integer originalId, Integer virtualId, Outp
throw new IllegalStateException("Already has virtual output node with id " + virtualId);
}

// verify that we don't already have a virtual node for the given originalId/outputTag
// combination with a different TypeInformation. This would indicate that someone is trying
// to read a side output from an operation with a different type for the same side output
// id.

for (Tuple2<Integer, OutputTag> tag : virtualSideOutputNodes.values()) {
if (!tag.f0.equals(originalId)) {
// different source operator
continue;
}

if (!tag.f1.getTypeInfo().equals(outputTag.getTypeInfo())) {
throw new IllegalArgumentException("Trying to add a side input for the same id " +
"with a different type. This is not allowed.");
}
}

virtualSideOutputNodes.put(virtualId, new Tuple2<>(originalId, outputTag));
}

Expand Down Expand Up @@ -356,7 +373,8 @@ public void addEdge(Integer upStreamVertexID, Integer downStreamVertexID, int ty
downStreamVertexID,
typeNumber,
null,
new ArrayList<String>(), null);
new ArrayList<String>(),
null);

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@
import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase;
import org.apache.flink.test.streaming.runtime.util.TestListResultSink;
import org.apache.flink.util.Collector;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;

import javax.annotation.Nullable;
import java.io.Serializable;
Expand All @@ -56,6 +58,9 @@
*/
public class SideOutputITCase extends StreamingMultipleProgramsTestBase implements Serializable {

@Rule
public transient ExpectedException expectedException = ExpectedException.none();

static List<Integer> elements = new ArrayList<>();
static {
elements.add(1);
Expand All @@ -65,14 +70,14 @@ public class SideOutputITCase extends StreamingMultipleProgramsTestBase implemen
elements.add(4);
}

private final static OutputTag<String> sideOutputTag1 = new OutputTag<String>("side"){};
private final static OutputTag<String> sideOutputTag2 = new OutputTag<String>("other-side"){};

/**
* Verify that watermarks are forwarded to all side outputs.
*/
@Test
public void testWatermarkForwarding() throws Exception {
final OutputTag<String> sideOutputTag1 = new OutputTag<String>("side"){};
final OutputTag<String> sideOutputTag2 = new OutputTag<String>("other-side"){};

TestListResultSink<String> sideOutputResultSink1 = new TestListResultSink<>();
TestListResultSink<String> sideOutputResultSink2 = new TestListResultSink<>();
TestListResultSink<String> resultSink = new TestListResultSink<>();
Expand Down Expand Up @@ -167,6 +172,8 @@ public String map(Integer value) throws Exception {

@Test
public void testSideOutputWithMultipleConsumers() throws Exception {
final OutputTag<String> sideOutputTag = new OutputTag<String>("side"){};

TestListResultSink<String> sideOutputResultSink1 = new TestListResultSink<>();
TestListResultSink<String> sideOutputResultSink2 = new TestListResultSink<>();
TestListResultSink<Integer> resultSink = new TestListResultSink<>();
Expand All @@ -184,12 +191,12 @@ public void testSideOutputWithMultipleConsumers() throws Exception {
public void processElement(
Integer value, Context ctx, Collector<Integer> out) throws Exception {
out.collect(value);
ctx.output(sideOutputTag1, "sideout-" + String.valueOf(value));
ctx.output(sideOutputTag, "sideout-" + String.valueOf(value));
}
});

passThroughtStream.getSideOutput(sideOutputTag1).addSink(sideOutputResultSink1);
passThroughtStream.getSideOutput(sideOutputTag1).addSink(sideOutputResultSink2);
passThroughtStream.getSideOutput(sideOutputTag).addSink(sideOutputResultSink1);
passThroughtStream.getSideOutput(sideOutputTag).addSink(sideOutputResultSink2);
passThroughtStream.addSink(resultSink);
env.execute();

Expand All @@ -200,6 +207,8 @@ public void processElement(

@Test
public void testSideOutputWithMultipleConsumersWithObjectReuse() throws Exception {
final OutputTag<String> sideOutputTag = new OutputTag<String>("side"){};

TestListResultSink<String> sideOutputResultSink1 = new TestListResultSink<>();
TestListResultSink<String> sideOutputResultSink2 = new TestListResultSink<>();
TestListResultSink<Integer> resultSink = new TestListResultSink<>();
Expand All @@ -218,12 +227,12 @@ public void testSideOutputWithMultipleConsumersWithObjectReuse() throws Exceptio
public void processElement(
Integer value, Context ctx, Collector<Integer> out) throws Exception {
out.collect(value);
ctx.output(sideOutputTag1, "sideout-" + String.valueOf(value));
ctx.output(sideOutputTag, "sideout-" + String.valueOf(value));
}
});

passThroughtStream.getSideOutput(sideOutputTag1).addSink(sideOutputResultSink1);
passThroughtStream.getSideOutput(sideOutputTag1).addSink(sideOutputResultSink2);
passThroughtStream.getSideOutput(sideOutputTag).addSink(sideOutputResultSink1);
passThroughtStream.getSideOutput(sideOutputTag).addSink(sideOutputResultSink2);
passThroughtStream.addSink(resultSink);
env.execute();

Expand All @@ -232,13 +241,13 @@ public void processElement(
assertEquals(Arrays.asList(1, 2, 3, 4, 5), resultSink.getSortedResult());
}

/**
* Test ProcessFunction side output.
*/
@Test
public void testProcessFunctionSideOutput() throws Exception {
TestListResultSink<String> sideOutputResultSink = new TestListResultSink<>();
TestListResultSink<Integer> resultSink = new TestListResultSink<>();
public void testSideOutputNameClash() throws Exception {
final OutputTag<String> sideOutputTag1 = new OutputTag<String>("side"){};
final OutputTag<Integer> sideOutputTag2 = new OutputTag<Integer>("side"){};

TestListResultSink<String> sideOutputResultSink1 = new TestListResultSink<>();
TestListResultSink<Integer> sideOutputResultSink2 = new TestListResultSink<>();

StreamExecutionEnvironment see = StreamExecutionEnvironment.getExecutionEnvironment();
see.setParallelism(3);
Expand All @@ -254,22 +263,23 @@ public void processElement(
Integer value, Context ctx, Collector<Integer> out) throws Exception {
out.collect(value);
ctx.output(sideOutputTag1, "sideout-" + String.valueOf(value));
ctx.output(sideOutputTag2, 13);
}
});

passThroughtStream.getSideOutput(sideOutputTag1).addSink(sideOutputResultSink);
passThroughtStream.addSink(resultSink);
see.execute();
passThroughtStream.getSideOutput(sideOutputTag1).addSink(sideOutputResultSink1);

assertEquals(Arrays.asList("sideout-1", "sideout-2", "sideout-3", "sideout-4", "sideout-5"), sideOutputResultSink.getSortedResult());
assertEquals(Arrays.asList(1, 2, 3, 4, 5), resultSink.getSortedResult());
expectedException.expect(UnsupportedOperationException.class);
passThroughtStream.getSideOutput(sideOutputTag2).addSink(sideOutputResultSink2);
}

/**
* Test keyed ProcessFunction side output.
* Test ProcessFunction side output.
*/
@Test
public void testKeyedProcessFunctionSideOutput() throws Exception {
public void testProcessFunctionSideOutput() throws Exception {
final OutputTag<String> sideOutputTag = new OutputTag<String>("side"){};

TestListResultSink<String> sideOutputResultSink = new TestListResultSink<>();
TestListResultSink<Integer> resultSink = new TestListResultSink<>();

Expand All @@ -279,68 +289,77 @@ public void testKeyedProcessFunctionSideOutput() throws Exception {
DataStream<Integer> dataStream = see.fromCollection(elements);

SingleOutputStreamOperator<Integer> passThroughtStream = dataStream
.keyBy(new KeySelector<Integer, Integer>() {
private static final long serialVersionUID = 1L;

@Override
public Integer getKey(Integer value) throws Exception {
return value;
}
})
.process(new ProcessFunction<Integer, Integer>() {
private static final long serialVersionUID = 1L;

@Override
public void processElement(
Integer value, Context ctx, Collector<Integer> out) throws Exception {
out.collect(value);
ctx.output(sideOutputTag1, "sideout-" + String.valueOf(value));
ctx.output(sideOutputTag, "sideout-" + String.valueOf(value));
}
});

passThroughtStream.getSideOutput(sideOutputTag1).addSink(sideOutputResultSink);
passThroughtStream.getSideOutput(sideOutputTag).addSink(sideOutputResultSink);
passThroughtStream.addSink(resultSink);
see.execute();

assertEquals(Arrays.asList("sideout-1", "sideout-2", "sideout-3", "sideout-4", "sideout-5"), sideOutputResultSink.getSortedResult());
assertEquals(Arrays.asList(1, 2, 3, 4, 5), resultSink.getSortedResult());
}


/**
* Test ProcessFunction side outputs with wrong {@code OutputTag}.
* Test keyed ProcessFunction side output.
*/
@Test
public void testProcessFunctionSideOutputWithWrongTag() throws Exception {
public void testKeyedProcessFunctionSideOutput() throws Exception {
final OutputTag<String> sideOutputTag = new OutputTag<String>("side"){};

TestListResultSink<String> sideOutputResultSink = new TestListResultSink<>();
TestListResultSink<Integer> resultSink = new TestListResultSink<>();

StreamExecutionEnvironment see = StreamExecutionEnvironment.getExecutionEnvironment();
see.setParallelism(3);

DataStream<Integer> dataStream = see.fromCollection(elements);

dataStream
SingleOutputStreamOperator<Integer> passThroughtStream = dataStream
.keyBy(new KeySelector<Integer, Integer>() {
private static final long serialVersionUID = 1L;

@Override
public Integer getKey(Integer value) throws Exception {
return value;
}
})
.process(new ProcessFunction<Integer, Integer>() {
private static final long serialVersionUID = 1L;

@Override
public void processElement(
Integer value, Context ctx, Collector<Integer> out) throws Exception {
out.collect(value);
ctx.output(sideOutputTag2, "sideout-" + String.valueOf(value));
ctx.output(sideOutputTag, "sideout-" + String.valueOf(value));
}
}).getSideOutput(sideOutputTag1).addSink(sideOutputResultSink);
});

passThroughtStream.getSideOutput(sideOutputTag).addSink(sideOutputResultSink);
passThroughtStream.addSink(resultSink);
see.execute();

assertEquals(Arrays.asList(), sideOutputResultSink.getSortedResult());
assertEquals(Arrays.asList("sideout-1", "sideout-2", "sideout-3", "sideout-4", "sideout-5"), sideOutputResultSink.getSortedResult());
assertEquals(Arrays.asList(1, 2, 3, 4, 5), resultSink.getSortedResult());
}


/**
* Test keyed ProcessFunction side outputs with wrong {@code OutputTag}.
* Test ProcessFunction side outputs with wrong {@code OutputTag}.
*/
@Test
public void testKeyedProcessFunctionSideOutputWithWrongTag() throws Exception {
public void testProcessFunctionSideOutputWithWrongTag() throws Exception {
final OutputTag<String> sideOutputTag1 = new OutputTag<String>("side"){};
final OutputTag<String> sideOutputTag2 = new OutputTag<String>("other-side"){};

TestListResultSink<String> sideOutputResultSink = new TestListResultSink<>();

StreamExecutionEnvironment see = StreamExecutionEnvironment.getExecutionEnvironment();
Expand All @@ -349,14 +368,6 @@ public void testKeyedProcessFunctionSideOutputWithWrongTag() throws Exception {
DataStream<Integer> dataStream = see.fromCollection(elements);

dataStream
.keyBy(new KeySelector<Integer, Integer>() {
private static final long serialVersionUID = 1L;

@Override
public Integer getKey(Integer value) throws Exception {
return value;
}
})
.process(new ProcessFunction<Integer, Integer>() {
private static final long serialVersionUID = 1L;

Expand All @@ -373,7 +384,6 @@ public void processElement(
assertEquals(Arrays.asList(), sideOutputResultSink.getSortedResult());
}


private static class TestWatermarkAssigner implements AssignerWithPunctuatedWatermarks<Integer> {
private static final long serialVersionUID = 1L;

Expand Down

0 comments on commit 20d8d67

Please sign in to comment.