diff --git a/flink-core/src/main/java/org/apache/flink/api/common/distributions/DataDistribution.java b/flink-core/src/main/java/org/apache/flink/api/common/distributions/DataDistribution.java index 321948d8c7819..c0794d6c6b307 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/distributions/DataDistribution.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/distributions/DataDistribution.java @@ -22,6 +22,7 @@ import java.io.Serializable; import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.core.io.IOReadableWritable; @PublicEvolving @@ -57,4 +58,10 @@ public interface DataDistribution extends IOReadableWritable, Serializable { * @return The number of fields in the (composite) key. */ int getNumberOfFields(); + + /** + * Gets the type of the key by which the dataSet is partitioned. + * @return The type of the key by which the dataSet is partitioned. + */ + TypeInformation[] getKeyTypes(); } diff --git a/flink-core/src/main/java/org/apache/flink/api/common/operators/Keys.java b/flink-core/src/main/java/org/apache/flink/api/common/operators/Keys.java index ad21c476cac11..abe41af29b279 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/operators/Keys.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/operators/Keys.java @@ -42,7 +42,7 @@ public abstract class Keys { public abstract int[] computeLogicalKeyPositions(); - protected abstract TypeInformation[] getKeyFieldTypes(); + public abstract TypeInformation[] getKeyFieldTypes(); public abstract void validateCustomPartitioner(Partitioner partitioner, TypeInformation typeInfo); @@ -134,7 +134,7 @@ public int[] computeLogicalKeyPositions() { } @Override - protected TypeInformation[] getKeyFieldTypes() { + public TypeInformation[] getKeyFieldTypes() { TypeInformation[] fieldTypes = new TypeInformation[keyFields.size()]; for (int i = 0; i < keyFields.size(); i++) { fieldTypes[i] = keyFields.get(i).getType(); @@ -337,7 +337,7 @@ public int[] computeLogicalKeyPositions() { } @Override - protected TypeInformation[] getKeyFieldTypes() { + public TypeInformation[] getKeyFieldTypes() { TypeInformation[] fieldTypes = new TypeInformation[keyFields.size()]; for (int i = 0; i < keyFields.size(); i++) { fieldTypes[i] = keyFields.get(i).getType(); diff --git a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/PartitionOperatorBase.java b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/PartitionOperatorBase.java index fd71facb576f5..4b802aa1ed11e 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/PartitionOperatorBase.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/PartitionOperatorBase.java @@ -22,6 +22,7 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.distributions.DataDistribution; import org.apache.flink.api.common.functions.Partitioner; import org.apache.flink.api.common.functions.RuntimeContext; import org.apache.flink.api.common.functions.util.NoOpFunction; @@ -49,6 +50,8 @@ public static enum PartitionMethod { private Partitioner customPartitioner; + private DataDistribution distribution; + public PartitionOperatorBase(UnaryOperatorInformation operatorInfo, PartitionMethod pMethod, int[] keys, String name) { super(new UserCodeObjectWrapper(new NoOpFunction()), operatorInfo, keys, name); @@ -70,6 +73,14 @@ public Partitioner getCustomPartitioner() { return customPartitioner; } + public DataDistribution getDistribution() { + return this.distribution; + } + + public void setDistribution(DataDistribution distribution) { + this.distribution = distribution; + } + public void setCustomPartitioner(Partitioner customPartitioner) { if (customPartitioner != null) { int[] keys = getKeyColumns(0); diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/PartitionOperator.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/PartitionOperator.java index b2b9f6ed03a5e..dc4a018cfe9f7 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/operators/PartitionOperator.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/PartitionOperator.java @@ -22,6 +22,7 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.annotation.Public; +import org.apache.flink.api.common.distributions.DataDistribution; import org.apache.flink.api.common.functions.Partitioner; import org.apache.flink.api.common.operators.Keys; import org.apache.flink.api.common.operators.Operator; @@ -33,6 +34,8 @@ import org.apache.flink.api.common.operators.Keys.SelectorFunctionKeys; import org.apache.flink.api.java.tuple.Tuple2; +import java.util.Arrays; + /** * This operator represents a partitioning. * @@ -45,35 +48,46 @@ public class PartitionOperator extends SingleInputOperator customPartitioner; - - + private final DataDistribution distribution; + + public PartitionOperator(DataSet input, PartitionMethod pMethod, Keys pKeys, String partitionLocationName) { - this(input, pMethod, pKeys, null, null, partitionLocationName); + this(input, pMethod, pKeys, null, null, null, partitionLocationName); } - + + public PartitionOperator(DataSet input, PartitionMethod pMethod, Keys pKeys, DataDistribution distribution, String partitionLocationName) { + this(input, pMethod, pKeys, null, null, distribution, partitionLocationName); + } + public PartitionOperator(DataSet input, PartitionMethod pMethod, String partitionLocationName) { - this(input, pMethod, null, null, null, partitionLocationName); + this(input, pMethod, null, null, null, null, partitionLocationName); } public PartitionOperator(DataSet input, Keys pKeys, Partitioner customPartitioner, String partitionLocationName) { - this(input, PartitionMethod.CUSTOM, pKeys, customPartitioner, null, partitionLocationName); + this(input, PartitionMethod.CUSTOM, pKeys, customPartitioner, null, null, partitionLocationName); } - public

PartitionOperator(DataSet input, Keys pKeys, Partitioner

customPartitioner, + public

PartitionOperator(DataSet input, Keys pKeys, Partitioner

customPartitioner, TypeInformation

partitionerTypeInfo, String partitionLocationName) { - this(input, PartitionMethod.CUSTOM, pKeys, customPartitioner, partitionerTypeInfo, partitionLocationName); + this(input, PartitionMethod.CUSTOM, pKeys, customPartitioner, partitionerTypeInfo, null, partitionLocationName); } - private

PartitionOperator(DataSet input, PartitionMethod pMethod, Keys pKeys, Partitioner

customPartitioner, - TypeInformation

partitionerTypeInfo, String partitionLocationName) + private

PartitionOperator(DataSet input, PartitionMethod pMethod, Keys pKeys, Partitioner

customPartitioner, + TypeInformation

partitionerTypeInfo, DataDistribution distribution, String partitionLocationName) { super(input, input.getType()); Preconditions.checkNotNull(pMethod); Preconditions.checkArgument(pKeys != null || pMethod == PartitionMethod.REBALANCE, "Partitioning requires keys"); Preconditions.checkArgument(pMethod != PartitionMethod.CUSTOM || customPartitioner != null, "Custom partioning requires a partitioner."); - + Preconditions.checkArgument(distribution == null || pMethod == PartitionMethod.RANGE, "Customized data distribution is only neccessary for range partition."); + + if (distribution != null) { + Preconditions.checkArgument(distribution.getNumberOfFields() == pKeys.getNumberOfKeyFields(), "The number of key fields in the distribution and range partitioner should be the same."); + Preconditions.checkArgument(Arrays.equals(distribution.getKeyTypes(), pKeys.getKeyFieldTypes()), "The types of key from the distribution and range partitioner are not equal."); + } + if (customPartitioner != null) { pKeys.validateCustomPartitioner(customPartitioner, partitionerTypeInfo); } @@ -82,6 +96,7 @@ private

PartitionOperator(DataSet input, PartitionMethod pMethod, Keys this.pKeys = pKeys; this.partitionLocationName = partitionLocationName; this.customPartitioner = customPartitioner; + this.distribution = distribution; } // -------------------------------------------------------------------------------------------- @@ -125,6 +140,7 @@ else if (pMethod == PartitionMethod.HASH || pMethod == PartitionMethod.CUSTOM || PartitionOperatorBase partitionedInput = new PartitionOperatorBase<>(operatorInfo, pMethod, logicalKeyPositions, name); partitionedInput.setInput(input); partitionedInput.setParallelism(getParallelism()); + partitionedInput.setDistribution(distribution); partitionedInput.setCustomPartitioner(customPartitioner); return partitionedInput; diff --git a/flink-java/src/main/java/org/apache/flink/api/java/utils/DataSetUtils.java b/flink-java/src/main/java/org/apache/flink/api/java/utils/DataSetUtils.java index 78e52319820e4..61a71aa32a75b 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/utils/DataSetUtils.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/utils/DataSetUtils.java @@ -21,16 +21,23 @@ import com.google.common.collect.Lists; import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.api.common.JobExecutionResult; +import org.apache.flink.api.common.distributions.DataDistribution; import org.apache.flink.api.common.functions.BroadcastVariableInitializer; import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.operators.Keys; +import org.apache.flink.api.common.operators.base.PartitionOperatorBase; +import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.Utils; +import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.functions.SampleInCoordinator; import org.apache.flink.api.java.functions.SampleInPartition; import org.apache.flink.api.java.functions.SampleWithFraction; import org.apache.flink.api.java.operators.GroupReduceOperator; import org.apache.flink.api.java.operators.MapPartitionOperator; +import org.apache.flink.api.java.operators.PartitionOperator; import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.configuration.Configuration; import org.apache.flink.util.AbstractID; import org.apache.flink.util.Collector; @@ -250,6 +257,32 @@ public static DataSet sampleWithSize( return new GroupReduceOperator<>(mapPartitionOperator, input.getType(), sampleInCoordinator, callLocation); } + // -------------------------------------------------------------------------------------------- + // Partition + // -------------------------------------------------------------------------------------------- + + /** + * Range-partitions a DataSet on the specified tuple field positions. + */ + public static PartitionOperator partitionByRange(DataSet input, DataDistribution distribution, int... fields) { + return new PartitionOperator<>(input, PartitionOperatorBase.PartitionMethod.RANGE, new Keys.ExpressionKeys<>(fields, input.getType(), false), distribution, Utils.getCallLocationName()); + } + + /** + * Range-partitions a DataSet on the specified fields. + */ + public static PartitionOperator partitionByRange(DataSet input, DataDistribution distribution, String... fields) { + return new PartitionOperator<>(input, PartitionOperatorBase.PartitionMethod.RANGE, new Keys.ExpressionKeys<>(fields, input.getType()), distribution, Utils.getCallLocationName()); + } + + /** + * Range-partitions a DataSet using the specified key selector function. + */ + public static > PartitionOperator partitionByRange(DataSet input, DataDistribution distribution, KeySelector keyExtractor) { + final TypeInformation keyType = TypeExtractor.getKeySelectorTypes(keyExtractor, input.getType()); + return new PartitionOperator<>(input, PartitionOperatorBase.PartitionMethod.RANGE, new Keys.SelectorFunctionKeys<>(input.clean(keyExtractor), input.getType(), keyType), distribution, Utils.getCallLocationName()); + } + // -------------------------------------------------------------------------------------------- // Checksum // -------------------------------------------------------------------------------------------- diff --git a/flink-optimizer/src/main/java/org/apache/flink/optimizer/dag/PartitionNode.java b/flink-optimizer/src/main/java/org/apache/flink/optimizer/dag/PartitionNode.java index 65a6e0237868a..9ecea6b061716 100644 --- a/flink-optimizer/src/main/java/org/apache/flink/optimizer/dag/PartitionNode.java +++ b/flink-optimizer/src/main/java/org/apache/flink/optimizer/dag/PartitionNode.java @@ -22,6 +22,7 @@ import java.util.Collections; import java.util.List; +import org.apache.flink.api.common.distributions.DataDistribution; import org.apache.flink.api.common.functions.Partitioner; import org.apache.flink.api.common.operators.Order; import org.apache.flink.api.common.operators.Ordering; @@ -51,7 +52,7 @@ public PartitionNode(PartitionOperatorBase operator) { super(operator); OperatorDescriptorSingle descr = new PartitionDescriptor( - this.getOperator().getPartitionMethod(), this.keys, operator.getCustomPartitioner()); + this.getOperator().getPartitionMethod(), this.keys, operator.getCustomPartitioner(), operator.getDistribution()); this.possibleProperties = Collections.singletonList(descr); } @@ -88,12 +89,14 @@ public static class PartitionDescriptor extends OperatorDescriptorSingle { private final PartitionMethod pMethod; private final Partitioner customPartitioner; + private final DataDistribution distribution; - public PartitionDescriptor(PartitionMethod pMethod, FieldSet pKeys, Partitioner customPartitioner) { + public PartitionDescriptor(PartitionMethod pMethod, FieldSet pKeys, Partitioner customPartitioner, DataDistribution distribution) { super(pKeys); this.pMethod = pMethod; this.customPartitioner = customPartitioner; + this.distribution = distribution; } @Override @@ -127,7 +130,7 @@ protected List createPossibleGlobalProperties() { for (int field : this.keys) { ordering.appendOrdering(field, null, Order.ASCENDING); } - rgps.setRangePartitioned(ordering); + rgps.setRangePartitioned(ordering, distribution); break; default: throw new IllegalArgumentException("Invalid partition method"); diff --git a/flink-optimizer/src/main/java/org/apache/flink/optimizer/dataproperties/GlobalProperties.java b/flink-optimizer/src/main/java/org/apache/flink/optimizer/dataproperties/GlobalProperties.java index 57ba29d06b1f8..ca17c2ba19fc2 100644 --- a/flink-optimizer/src/main/java/org/apache/flink/optimizer/dataproperties/GlobalProperties.java +++ b/flink-optimizer/src/main/java/org/apache/flink/optimizer/dataproperties/GlobalProperties.java @@ -22,6 +22,7 @@ import java.util.Set; import org.apache.flink.api.common.ExecutionMode; +import org.apache.flink.api.common.distributions.DataDistribution; import org.apache.flink.api.common.functions.Partitioner; import org.apache.flink.api.common.operators.Order; import org.apache.flink.api.common.operators.Ordering; @@ -55,6 +56,8 @@ public class GlobalProperties implements Cloneable { private Partitioner customPartitioner; + private DataDistribution distribution; + // -------------------------------------------------------------------------------------------- /** @@ -80,16 +83,38 @@ public void setHashPartitioned(FieldList partitionedFields) { this.partitioningFields = partitionedFields; this.ordering = null; } - + /** + * Set the parameters for range partition. + * + * @param ordering Order of the partitioned fields + */ public void setRangePartitioned(Ordering ordering) { if (ordering == null) { throw new NullPointerException(); } + + this.partitioning = PartitioningProperty.RANGE_PARTITIONED; + this.ordering = ordering; + this.partitioningFields = ordering.getInvolvedIndexes(); + } + + /** + * Set the parameters for range partition. + * + * @param ordering Order of the partitioned fields + * @param distribution The data distribution for range partition. User can supply a customized data distribution, + * also the data distribution can be null. + */ + public void setRangePartitioned(Ordering ordering, DataDistribution distribution) { + if (ordering == null) { + throw new NullPointerException(); + } this.partitioning = PartitioningProperty.RANGE_PARTITIONED; this.ordering = ordering; this.partitioningFields = ordering.getInvolvedIndexes(); + this.distribution = distribution; } public void setAnyPartitioning(FieldList partitionedFields) { @@ -167,6 +192,10 @@ public Partitioner getCustomPartitioner() { return this.customPartitioner; } + public DataDistribution getDataDistribution() { + return this.distribution; + } + // -------------------------------------------------------------------------------------------- public boolean isPartitionedOnFields(FieldSet fields) { diff --git a/flink-optimizer/src/main/java/org/apache/flink/optimizer/plan/Channel.java b/flink-optimizer/src/main/java/org/apache/flink/optimizer/plan/Channel.java index 508cc9505acdc..bd2a5949fbb14 100644 --- a/flink-optimizer/src/main/java/org/apache/flink/optimizer/plan/Channel.java +++ b/flink-optimizer/src/main/java/org/apache/flink/optimizer/plan/Channel.java @@ -427,7 +427,7 @@ public GlobalProperties getGlobalProperties() { this.globalProps.setHashPartitioned(this.shipKeys); break; case PARTITION_RANGE: - this.globalProps.setRangePartitioned(Utils.createOrdering(this.shipKeys, this.shipSortOrder)); + this.globalProps.setRangePartitioned(Utils.createOrdering(this.shipKeys, this.shipSortOrder), this.dataDistribution); break; case FORWARD: break; diff --git a/flink-optimizer/src/main/java/org/apache/flink/optimizer/traversals/RangePartitionRewriter.java b/flink-optimizer/src/main/java/org/apache/flink/optimizer/traversals/RangePartitionRewriter.java index 7656dfd172901..b1c5dae7c606b 100644 --- a/flink-optimizer/src/main/java/org/apache/flink/optimizer/traversals/RangePartitionRewriter.java +++ b/flink-optimizer/src/main/java/org/apache/flink/optimizer/traversals/RangePartitionRewriter.java @@ -109,14 +109,16 @@ public void postVisit(PlanNode node) { // Make sure we only optimize the DAG for range partition, and do not optimize multi times. if (shipStrategy == ShipStrategyType.PARTITION_RANGE) { - if(node.isOnDynamicPath()) { - throw new InvalidProgramException("Range Partitioning not supported within iterations."); + if(channel.getDataDistribution() == null) { + if (node.isOnDynamicPath()) { + throw new InvalidProgramException("Range Partitioning not supported within iterations if users do not supply the data distribution."); + } + + PlanNode channelSource = channel.getSource(); + List newSourceOutputChannels = rewriteRangePartitionChannel(channel); + channelSource.getOutgoingChannels().remove(channel); + channelSource.getOutgoingChannels().addAll(newSourceOutputChannels); } - - PlanNode channelSource = channel.getSource(); - List newSourceOutputChannels = rewriteRangePartitionChannel(channel); - channelSource.getOutgoingChannels().remove(channel); - channelSource.getOutgoingChannels().addAll(newSourceOutputChannels); } } } diff --git a/flink-optimizer/src/test/java/org/apache/flink/optimizer/dataproperties/MockDistribution.java b/flink-optimizer/src/test/java/org/apache/flink/optimizer/dataproperties/MockDistribution.java index 483bc514b8606..a35f0d0d0354c 100644 --- a/flink-optimizer/src/test/java/org/apache/flink/optimizer/dataproperties/MockDistribution.java +++ b/flink-optimizer/src/test/java/org/apache/flink/optimizer/dataproperties/MockDistribution.java @@ -19,6 +19,8 @@ package org.apache.flink.optimizer.dataproperties; import org.apache.flink.api.common.distributions.DataDistribution; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataOutputView; @@ -37,6 +39,11 @@ public int getNumberOfFields() { return 0; } + @Override + public TypeInformation[] getKeyTypes() { + return new TypeInformation[]{BasicTypeInfo.INT_TYPE_INFO}; + } + @Override public void write(DataOutputView out) throws IOException { diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/utils/package.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/utils/package.scala index 6407093cb58d2..adad9ab4ade34 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/utils/package.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/utils/package.scala @@ -19,9 +19,16 @@ package org.apache.flink.api.scala import org.apache.flink.annotation.PublicEvolving +import org.apache.flink.api.common.distributions.DataDistribution +import org.apache.flink.api.common.operators.Keys +import org.apache.flink.api.common.operators.base.PartitionOperatorBase +import org.apache.flink.api.common.operators.base.PartitionOperatorBase.PartitionMethod import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.java.Utils import org.apache.flink.api.java.Utils.ChecksumHashCode +import org.apache.flink.api.java.functions.KeySelector +import org.apache.flink.api.java.operators.PartitionOperator +import org.apache.flink.api.java.typeutils.TypeExtractor import org.apache.flink.api.java.utils.{DataSetUtils => jutils} import org.apache.flink.util.AbstractID @@ -109,6 +116,59 @@ package object utils { wrap(jutils.sampleWithSize(self.javaSet, withReplacement, numSamples, seed)) } + // -------------------------------------------------------------------------------------------- + // Partitioning + // -------------------------------------------------------------------------------------------- + + /** + * Range-partitions a DataSet on the specified tuple field positions. + */ + def partitionByRange(distribution: DataDistribution, fields: Int*): DataSet[T] = { + val op = new PartitionOperator[T]( + self.javaSet, + PartitionMethod.RANGE, + new Keys.ExpressionKeys[T](fields.toArray, self.javaSet.getType), + distribution, + getCallLocationName()) + wrap(op) + } + + /** + * Range-partitions a DataSet on the specified fields. + */ + def partitionByRange(distribution: DataDistribution, + firstField: String, + otherFields: String*): DataSet[T] = { + val op = new PartitionOperator[T]( + self.javaSet, + PartitionMethod.RANGE, + new Keys.ExpressionKeys[T](firstField +: otherFields.toArray, self.javaSet.getType), + distribution, + getCallLocationName()) + wrap(op) + } + + /** + * Range-partitions a DataSet using the specified key selector function. + */ + def partitionByRange[K: TypeInformation](distribution: DataDistribution, + fun: T => K): DataSet[T] = { + val keyExtractor = new KeySelector[T, K] { + val cleanFun = self.javaSet.clean(fun) + def getKey(in: T) = cleanFun(in) + } + val op = new PartitionOperator[T]( + self.javaSet, + PartitionMethod.RANGE, + new Keys.SelectorFunctionKeys[T, K]( + keyExtractor, + self.javaSet.getType, + implicitly[TypeInformation[K]]), + distribution, + getCallLocationName()) + wrap(op) + } + // -------------------------------------------------------------------------------------------- // Checksum // -------------------------------------------------------------------------------------------- diff --git a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/CustomDistributionITCase.java b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/CustomDistributionITCase.java new file mode 100644 index 0000000000000..062800fb360a3 --- /dev/null +++ b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/CustomDistributionITCase.java @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.test.javaApiOperators; + +import org.apache.flink.api.common.distributions.DataDistribution; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.io.DiscardingOutputFormat; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.utils.DataSetUtils; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.test.javaApiOperators.util.CollectionDataSets; +import org.apache.flink.util.Collector; +import org.junit.Test; + + +import java.io.IOException; + +import static org.junit.Assert.fail; + + +public class CustomDistributionITCase { + + @Test + public void testPartitionWithDistribution1() throws Exception{ + /* + * Test the record partitioned rightly with one field according to the customized data distribution + */ + + ExecutionEnvironment env = ExecutionEnvironment.createLocalEnvironment(); + + DataSet> input1 = CollectionDataSets.get3TupleDataSet(env); + final TestDataDist1 dist = new TestDataDist1(); + + env.setParallelism(dist.getParallelism()); + + DataSet result = DataSetUtils.partitionByRange(input1, dist, 0).mapPartition(new RichMapPartitionFunction, Boolean>() { + @Override + public void mapPartition(Iterable> values, Collector out) throws Exception { + int partitionIndex = getRuntimeContext().getIndexOfThisSubtask(); + + for (Tuple3 s : values) { + if ((s.f0 - 1) / 7 != partitionIndex) { + fail("Record was not correctly partitioned: " + s.toString()); + } + } + } + }); + + result.output(new DiscardingOutputFormat()); + env.execute(); + } + + @Test + public void testRangeWithDistribution2() throws Exception{ + /* + * Test the record partitioned rightly with two fields according to the customized data distribution + */ + + ExecutionEnvironment env = ExecutionEnvironment.createLocalEnvironment(); + + DataSet> input1 = env.fromElements( + new Tuple3<>(1, 5L, "Hi"), + new Tuple3<>(1, 11L, "Hello"), + new Tuple3<>(2, 3L, "World"), + new Tuple3<>(2, 13L, "Hello World"), + new Tuple3<>(3, 8L, "Say"), + new Tuple3<>(4, 0L, "Why"), + new Tuple3<>(4, 2L, "Java"), + new Tuple3<>(4, 11L, "Say Hello"), + new Tuple3<>(5, 2L, "Hi Java")); + + final TestDataDist2 dist = new TestDataDist2(); + + env.setParallelism(dist.getParallelism()); + + DataSet result = DataSetUtils.partitionByRange(input1.map(new MapFunction, Tuple3>() { + @Override + public Tuple3 map(Tuple3 value) throws Exception { + return new Tuple3<>(value.f0, value.f1.intValue(), value.f2); + } + }), dist, 0, 1).mapPartition(new RichMapPartitionFunction, Boolean>() { + @Override + public void mapPartition(Iterable> values, Collector out) throws Exception { + int partitionIndex = getRuntimeContext().getIndexOfThisSubtask(); + boolean checkPartiton = true; + + for (Tuple3 s : values) { + + if (partitionIndex == 0) { + if (s.f0 > partitionIndex + 1 || (s.f0 == partitionIndex + 1 && s.f1 > dist.rightBoundary[partitionIndex])) { + checkPartiton = false; + } + } + else if (partitionIndex > 0 || partitionIndex < dist.getParallelism() - 1) { + if (s.f0 > partitionIndex + 1 || (s.f0 == partitionIndex + 1 && s.f1 > dist.rightBoundary[partitionIndex]) || + s.f0 < partitionIndex || (s.f0 == partitionIndex && s.f1 < dist.rightBoundary[partitionIndex - 1])) { + checkPartiton = false; + } + } + else { + if (s.f0 < partitionIndex || (s.f0 == partitionIndex && s.f1 < dist.rightBoundary[partitionIndex - 1])) { + checkPartiton = false; + } + } + + if (!checkPartiton) { + fail("Record was not correctly partitioned: " + s.toString()); + } + } + } + }); + + result.output(new DiscardingOutputFormat()); + env.execute(); + } + + /** + * The class is used to do the tests of range partition with one key. + */ + public static class TestDataDist1 implements DataDistribution { + + /** + * Constructor of the customized distribution for range partition. + */ + public TestDataDist1() {} + + public int getParallelism() { + return 3; + } + + @Override + public Object[] getBucketBoundary(int bucketNum, int totalNumBuckets) { + + /* + for the first test, the boundary is just like : + (0, 7] + (7, 14] + (14, 21] + */ + return new Integer[]{(bucketNum + 1) * 7}; + } + + @Override + public int getNumberOfFields() { + return 1; + } + + @Override + public TypeInformation[] getKeyTypes() { + return new TypeInformation[]{BasicTypeInfo.INT_TYPE_INFO}; + } + + @Override + public void write(DataOutputView out) throws IOException { + + } + + @Override + public void read(DataInputView in) throws IOException { + + } + } + + /** + * The class is used to do the tests of range partition with two keys. + */ + public static class TestDataDist2 implements DataDistribution { + + public int rightBoundary[] = new int[]{6, 4, 9, 1, 2}; + + /** + * Constructor of the customized distribution for range partition. + */ + public TestDataDist2() {} + + public int getParallelism() { + return 5; + } + + @Override + public Object[] getBucketBoundary(int bucketNum, int totalNumBuckets) { + + /* + for the second test, the boundary is just like : + ((0, 0), (1, 6)] + ((1, 6), (2, 4)] + ((2, 4), (3, 9)] + ((3, 9), (4, 1)] + ((4, 1), (5, 2)] + */ + return new Integer[]{bucketNum + 1, rightBoundary[bucketNum]}; + } + + @Override + public int getNumberOfFields() { + return 2; + } + + @Override + public TypeInformation[] getKeyTypes() { + return new TypeInformation[]{BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO}; + } + + @Override + public void write(DataOutputView out) throws IOException { + + } + + @Override + public void read(DataInputView in) throws IOException { + + } + } +}