diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/GroupReduceOperator.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/GroupReduceOperator.java index 42553a008da49..1feab0c51ffc5 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/operators/GroupReduceOperator.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/GroupReduceOperator.java @@ -18,9 +18,10 @@ package org.apache.flink.api.java.operators; +import org.apache.flink.api.common.functions.CombineFunction; import org.apache.flink.api.common.functions.GroupCombineFunction; import org.apache.flink.api.common.functions.GroupReduceFunction; -import org.apache.flink.api.common.operators.Keys; +import org.apache.flink.api.java.operators.translation.CombineToGroupCombineWrapper; import org.apache.flink.api.common.operators.Operator; import org.apache.flink.api.common.operators.Order; import org.apache.flink.api.common.operators.Ordering; @@ -30,6 +31,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.functions.SemanticPropUtil; import org.apache.flink.api.common.operators.Keys.SelectorFunctionKeys; +import org.apache.flink.api.common.operators.Keys.ExpressionKeys; import org.apache.flink.api.java.operators.translation.PlanUnwrappingReduceGroupOperator; import org.apache.flink.api.java.operators.translation.PlanUnwrappingSortedReduceGroupOperator; import org.apache.flink.api.java.tuple.Tuple2; @@ -52,7 +54,7 @@ public class GroupReduceOperator extends SingleInputUdfOperator function; + private GroupReduceFunction function; private final Grouping grouper; @@ -68,12 +70,12 @@ public class GroupReduceOperator extends SingleInputUdfOperator input, TypeInformation resultType, GroupReduceFunction function, String defaultName) { super(input, resultType); - + this.function = function; this.grouper = null; this.defaultName = defaultName; - checkCombinability(); + this.combinable = checkCombinability(); } /** @@ -84,18 +86,18 @@ public GroupReduceOperator(DataSet input, TypeInformation resultType, G */ public GroupReduceOperator(Grouping input, TypeInformation resultType, GroupReduceFunction function, String defaultName) { super(input != null ? input.getInputDataSet() : null, resultType); - + this.function = function; this.grouper = input; this.defaultName = defaultName; - checkCombinability(); + this.combinable = checkCombinability(); UdfOperatorUtils.analyzeSingleInputUdf(this, GroupReduceFunction.class, defaultName, function, grouper.keys); } - private void checkCombinability() { - if (function instanceof GroupCombineFunction) { + private boolean checkCombinability() { + if (function instanceof GroupCombineFunction || function instanceof CombineFunction) { // check if the generic types of GroupCombineFunction and GroupReduceFunction match, i.e., // GroupCombineFunction and GroupReduceFunction. @@ -110,7 +112,9 @@ private void checkCombinability() { if (((ParameterizedType) genInterface).getRawType().equals(GroupReduceFunction.class)) { reduceTypes = ((ParameterizedType) genInterface).getActualTypeArguments(); // get parameters of GroupCombineFunction - } else if (((ParameterizedType) genInterface).getRawType().equals(GroupCombineFunction.class)) { + } else if ((((ParameterizedType) genInterface).getRawType().equals(GroupCombineFunction.class)) || + (((ParameterizedType) genInterface).getRawType().equals(CombineFunction.class))) { + combineTypes = ((ParameterizedType) genInterface).getActualTypeArguments(); } } @@ -120,24 +124,25 @@ private void checkCombinability() { combineTypes != null && combineTypes.length == 2) { if (reduceTypes[0].equals(combineTypes[0]) && reduceTypes[0].equals(combineTypes[1])) { - this.combinable = true; + return true; } else { LOG.warn("GroupCombineFunction cannot be used as combiner for GroupReduceFunction. " + "Generic types are incompatible."); - this.combinable = false; + return false; } } else if (reduceTypes == null || reduceTypes.length != 2) { LOG.warn("Cannot check generic types of GroupReduceFunction. " + "Enabling combiner but combine function might fail at runtime."); - this.combinable = true; + return true; } else { LOG.warn("Cannot check generic types of GroupCombineFunction. " + "Enabling combiner but combine function might fail at runtime."); - this.combinable = true; + return true; } } + return false; } @@ -156,13 +161,18 @@ public boolean isCombinable() { } public GroupReduceOperator setCombinable(boolean combinable) { - // sanity check that the function is a subclass of the combine interface - if (combinable && !(function instanceof GroupCombineFunction)) { - throw new IllegalArgumentException("The function does not implement the combine interface."); + + if(combinable) { + // sanity check that the function is a subclass of the combine interface + if (!checkCombinability()) { + throw new IllegalArgumentException("Either the function does not implement a combine interface, " + + "or the types of the combine() and reduce() methods are not compatible."); + } + this.combinable = true; + } + else { + this.combinable = false; } - - this.combinable = combinable; - return this; } @@ -191,10 +201,16 @@ public SingleInputSemanticProperties getSemanticProperties() { // -------------------------------------------------------------------------------------------- @Override + @SuppressWarnings("unchecked") protected GroupReduceOperatorBase translateToDataFlow(Operator input) { String name = getName() != null ? getName() : "GroupReduce at " + defaultName; - + + // wrap CombineFunction in GroupCombineFunction if combinable + this.function = (combinable && function instanceof CombineFunction) ? + new CombineToGroupCombineWrapper((CombineFunction) function) : + function; + // distinguish between grouped reduce and non-grouped reduce if (grouper == null) { // non grouped reduce @@ -236,7 +252,7 @@ selectorKeys, sortKeys, groupOrder, function, getResultType(), name, input, isCo return po; } } - else if (grouper.getKeys() instanceof Keys.ExpressionKeys) { + else if (grouper.getKeys() instanceof ExpressionKeys) { int[] logicalKeyPositions = grouper.getKeys().computeLogicalKeyPositions(); UnaryOperatorInformation operatorInfo = new UnaryOperatorInformation<>(getInputType(), getResultType()); diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/translation/CombineToGroupCombineWrapper.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/translation/CombineToGroupCombineWrapper.java new file mode 100644 index 0000000000000..87c1e33c510f0 --- /dev/null +++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/translation/CombineToGroupCombineWrapper.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.java.operators.translation; + +import com.google.common.base.Preconditions; +import org.apache.flink.api.common.functions.CombineFunction; +import org.apache.flink.api.common.functions.GroupCombineFunction; +import org.apache.flink.api.common.functions.GroupReduceFunction; +import org.apache.flink.util.Collector; + +/** + * A wrapper the wraps a function that implements both {@link CombineFunction} and {@link GroupReduceFunction} interfaces + * and makes it look like a function that implements {@link GroupCombineFunction} and {@link GroupReduceFunction} to the runtime. + */ +public class CombineToGroupCombineWrapper & GroupReduceFunction> + implements GroupCombineFunction, GroupReduceFunction { + + private final F wrappedFunction; + + public CombineToGroupCombineWrapper(F wrappedFunction) { + this.wrappedFunction = Preconditions.checkNotNull(wrappedFunction); + } + + @Override + public void combine(Iterable values, Collector out) throws Exception { + IN outValue = wrappedFunction.combine(values); + out.collect(outValue); + } + + @Override + public void reduce(Iterable values, Collector out) throws Exception { + wrappedFunction.reduce(values, out); + } +} diff --git a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/ReduceWithCombinerITCase.java b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/ReduceWithCombinerITCase.java new file mode 100644 index 0000000000000..685a9ac969a2c --- /dev/null +++ b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/ReduceWithCombinerITCase.java @@ -0,0 +1,313 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.test.javaApiOperators; + +import org.apache.flink.api.common.functions.CombineFunction; +import org.apache.flink.api.common.functions.GroupCombineFunction; +import org.apache.flink.api.common.functions.GroupReduceFunction; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.operators.UnsortedGrouping; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.test.util.MultipleProgramsTestBase; +import org.apache.flink.util.Collector; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.List; + +@SuppressWarnings("serial") +@RunWith(Parameterized.class) +public class ReduceWithCombinerITCase extends MultipleProgramsTestBase { + + public ReduceWithCombinerITCase(TestExecutionMode mode) { + super(TestExecutionMode.CLUSTER); + } + + @Test + public void testReduceOnNonKeyedDataset() throws Exception { + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(4); + + // creates the input data and distributes them evenly among the available downstream tasks + DataSet> input = createNonKeyedInput(env); + List> actual = input.reduceGroup(new NonKeyedCombReducer()).collect(); + String expected = "10,true\n"; + + compareResultAsTuples(actual, expected); + } + + @Test + public void testForkingReduceOnNonKeyedDataset() throws Exception { + + // set up the execution environment + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(4); + + // creates the input data and distributes them evenly among the available downstream tasks + DataSet> input = createNonKeyedInput(env); + + DataSet> r1 = input.reduceGroup(new NonKeyedCombReducer()); + DataSet> r2 = input.reduceGroup(new NonKeyedGroupCombReducer()); + + List> actual = r1.union(r2).collect(); + String expected = "10,true\n10,true\n"; + compareResultAsTuples(actual, expected); + } + + @Test + public void testReduceOnKeyedDataset() throws Exception { + + // set up the execution environment + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(4); + + // creates the input data and distributes them evenly among the available downstream tasks + DataSet> input = createKeyedInput(env); + List> actual = input.groupBy(0).reduceGroup(new KeyedCombReducer()).collect(); + String expected = "k1,6,true\nk2,4,true\n"; + + compareResultAsTuples(actual, expected); + } + + @Test + public void testReduceOnKeyedDatasetWithSelector() throws Exception { + + // set up the execution environment + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(4); + + // creates the input data and distributes them evenly among the available downstream tasks + DataSet> input = createKeyedInput(env); + + List> actual = input + .groupBy(new KeySelectorX()) + .reduceGroup(new KeyedCombReducer()) + .collect(); + String expected = "k1,6,true\nk2,4,true\n"; + + compareResultAsTuples(actual, expected); + } + + @Test + public void testForkingReduceOnKeyedDataset() throws Exception { + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(4); + + // creates the input data and distributes them evenly among the available downstream tasks + DataSet> input = createKeyedInput(env); + + UnsortedGrouping> counts = input.groupBy(0); + + DataSet> r1 = counts.reduceGroup(new KeyedCombReducer()); + DataSet> r2 = counts.reduceGroup(new KeyedGroupCombReducer()); + + List> actual = r1.union(r2).collect(); + String expected = "k1,6,true\n" + + "k2,4,true\n" + + "k1,6,true\n" + + "k2,4,true\n"; + compareResultAsTuples(actual, expected); + } + + @Test + public void testForkingReduceOnKeyedDatasetWithSelection() throws Exception { + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(4); + + // creates the input data and distributes them evenly among the available downstream tasks + DataSet> input = createKeyedInput(env); + + UnsortedGrouping> counts = input.groupBy(new KeySelectorX()); + + DataSet> r1 = counts.reduceGroup(new KeyedCombReducer()); + DataSet> r2 = counts.reduceGroup(new KeyedGroupCombReducer()); + + List> actual = r1.union(r2).collect(); + String expected = "k1,6,true\n" + + "k2,4,true\n" + + "k1,6,true\n" + + "k2,4,true\n"; + + compareResultAsTuples(actual, expected); + } + + private DataSet> createNonKeyedInput(ExecutionEnvironment env) { + return env.fromCollection(Arrays.asList( + new Tuple2<>(1, false), + new Tuple2<>(1, false), + new Tuple2<>(1, false), + new Tuple2<>(1, false), + new Tuple2<>(1, false), + new Tuple2<>(1, false), + new Tuple2<>(1, false), + new Tuple2<>(1, false), + new Tuple2<>(1, false), + new Tuple2<>(1, false)) + ).rebalance(); + } + + private static class NonKeyedCombReducer implements CombineFunction, Tuple2>, + GroupReduceFunction,Tuple2> { + + @Override + public Tuple2 combine(Iterable> values) throws Exception { + int sum = 0; + boolean flag = true; + + for(Tuple2 tuple : values) { + sum += tuple.f0; + flag &= !tuple.f1; + + } + return new Tuple2<>(sum, flag); + } + + @Override + public void reduce(Iterable> values, Collector> out) throws Exception { + int sum = 0; + boolean flag = true; + for(Tuple2 tuple : values) { + sum += tuple.f0; + flag &= tuple.f1; + } + out.collect(new Tuple2<>(sum, flag)); + } + } + + private static class NonKeyedGroupCombReducer implements GroupCombineFunction, Tuple2>, + GroupReduceFunction,Tuple2> { + + @Override + public void reduce(Iterable> values, Collector> out) throws Exception { + int sum = 0; + boolean flag = true; + for(Tuple2 tuple : values) { + sum += tuple.f0; + flag &= tuple.f1; + } + out.collect(new Tuple2<>(sum, flag)); + } + + @Override + public void combine(Iterable> values, Collector> out) throws Exception { + int sum = 0; + boolean flag = true; + for(Tuple2 tuple : values) { + sum += tuple.f0; + flag &= !tuple.f1; + } + out.collect(new Tuple2<>(sum, flag)); + } + } + + private DataSet> createKeyedInput(ExecutionEnvironment env) { + return env.fromCollection(Arrays.asList( + new Tuple3<>("k1", 1, false), + new Tuple3<>("k1", 1, false), + new Tuple3<>("k1", 1, false), + new Tuple3<>("k2", 1, false), + new Tuple3<>("k1", 1, false), + new Tuple3<>("k1", 1, false), + new Tuple3<>("k2", 1, false), + new Tuple3<>("k2", 1, false), + new Tuple3<>("k1", 1, false), + new Tuple3<>("k2", 1, false)) + ).rebalance(); + } + + public static class KeySelectorX implements KeySelector, String> { + private static final long serialVersionUID = 1L; + @Override + public String getKey(Tuple3 in) { + return in.f0; + } + } + + private class KeyedCombReducer implements CombineFunction, Tuple3>, + GroupReduceFunction, Tuple3> { + + @Override + public Tuple3 combine(Iterable> values) throws Exception { + String key = null; + int sum = 0; + boolean flag = true; + + for(Tuple3 tuple : values) { + key = (key == null) ? tuple.f0 : key; + sum += tuple.f1; + flag &= !tuple.f2; + } + return new Tuple3<>(key, sum, flag); + } + + @Override + public void reduce(Iterable> values, Collector> out) throws Exception { + String key = null; + int sum = 0; + boolean flag = true; + + for(Tuple3 tuple : values) { + key = (key == null) ? tuple.f0 : key; + sum += tuple.f1; + flag &= tuple.f2; + } + out.collect(new Tuple3<>(key, sum, flag)); + } + } + + private class KeyedGroupCombReducer implements GroupCombineFunction, Tuple3>, + GroupReduceFunction, Tuple3> { + + @Override + public void combine(Iterable> values, Collector> out) throws Exception { + String key = null; + int sum = 0; + boolean flag = true; + + for(Tuple3 tuple : values) { + key = (key == null) ? tuple.f0 : key; + sum += tuple.f1; + flag &= !tuple.f2; + } + out.collect(new Tuple3<>(key, sum, flag)); + } + + @Override + public void reduce(Iterable> values, Collector> out) throws Exception { + String key = null; + int sum = 0; + boolean flag = true; + + for(Tuple3 tuple : values) { + key = (key == null) ? tuple.f0 : key; + sum += tuple.f1; + flag &= tuple.f2; + } + out.collect(new Tuple3<>(key, sum, flag)); + } + } +}