From 8535b22a7cf1545381bcd7b5febc0d4763f925aa Mon Sep 17 00:00:00 2001 From: wind Date: Mon, 11 Jun 2018 20:40:19 +0800 Subject: [PATCH] add TestingRuntimeContext --- docs/dev/udf_test.md | 423 ++++++++++++++++++ .../util/test/SimpleAggregatingState.java | 51 +++ .../util/test/SimpleFoldingState.java | 51 +++ .../functions/util/test/SimpleListState.java | 68 +++ .../functions/util/test/SimpleMapState.java | 93 ++++ .../util/test/SimpleReducingState.java | 50 +++ .../functions/util/test/SimpleValueState.java | 50 +++ .../functions/util/test/TestingCollector.java | 45 ++ .../util/test/TestingRuntimeContext.java | 383 ++++++++++++++++ .../util/test/TestingRuntimeContextTest.java | 159 +++++++ .../test/TestingRuntimeContextTest.java | 175 ++++++++ 11 files changed, 1548 insertions(+) create mode 100644 docs/dev/udf_test.md create mode 100644 flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/SimpleAggregatingState.java create mode 100644 flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/SimpleFoldingState.java create mode 100644 flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/SimpleListState.java create mode 100644 flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/SimpleMapState.java create mode 100644 flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/SimpleReducingState.java create mode 100644 flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/SimpleValueState.java create mode 100644 flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/TestingCollector.java create mode 100644 flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/TestingRuntimeContext.java create mode 100644 flink-core/src/test/java/org/apache/flink/api/common/functions/util/test/TestingRuntimeContextTest.java create mode 100644 flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/test/TestingRuntimeContextTest.java diff --git a/docs/dev/udf_test.md b/docs/dev/udf_test.md new file mode 100644 index 0000000000000..c5a103d28f537 --- /dev/null +++ b/docs/dev/udf_test.md @@ -0,0 +1,423 @@ +--- +title: "User-Defined Function Testing" +nav-parent_id: dev +nav-pos: 40 +nav-show_overview: true +nav-id: test-udf +--- + +If you're using some transform operators whose implemented +functions don't have any return like "FlatMapFunction", or "Co +ProcessFunction" or funtions you define on your own, you may +want to use TestingRuntimeContext to test it. +{% highlight java %} +/** + * Initialization + * + * @param isStreaming: if set to true, this is used for DataStream Functions. + * if set to false, this is used for DataSet Functions. + */ +TestingRuntimeContext ctx = new TestingRuntimeContext(isStreaming); +{% endhighlight %} +### DataSet Functions +Assume we're using a WordDistinctFlatMap to remove duplicate words for each line in a textbook. +
+
+{% highlight java %} +public static class WordDistinctFlat extends RichFlatMapFunction { + + @Override + public void flatMap(String value, Collector out) throws Exception { + Set wordsList = new HashSet<>(Arrays.asList(value.split(","))); + for (String word: wordsList) { + out.collect(word); + } + } +} +{% endhighlight %} +
+ +
+{% highlight scala %} +class WordDistinctFlat extends RichFlatMapFunction[String, String] { + override def flatMap(value: String, out: Collector[String]): Unit = { + val wordsList = value.split(",").toSet + wordsList.foreach(word => { + out.collect(word) + }) + } +} +{% endhighlight %} +
+
+It's easy to find that we can't test it directly like what we usually do in MapFunction because the results are collected +by flink's Collector. Now we use TestingRuntimeContext to test if the logic is right inside the flatMap. +
+
+{% highlight java %} +@Test +public void testWordDistinctFlat() throws Exception { + // "false" means that this is not a datastream function. + TestingRuntimeContext ctx = new TestingRuntimeContext(false); + WordDistinctFlat flat = new WordDistinctFlat(); + flat.setRuntimeContext(ctx); + flat.flatMap("Eat,Eat,Walk,Run", ctx.getCollector()); + Assert.assertArrayEquals(ctx.getCollectorOutput().toArray(), new String[]{"Eat", "Walk", "Run"}); +} +{% endhighlight %} +
+ +
+{% highlight scala %} +@Test +def testWordDistinctFlat(): Unit = { + import scala.collection.JavaConverters._ + val ctx = new TestingRuntimeContext(false) + val flat = new WordDistinctFlat() + flat.setRuntimeContext(ctx) + flat.flatMap("Eat,Eat,Walk,Run", ctx.getCollector()) + Assert.assertEquals(ctx.getCollectorOutput.asScala.toArray.deep, Array("Eat", "Walk", "Run").deep) +} +{% endhighlight %} +
+
+ +Then we want to exclude some specific words by broadcasting them, the codes will be like this. +
+
+{% highlight java %} +@Test +public void testWordDistinctFlat() throws Exception { + TestingRuntimeContext ctx = new TestingRuntimeContext(false); + ctx.setBroadcastVariable("exWords", Collections.singletonList("Eat")); + WordDistinctFlat flat = new WordDistinctFlat(); + flat.setRuntimeContext(ctx); + flat.flatMap("Eat,Eat,Walk,Run", ctx.getCollector()); + Assert.assertArrayEquals(ctx.getCollectorOutput().toArray(), new String[]{"Walk", "Run"}); +} + +public static class WordDistinctFlat extends RichFlatMapFunction { + + @Override + public void flatMap(String value, Collector out) throws Exception { + Set wordsList = new HashSet<>(Arrays.asList(value.split(","))); + List excludeWords = getRuntimeContext().getBroadcastVariable("exWords"); + for (String word: wordsList) { + if (!excludeWords.contains(word)) { + out.collect(word); + } + } + } +} +{% endhighlight %} +
+ +
+{% highlight scala %} +@Test +def testWordDistinctFlat(): Unit = { + import scala.collection.JavaConverters._ + val ctx = new TestingRuntimeContext(false) + ctx.setBroadcastVariable("exWords", Collections.singletonList("Eat")) + val flat = new WordDistinctFlat() + flat.setRuntimeContext(ctx) + flat.flatMap("Eat,Eat,Walk,Run", ctx.getCollector()) + Assert.assertEquals(ctx.getCollectorOutput.asScala.toArray.deep, Array("Walk", "Run").deep) +} + +class WordDistinctFlat extends RichFlatMapFunction[String, String] { + override def flatMap(value: String, out: Collector[String]): Unit = { + val wordsList = value.split(",").toSet + val exclude = getRuntimeContext.getBroadcastVariable("exWords") + wordsList.foreach(word => { + if (!exclude.contains(word)) { + out.collect(word) + } + }) + } +} +{% endhighlight %} +
+
+ + +### DataStream Functions +For most functions in DataStream, you can easily test them in the same way that you test with DataSet +functions, the difference is the **State** in datastream, which is used very often in user defined functions. And we've +already offered some simple states to be used in testing. Now we can show you how to use these things +in a real user defined function. +Let's assume that we're going to build a system for attributing taxis' rides and fare, so what we need is to write an user-defined function +to join rides with fare. The source code of this function is from [JoinRidesWithFares](https://github.com/dataArtisans/flink-training-exercises/blob/master/src/main/java/com/dataartisans/flinktraining/exercises/datastream_java/process/JoinRidesWithFares.java). +
+
+{% highlight java %} +static final OutputTag unmatchedRides = new OutputTag("unmatchedRides") {}; +static final OutputTag unmatchedFares = new OutputTag("unmatchedFares") {}; + +static class TaxiRide { + private long eventTime; + + TaxiRide(long eventTime) { this.eventTime = eventTime; } + + Long getEventTime() { return eventTime; } +} + +static class TaxiFare { + private long eventTime; + TaxiFare(long eventTime) { this.eventTime = eventTime; } + Long getEventTime() { return eventTime; } +} + +public static class EnrichmentFunction extends CoProcessFunction> { + // keyed, managed state + private ValueState rideState; + private ValueState fareState; + + @Override + public void open(Configuration config) { + rideState = getRuntimeContext().getState(new ValueStateDescriptor<>("saved ride", TaxiRide.class)); + fareState = getRuntimeContext().getState(new ValueStateDescriptor<>("saved fare", TaxiFare.class)); + } + + @Override + public void onTimer(long timestamp, OnTimerContext ctx, Collector> out) throws Exception { + if (fareState.value() != null) { + ctx.output(unmatchedFares, fareState.value()); + fareState.clear(); + } + if (rideState.value() != null) { + ctx.output(unmatchedRides, rideState.value()); + rideState.clear(); + } + } + + @Override + public void processElement1(TaxiRide ride, Context context, Collector> out) throws Exception { + TaxiFare fare = fareState.value(); + if (fare != null) { + fareState.clear(); + out.collect(new Tuple2(ride, fare)); + } else { + rideState.update(ride); + // as soon as the watermark arrives, we can stop waiting for the corresponding fare + context.timerService().registerEventTimeTimer(ride.getEventTime()); + } + } + + @Override + public void processElement2(TaxiFare fare, Context context, Collector> out) throws Exception { + TaxiRide ride = rideState.value(); + if (ride != null) { + rideState.clear(); + out.collect(new Tuple2(ride, fare)); + } else { + fareState.update(fare); + // wait up to 6 hours for the corresponding ride END event, then clear the state + context.timerService().registerEventTimeTimer(fare.getEventTime() + 6 * 60 * 60 * 1000); + } + } +} +{% endhighlight %} +
+
+{% highlight scala %} +val unmatchedRides: OutputTag[TestingRuntimeContextTest.TaxiRide] = new OutputTag[TestingRuntimeContextTest.TaxiRide]("unmatchedRides") {} +val unmatchedFares: OutputTag[TestingRuntimeContextTest.TaxiFare] = new OutputTag[TestingRuntimeContextTest.TaxiFare]("unmatchedFares") {} + +class TaxiRide(eventTime: Long) { + def getEventTime: Long = eventTime +} + +class TaxiFare(eventTime: Long) { + def getEventTime: Long = eventTime +} + +class EnrichmentFunction extends CoProcessFunction[TaxiRide, TaxiFare, Tuple2[TaxiRide, TaxiFare]] { // keyed, managed state + var rideState: ValueState[TaxiRide] = _ + var fareState: ValueState[TaxiFare] = _ + + override def open(config: Configuration): Unit = { + rideState = getRuntimeContext.getState(new ValueStateDescriptor[TaxiRide]("saved ride", classOf[TaxiRide])) + fareState = getRuntimeContext.getState(new ValueStateDescriptor[TaxiFare]("saved fare", classOf[TaxiFare])) + } + + override def onTimer(timestamp: Long, ctx: CoProcessFunction[TaxiRide, TaxiFare, Tuple2[TaxiRide, TaxiFare]]#OnTimerContext, out: Collector[Tuple2[TaxiRide, TaxiFare]]): Unit = { + if (fareState.value != null) { + ctx.output(unmatchedFares, fareState.value) + fareState.clear() + } + if (rideState.value != null) { + ctx.output(unmatchedRides, rideState.value) + rideState.clear() + } + } + + override def processElement1(ride: TaxiRide, context: CoProcessFunction[TaxiRide, TaxiFare, Tuple2[TaxiRide, TaxiFare]]#Context, out: Collector[Tuple2[TaxiRide, TaxiFare]]): Unit = { + val fare = fareState.value + if (fare != null) { + fareState.clear() + out.collect(new Tuple2(ride, fare)) + } + else { + rideState.update(ride) + // as soon as the watermark arrives, we can stop waiting for the corresponding fare + context.timerService.registerEventTimeTimer(ride.getEventTime) + } + } + + override def processElement2(fare: TaxiFare, context: CoProcessFunction[TaxiRide, TaxiFare, Tuple2[TaxiRide, TaxiFare]]#Context, out: Collector[Tuple2[TaxiRide, TaxiFare]]): Unit = { + val ride = rideState.value + if (ride != null) { + rideState.clear() + out.collect(new Tuple2(ride, fare)) + } + else { + fareState.update(fare) + // wait up to 6 hours for the corresponding ride END event, then clear the state + context.timerService.registerEventTimeTimer(fare.getEventTime + 6 * 60 * 60 * 1000) + } + } +} +{% endhighlight %} +
+
+Let's try some tests for the EnrichmentFunction. +
+
+{% highlight java %} +@Test +public void testEnrichmentFunction() throws Exception { + TestingRuntimeContext ctx = new TestingRuntimeContext(true); + EnrichmentFunction func = new EnrichmentFunction(); + func.setRuntimeContext(ctx); + + // have to mannually mock the context inside the function. + CoProcessFunction.Context context = mock(EnrichmentFunction.Context.class); + CoProcessFunction.OnTimerContext timerContext = mock(EnrichmentFunction.OnTimerContext.class); + TimerService timerService = mock(TimerService.class); + doAnswer(invocationOnMock -> { + OutputTag outputTag = invocationOnMock.getArgumentAt(0, OutputTag.class); + Object value = invocationOnMock.getArgumentAt(1, Object.class); + ctx.addSideOutput(outputTag, value); + return null; + }).when(timerContext).output(any(OutputTag.class), any()); + doReturn(timerService).when(context).timerService(); + doNothing().when(timerService).registerEventTimeTimer(anyLong()); + + // use simple states inside "org.apache.flink.api.common.functions.util.test" + ValueStateDescriptor rideStateDesc = new ValueStateDescriptor<>("saved ride", TaxiRide.class); + ValueStateDescriptor fareStateDesc = new ValueStateDescriptor<>("saved fare", TaxiFare.class); + ctx.setState(rideStateDesc, new SimpleValueState<>(null)); + ctx.setState(fareStateDesc, new SimpleValueState(null)); + func.open(new Configuration()); + + // receive the first taxi ride. + TaxiRide ride1 = new TaxiRide(1); + func.processElement1(ride1, context, ctx.getCollector()); + Assert.assertEquals(ctx.getState(rideStateDesc).value(), ride1); + + // receive the first taxi fare and do the attribution. + TaxiFare fare1 = new TaxiFare(1); + func.processElement2(fare1, context, ctx.getCollector()); + Assert.assertEquals(ctx.getState(rideStateDesc).value(), null); + Assert.assertEquals(ctx.getCollectorOutput(), Collections.singletonList(new Tuple2(ride1, fare1))); + + // receive the second taxi fare. + TaxiFare fare2 = new TaxiFare(2); + func.processElement2(fare2, context, ctx.getCollector()); + Assert.assertEquals(ctx.getState(fareStateDesc).value(), fare2); + + func.onTimer(0L, timerContext, ctx.getCollector()); + Assert.assertEquals(Collections.singletonList(fare2), ctx.getSideOutput(unmatchedFares)); +} +{% endhighlight %} +
+ +
+{% highlight scala %} +@Test +def testEnrichmentFunction(): Unit = { + val ctx = new TestingRuntimeContext(true) + val func = new EnrichmentFunction + func.setRuntimeContext(ctx) + + // have to mannually mock the context inside the function. + val context = mock(classOf[CoProcessFunction[TaxiRide, TaxiFare, Tuple2[TaxiRide, TaxiFare]]#Context]) + val timerContext = mock(classOf[CoProcessFunction[TaxiRide, TaxiFare, Tuple2[TaxiRide, TaxiFare]]#OnTimerContext]) + val timerService = mock(classOf[TimerService]) + doAnswer(new Answer[Any] { + override def answer(invocationOnMock: InvocationOnMock) = { + val outputTag = invocationOnMock.getArgumentAt(0, classOf[OutputTag[Any]]) + val value = invocationOnMock.getArgumentAt(1, classOf[Any]) + ctx.addSideOutput(outputTag, value) + null + } + }).when(timerContext).output(any(classOf[OutputTag[Any]]), any) + doReturn(timerService).when(context).timerService + doNothing().when(timerService).registerEventTimeTimer(anyLong) + + // use simple states inside "org.apache.flink.api.common.functions.util.test" + val rideStateDesc = new ValueStateDescriptor[TaxiRide]("saved ride", classOf[TaxiRide]) + val fareStateDesc = new ValueStateDescriptor[TaxiFare]("saved fare", classOf[TaxiFare]) + ctx.setState(rideStateDesc, new SimpleValueState[AnyRef](null)) + ctx.setState(fareStateDesc, new SimpleValueState[TaxiFare](null)) + func.open(new Configuration) + + // receive the first taxi ride. + val ride1 = new TaxiRide(1) + func.processElement1(ride1, context, ctx.getCollector()) + Assert.assertEquals(ctx.getState(rideStateDesc).value, ride1) + + // receive the first taxi fare and do the attribution. + val fare1 = new TaxiFare(1) + func.processElement2(fare1, context, ctx.getCollector()) + Assert.assertEquals(ctx.getState(rideStateDesc).value, null) + Assert.assertEquals(ctx.getCollectorOutput, Collections.singletonList(new Tuple2(ride1, fare1))) + + // receive the second taxi fare. + val fare2 = new TaxiFare(2) + func.processElement2(fare2, context, ctx.getCollector()) + Assert.assertEquals(ctx.getState(fareStateDesc).value, fare2) + + func.onTimer(0L, timerContext, ctx.getCollector()) + Assert.assertEquals(Collections.singletonList(fare2), ctx.getSideOutput(unmatchedFares)) +} +{% endhighlight %} +
+
+{% top %} + + + + + + + + + + + + + + + + + + diff --git a/flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/SimpleAggregatingState.java b/flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/SimpleAggregatingState.java new file mode 100644 index 0000000000000..5920e90884927 --- /dev/null +++ b/flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/SimpleAggregatingState.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.common.functions.util.test; + +import org.apache.flink.api.common.state.AggregatingState; +import org.apache.flink.api.common.state.AggregatingStateDescriptor; + +/** + * A simple {@link AggregatingState} for testing. + */ +public class SimpleAggregatingState implements AggregatingState { + + private AggregatingStateDescriptor descriptor; + private ACC accumulator; + + public SimpleAggregatingState(AggregatingStateDescriptor descriptor) { + this.descriptor = descriptor; + this.accumulator = this.descriptor.getAggregateFunction().createAccumulator(); + } + + @Override + public OUT get() throws Exception { + return descriptor.getAggregateFunction().getResult(accumulator); + } + + @Override + public void add(IN value) throws Exception { + this.accumulator = descriptor.getAggregateFunction().add(value, accumulator); + } + + @Override + public void clear() { + accumulator = null; + } +} diff --git a/flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/SimpleFoldingState.java b/flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/SimpleFoldingState.java new file mode 100644 index 0000000000000..34a3fe5979c20 --- /dev/null +++ b/flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/SimpleFoldingState.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.common.functions.util.test; + +import org.apache.flink.api.common.state.FoldingState; +import org.apache.flink.api.common.state.FoldingStateDescriptor; + +/** + * A simple {@link FoldingState} for testing. + */ +public class SimpleFoldingState implements FoldingState { + + private FoldingStateDescriptor descriptor; + private ACC accumulator; + + public SimpleFoldingState(FoldingStateDescriptor descriptor) { + this.descriptor = descriptor; + this.accumulator = descriptor.getDefaultValue(); + } + + @Override + public ACC get() throws Exception { + return accumulator; + } + + @Override + public void add(T value) throws Exception { + accumulator = descriptor.getFoldFunction().fold(accumulator, value); + } + + @Override + public void clear() { + accumulator = null; + } +} diff --git a/flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/SimpleListState.java b/flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/SimpleListState.java new file mode 100644 index 0000000000000..2c4b1dcdd19a6 --- /dev/null +++ b/flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/SimpleListState.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.common.functions.util.test; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.util.Preconditions; + +import java.util.ArrayList; +import java.util.List; + +/** + * A simple {@link ListState} for testing. + */ +public class SimpleListState implements ListState { + + private final List list = new ArrayList<>(); + + @Override + public void clear() { + list.clear(); + } + + @Override + public Iterable get() throws Exception { + return list; + } + + @Override + public void add(T value) throws Exception { + Preconditions.checkNotNull(value, "You cannot add null to a ListState."); + list.add(value); + } + + public List getList() { + return list; + } + + @Override + public void update(List values) throws Exception { + clear(); + + addAll(values); + } + + @Override + public void addAll(List values) throws Exception { + if (values != null) { + values.forEach(v -> Preconditions.checkNotNull(v, "You cannot add null to a ListState.")); + list.addAll(values); + } + } +} diff --git a/flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/SimpleMapState.java b/flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/SimpleMapState.java new file mode 100644 index 0000000000000..37ff4c9e519ad --- /dev/null +++ b/flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/SimpleMapState.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.common.functions.util.test; + +import org.apache.flink.api.common.state.MapState; + +import java.util.Collections; +import java.util.Iterator; +import java.util.Map; + +/** + * A simple {@link MapState} for testing. + */ +public class SimpleMapState implements MapState { private final MapState originalState; + + private final Map emptyState = Collections.emptyMap(); + + SimpleMapState(MapState originalState) { + this.originalState = originalState; + } + + // ------------------------------------------------------------------------ + + @Override + public V get(K key) throws Exception { + return originalState.get(key); + } + + @Override + public void put(K key, V value) throws Exception { + originalState.put(key, value); + } + + @Override + public void putAll(Map value) throws Exception { + originalState.putAll(value); + } + + @Override + public void clear() { + originalState.clear(); + } + + @Override + public void remove(K key) throws Exception { + originalState.remove(key); + } + + @Override + public boolean contains(K key) throws Exception { + return originalState.contains(key); + } + + @Override + public Iterable> entries() throws Exception { + Iterable> original = originalState.entries(); + return original != null ? original : emptyState.entrySet(); + } + + @Override + public Iterable keys() throws Exception { + Iterable original = originalState.keys(); + return original != null ? original : emptyState.keySet(); + } + + @Override + public Iterable values() throws Exception { + Iterable original = originalState.values(); + return original != null ? original : emptyState.values(); + } + + @Override + public Iterator> iterator() throws Exception { + Iterator> original = originalState.iterator(); + return original != null ? original : emptyState.entrySet().iterator(); + } +} diff --git a/flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/SimpleReducingState.java b/flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/SimpleReducingState.java new file mode 100644 index 0000000000000..312c1e5160305 --- /dev/null +++ b/flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/SimpleReducingState.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.common.functions.util.test; + +import org.apache.flink.api.common.state.ReducingState; +import org.apache.flink.api.common.state.ReducingStateDescriptor; + +/** + * A simple {@link ReducingState} for testing. + */ +public class SimpleReducingState implements ReducingState { + + private ReducingStateDescriptor descriptor; + private T value; + + public SimpleReducingState(ReducingStateDescriptor descriptor) { + this.descriptor = descriptor; + } + + @Override + public T get() throws Exception { + return value; + } + + @Override + public void add(T value) throws Exception { + this.value = descriptor.getReduceFunction().reduce(this.value, value); + } + + @Override + public void clear() { + value = null; + } +} diff --git a/flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/SimpleValueState.java b/flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/SimpleValueState.java new file mode 100644 index 0000000000000..ca6e52031f1cc --- /dev/null +++ b/flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/SimpleValueState.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.common.functions.util.test; + +import org.apache.flink.api.common.state.ValueState; + +import java.io.IOException; + +/** + * A simple {@link ValueState} for testing. + */ +public class SimpleValueState implements ValueState { + + private T value; + + public SimpleValueState(T value) { + this.value = value; + } + + @Override + public T value() throws IOException { + return value; + } + + @Override + public void update(T value) throws IOException { + this.value = value; + } + + @Override + public void clear() { + this.value = null; + } +} diff --git a/flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/TestingCollector.java b/flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/TestingCollector.java new file mode 100644 index 0000000000000..8385f541d1f5b --- /dev/null +++ b/flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/TestingCollector.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.common.functions.util.test; + +import org.apache.flink.util.Collector; + +import java.util.ArrayList; +import java.util.List; + +/** + * A simple {@link Collector} used in {@link TestingRuntimeContext}. + */ +class TestingCollector implements Collector { + + List output = new ArrayList<>(); + + @Override + public void collect(T record) { + output.add(record); + } + + @Override + public void close() {} + + public List getOutput() { + return output; + } + +} diff --git a/flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/TestingRuntimeContext.java b/flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/TestingRuntimeContext.java new file mode 100644 index 0000000000000..df224e357c30f --- /dev/null +++ b/flink-core/src/main/java/org/apache/flink/api/common/functions/util/test/TestingRuntimeContext.java @@ -0,0 +1,383 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.common.functions.util.test; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.TaskInfo; +import org.apache.flink.api.common.accumulators.Accumulator; +import org.apache.flink.api.common.accumulators.AccumulatorHelper; +import org.apache.flink.api.common.accumulators.DoubleCounter; +import org.apache.flink.api.common.accumulators.Histogram; +import org.apache.flink.api.common.accumulators.IntCounter; +import org.apache.flink.api.common.accumulators.LongCounter; +import org.apache.flink.api.common.cache.DistributedCache; +import org.apache.flink.api.common.functions.BroadcastVariableInitializer; +import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.api.common.state.AggregatingState; +import org.apache.flink.api.common.state.AggregatingStateDescriptor; +import org.apache.flink.api.common.state.FoldingState; +import org.apache.flink.api.common.state.FoldingStateDescriptor; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.MapState; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.ReducingState; +import org.apache.flink.api.common.state.ReducingStateDescriptor; +import org.apache.flink.api.common.state.State; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.metrics.MetricGroup; +import org.apache.flink.util.OutputTag; + +import java.io.Serializable; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Testing class for user-defined functions testing. + */ +@SuppressWarnings("unchecked") +public class TestingRuntimeContext implements RuntimeContext { + + private MetricGroup metricGroup; + private ExecutionConfig executionConfig; + private ClassLoader userCodeClassLoader; + private Map, State> ctxStates = new HashMap<>(); + private DistributedCache distributedCache; + private TaskInfo taskInfo; + private boolean isStreaming; + + // Broadcast variables + private final HashMap initializedBroadcastVars = new HashMap(); + private final HashMap> uninitializedBroadcastVars = new HashMap>(); + + // Accumulators + private final Map> accumulators = new HashMap<>(); + + // Collector + private TestingCollector collector = new TestingCollector<>(); + private Map, List> sideOutput = new HashMap<>(); + + public TestingRuntimeContext(boolean isStreaming) { + this.isStreaming = isStreaming; + } + + public void setTaskInfo(TaskInfo taskInfo) { + this.taskInfo = taskInfo; + } + + public void setMetricGroup(MetricGroup metricGroup) { + this.metricGroup = metricGroup; + } + + public void setExecutionConfig(ExecutionConfig executionConfig) { + this.executionConfig = executionConfig; + } + + public void setUserCodeClassLoader(ClassLoader userCodeClassLoader) { + this.userCodeClassLoader = userCodeClassLoader; + } + + public void setDistributedCache(DistributedCache distributedCache) { + this.distributedCache = distributedCache; + } + + @Override + public String getTaskName() { + return taskInfo.getTaskName(); + } + + @Override + public MetricGroup getMetricGroup() { + return metricGroup; + } + + @Override + public int getNumberOfParallelSubtasks() { + return taskInfo.getNumberOfParallelSubtasks(); + } + + @Override + public int getMaxNumberOfParallelSubtasks() { + return taskInfo.getMaxNumberOfParallelSubtasks(); + } + + @Override + public int getIndexOfThisSubtask() { + return taskInfo.getIndexOfThisSubtask(); + } + + @Override + public int getAttemptNumber() { + return taskInfo.getAttemptNumber(); + } + + @Override + public String getTaskNameWithSubtasks() { + return taskInfo.getTaskNameWithSubtasks(); + } + + @Override + public ExecutionConfig getExecutionConfig() { + return executionConfig; + } + + @Override + public ClassLoader getUserCodeClassLoader() { + return userCodeClassLoader; + } + + @Override + public DistributedCache getDistributedCache() { + return distributedCache; + } + + // Accumulators operations. + + @Override + public void addAccumulator(String name, Accumulator accumulator) { + if (accumulators.containsKey(name)) { + throw new UnsupportedOperationException("The accumulator '" + name + + "' already exists and cannot be added."); + } + accumulators.put(name, accumulator); + } + + @Override + public Accumulator getAccumulator(String name) { + return (Accumulator) accumulators.get(name); + } + + @Override + public Map> getAllAccumulators() { + return Collections.unmodifiableMap(this.accumulators); + } + + @Override + public IntCounter getIntCounter(String name) { + return (IntCounter) getAccumulator(name, IntCounter.class); + } + + @Override + public LongCounter getLongCounter(String name) { + return (LongCounter) getAccumulator(name, LongCounter.class); + } + + @Override + public Histogram getHistogram(String name) { + return (Histogram) getAccumulator(name, Histogram.class); + } + + @Override + public DoubleCounter getDoubleCounter(String name) { + return (DoubleCounter) getAccumulator(name, DoubleCounter.class); + } + + @SuppressWarnings("unchecked") + private Accumulator getAccumulator(String name, Class> accumulatorClass) { + Accumulator accumulator = accumulators.get(name); + + if (accumulator != null) { + AccumulatorHelper.compareAccumulatorTypes(name, accumulator.getClass(), accumulatorClass); + } else { + // Create new accumulator + try { + accumulator = accumulatorClass.newInstance(); + } + catch (Exception e) { + throw new RuntimeException("Cannot create accumulator " + accumulatorClass.getName()); + } + accumulators.put(name, accumulator); + } + return (Accumulator) accumulator; + } + + // Broadcast operations. + + @Override + public boolean hasBroadcastVariable(String name) { + if (isStreaming) { + throw new UnsupportedOperationException("This broadcastVariable is only accessible by functions executed on a DataSet"); + } + return false; + } + + @Override + public List getBroadcastVariable(String name) { + if (isStreaming) { + throw new UnsupportedOperationException("This broadcastVariable is only accessible by functions executed on a DataSet"); + } + // check if we have an initialized version + Object o = this.initializedBroadcastVars.get(name); + if (o != null) { + if (o instanceof List) { + return (List) o; + } + else { + throw new IllegalStateException("The broadcast variable with name '" + name + + "' is not a List. A different call must have requested this variable with a BroadcastVariableInitializer."); + } + } + else { + List uninitialized = this.uninitializedBroadcastVars.remove(name); + if (uninitialized != null) { + this.initializedBroadcastVars.put(name, uninitialized); + return (List) uninitialized; + } + else { + throw new IllegalArgumentException("The broadcast variable with name '" + name + "' has not been set."); + } + } + } + + @Override + public C getBroadcastVariableWithInitializer(String name, BroadcastVariableInitializer initializer) { + if (isStreaming) { + throw new UnsupportedOperationException("This broadcastVariable is only accessible by functions executed on a DataSet"); + } + // check if we have an initialized version + Object o = this.initializedBroadcastVars.get(name); + if (o != null) { + return (C) o; + } + else { + List uninitialized = (List) this.uninitializedBroadcastVars.remove(name); + if (uninitialized != null) { + C result = initializer.initializeBroadcastVariable(uninitialized); + this.initializedBroadcastVars.put(name, result); + return result; + } + else { + throw new IllegalArgumentException("The broadcast variable with name '" + name + "' has not been set."); + } + } + } + + public void setBroadcastVariable(String name, List value) { + if (isStreaming) { + throw new UnsupportedOperationException("This broadcastVariable is only accessible by functions executed on a DataSet"); + } + this.uninitializedBroadcastVars.put(name, value); + this.initializedBroadcastVars.remove(name); + } + + public void clearBroadcastVariable(String name) { + if (isStreaming) { + throw new UnsupportedOperationException("This broadcastVariable is only accessible by functions executed on a DataSet"); + } + this.uninitializedBroadcastVars.remove(name); + this.initializedBroadcastVars.remove(name); + } + + public void clearAllBroadcastVariables() { + if (isStreaming) { + throw new UnsupportedOperationException("This broadcastVariable is only accessible by functions executed on a DataSet"); + } + this.uninitializedBroadcastVars.clear(); + this.initializedBroadcastVars.clear(); + } + + + // State operations. + + public void setState(StateDescriptor stateProperties, State state) { + ctxStates.put(stateProperties, state); + } + + @Override + public ValueState getState(ValueStateDescriptor stateProperties) { + if (isStreaming) { + return (ValueState) ctxStates.get(stateProperties); + } else { + throw new UnsupportedOperationException("This state is only accessible by functions executed on a KeyedStream"); + } + } + + @Override + public ListState getListState(ListStateDescriptor stateProperties) { + if (isStreaming) { + + return (ListState) ctxStates.get(stateProperties); + } else { + throw new UnsupportedOperationException("This state is only accessible by functions executed on a KeyedStream"); + } + } + + @Override + public ReducingState getReducingState(ReducingStateDescriptor stateProperties) { + if (isStreaming) { + return (ReducingState) ctxStates.get(stateProperties); + } else { + throw new UnsupportedOperationException("This state is only accessible by functions executed on a KeyedStream"); + } + } + + @Override + public AggregatingState getAggregatingState(AggregatingStateDescriptor stateProperties) { + if (isStreaming) { + return (AggregatingState) ctxStates.get(stateProperties); + } else { + throw new UnsupportedOperationException("This state is only accessible by functions executed on a KeyedStream"); + } + } + + @Override + public FoldingState getFoldingState(FoldingStateDescriptor stateProperties) { + if (isStreaming) { + return (FoldingState) ctxStates.get(stateProperties); + } else { + throw new UnsupportedOperationException("This state is only accessible by functions executed on a KeyedStream"); + } + } + + @Override + public MapState getMapState(MapStateDescriptor stateProperties) { + if (isStreaming) { + return (MapState) ctxStates.get(stateProperties); + } else { + throw new UnsupportedOperationException("This state is only accessible by functions executed on a KeyedStream"); + } + } + + public TestingCollector getCollector() { + return (TestingCollector) collector; + } + + public List getCollectorOutput() { + return (List) collector.output; + } + + public void addSideOutput(OutputTag tag, T value) { + if (sideOutput.containsKey(tag)) { + List originList = (List) sideOutput.get(tag); + originList.add(value); + sideOutput.put(tag, originList); + } else { + sideOutput.put(tag, Collections.singletonList(value)); + } + } + + public List getSideOutput(OutputTag tag) { + return (List) sideOutput.getOrDefault(tag, Collections.emptyList()); + } + +} diff --git a/flink-core/src/test/java/org/apache/flink/api/common/functions/util/test/TestingRuntimeContextTest.java b/flink-core/src/test/java/org/apache/flink/api/common/functions/util/test/TestingRuntimeContextTest.java new file mode 100644 index 0000000000000..f8084b76325e2 --- /dev/null +++ b/flink-core/src/test/java/org/apache/flink/api/common/functions/util/test/TestingRuntimeContextTest.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.common.functions.util.test; + +import org.apache.flink.api.common.functions.BroadcastVariableInitializer; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.RichCoGroupFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.util.Collector; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.*; + +/** + * Test for the {@link TestingRuntimeContext}. + */ +public class TestingRuntimeContextTest { + + @Test + public void testWordDistinctFlat() throws Exception { + TestingRuntimeContext ctx = new TestingRuntimeContext(false); + ctx.setBroadcastVariable("exWords", Collections.singletonList("Eat")); + WordDistinctFlat flat = new WordDistinctFlat(); + flat.setRuntimeContext(ctx); + flat.flatMap("Eat,Eat,Walk,Run", ctx.getCollector()); + Assert.assertArrayEquals(ctx.getCollectorOutput().toArray(), new String[]{"Walk", "Run"}); + } + + @Test + public void testCharDistinctFlat() throws Exception { + TestingRuntimeContext ctx = new TestingRuntimeContext(false); + ctx.setBroadcastVariable("exclude", Collections.singletonList('c')); + CharDistinctFlat flatFunc = new CharDistinctFlat(); + flatFunc.setRuntimeContext(ctx); + flatFunc.flatMap("abbcd", ctx.getCollector()); + Assert.assertArrayEquals(ctx.getCollectorOutput().toArray(), new Character[]{'a', 'b', 'd'}); + } + + @Test + public void testReverseMap() throws Exception { + ReverseMap m = new ReverseMap(); + String result = m.map("12345678"); + Assert.assertEquals(result, "87654321"); + } + + @Test + public void testAttributionSumCoGroup() throws Exception { + TestingRuntimeContext ctx = new TestingRuntimeContext(false); + ctx.setBroadcastVariable("min", Collections.singletonList(3)); + Collector collector = ctx.getCollector(); + AttributionSumCoGroup cg = new AttributionSumCoGroup(); + cg.setRuntimeContext(ctx); + List> array1 = new ArrayList<>(); + array1.add(new Tuple2<>("1", 2)); + array1.add(new Tuple2<>("1", 3)); + array1.add(new Tuple2<>("1", 4)); + List> array2 = new ArrayList<>(); + array2.add(new Tuple2<>("1", 1)); + cg.coGroup(array1, array2, collector); + Assert.assertArrayEquals(ctx.getCollectorOutput().toArray(), new Integer[]{5, 9}); + } + + /** + * User-defined flatMap function to deduplicate words. + */ + public static class WordDistinctFlat extends RichFlatMapFunction { + + @Override + public void flatMap(String value, Collector out) throws Exception { + Set wordsList = new HashSet<>(Arrays.asList(value.split(","))); + List excludeWords = getRuntimeContext().getBroadcastVariable("exWords"); + for (String word: wordsList) { + if (!excludeWords.contains(word)) { + out.collect(word); + } + } + } + } + + /** + * User-defined flatMap function to deduplicate characters. + */ + public static class CharDistinctFlat extends RichFlatMapFunction { + + @Override + public void flatMap(String value, Collector out) throws Exception { + List excludes = getRuntimeContext().getBroadcastVariable("exclude"); + Set characters = new HashSet<>(); + char[] chars = value.toCharArray(); + for (char c : chars) { + if (!characters.contains(c) && !excludes.contains(c)) { + characters.add(c); + out.collect(c); + } + } + } + } + + /** + * User-defined map function that reverses string. + */ + public static class ReverseMap implements MapFunction { + + @Override + public String map(String value) throws Exception { + return new StringBuffer(value).reverse().toString(); + } + } + + /** + * User-defined map function that cogroups data which is bigger than minimum. + */ + public static class AttributionSumCoGroup extends RichCoGroupFunction, Tuple2, Integer> { + + @Override + public void coGroup(Iterable> first, Iterable> second, Collector out) throws Exception { + int sum = 0; + int minThreshold = getRuntimeContext().getBroadcastVariableWithInitializer("min", new BroadcastVariableInitializer() { + + @Override + public Integer initializeBroadcastVariable(Iterable data) { + for (Integer v : data) { + if (v > 0) { + return v; + } + } + return 0; + } + }); + for (Tuple2 aFirst : first) { + int v = aFirst.f1; + sum += v; + if (sum > minThreshold) { + out.collect(sum); + } + } + } + } + +} diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/test/TestingRuntimeContextTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/test/TestingRuntimeContextTest.java new file mode 100644 index 0000000000000..e2b138845227e --- /dev/null +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/test/TestingRuntimeContextTest.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.functions.test; + +import org.apache.flink.api.common.functions.util.test.SimpleValueState; +import org.apache.flink.api.common.functions.util.test.TestingRuntimeContext; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.streaming.api.TimerService; +import org.apache.flink.streaming.api.functions.co.CoProcessFunction; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.Collections; + +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyLong; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; + +/** + * Test for the {@link TestingRuntimeContext}. + */ +@SuppressWarnings("unchecked") +public class TestingRuntimeContextTest { + + @Test + public void testEnrichmentFunction() throws Exception { + TestingRuntimeContext ctx = new TestingRuntimeContext(true); + EnrichmentFunction func = new EnrichmentFunction(); + func.setRuntimeContext(ctx); + + CoProcessFunction.Context context = mock(EnrichmentFunction.Context.class); + CoProcessFunction.OnTimerContext timerContext = mock(EnrichmentFunction.OnTimerContext.class); + TimerService timerService = mock(TimerService.class); + doAnswer(invocationOnMock -> { + OutputTag outputTag = invocationOnMock.getArgumentAt(0, OutputTag.class); + Object value = invocationOnMock.getArgumentAt(1, Object.class); + ctx.addSideOutput(outputTag, value); + return null; + }).when(timerContext).output(any(OutputTag.class), any()); + doReturn(timerService).when(context).timerService(); + doNothing().when(timerService).registerEventTimeTimer(anyLong()); + + ValueStateDescriptor rideStateDesc = new ValueStateDescriptor<>("saved ride", TaxiRide.class); + ValueStateDescriptor fareStateDesc = new ValueStateDescriptor<>("saved fare", TaxiFare.class); + ctx.setState(rideStateDesc, new SimpleValueState<>(null)); + ctx.setState(fareStateDesc, new SimpleValueState(null)); + func.open(new Configuration()); + + TaxiRide ride1 = new TaxiRide(1); + func.processElement1(ride1, context, ctx.getCollector()); + Assert.assertEquals(ctx.getState(rideStateDesc).value(), ride1); + + TaxiFare fare1 = new TaxiFare(1); + func.processElement2(fare1, context, ctx.getCollector()); + Assert.assertEquals(ctx.getState(rideStateDesc).value(), null); + Assert.assertEquals(ctx.getCollectorOutput(), Collections.singletonList(new Tuple2(ride1, fare1))); + + TaxiFare fare2 = new TaxiFare(2); + func.processElement2(fare2, context, ctx.getCollector()); + Assert.assertEquals(ctx.getState(fareStateDesc).value(), fare2); + + func.onTimer(0L, timerContext, ctx.getCollector()); + Assert.assertEquals(Collections.singletonList(fare2), ctx.getSideOutput(unmatchedFares)); + + } + + + static OutputTag unmatchedRides = new OutputTag("unmatchedRides") {}; + static OutputTag unmatchedFares = new OutputTag("unmatchedFares") {}; + + static class TaxiRide { + private long eventTime; + + TaxiRide(long eventTime) { + this.eventTime = eventTime; + } + + Long getEventTime() { + return eventTime; + } + + } + + static class TaxiFare { + private long eventTime; + + TaxiFare(long eventTime) { + this.eventTime = eventTime; + } + + Long getEventTime() { + return eventTime; + } + + } + + /** + * User-defined function that joins rides and fare. + */ + public static class EnrichmentFunction extends CoProcessFunction> { + // keyed, managed state + private ValueState rideState; + private ValueState fareState; + + @Override + public void open(Configuration config) { + rideState = getRuntimeContext().getState(new ValueStateDescriptor<>("saved ride", TaxiRide.class)); + fareState = getRuntimeContext().getState(new ValueStateDescriptor<>("saved fare", TaxiFare.class)); + } + + @Override + public void onTimer(long timestamp, OnTimerContext ctx, Collector> out) throws Exception { + if (fareState.value() != null) { + ctx.output(unmatchedFares, fareState.value()); + fareState.clear(); + } + if (rideState.value() != null) { + ctx.output(unmatchedRides, rideState.value()); + rideState.clear(); + } + } + + @Override + public void processElement1(TaxiRide ride, Context context, Collector> out) throws Exception { + TaxiFare fare = fareState.value(); + if (fare != null) { + fareState.clear(); + out.collect(new Tuple2(ride, fare)); + } else { + rideState.update(ride); + // as soon as the watermark arrives, we can stop waiting for the corresponding fare + context.timerService().registerEventTimeTimer(ride.getEventTime()); + } + } + + @Override + public void processElement2(TaxiFare fare, Context context, Collector> out) throws Exception { + TaxiRide ride = rideState.value(); + if (ride != null) { + rideState.clear(); + out.collect(new Tuple2(ride, fare)); + } else { + fareState.update(fare); + // wait up to 6 hours for the corresponding ride END event, then clear the state + context.timerService().registerEventTimeTimer(fare.getEventTime() + 6 * 60 * 60 * 1000); + } + } + } + +}