diff --git a/flink-java/src/main/java/org/apache/flink/api/java/utils/DataSetUtils.java b/flink-java/src/main/java/org/apache/flink/api/java/utils/DataSetUtils.java index 61a71aa32a75b..756bb55776ea1 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/utils/DataSetUtils.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/utils/DataSetUtils.java @@ -60,7 +60,7 @@ public final class DataSetUtils { * @param input the DataSet received as input * @return a data set containing tuples of subtask index, number of elements mappings. */ - private static DataSet> countElements(DataSet input) { + public static DataSet> countElementsPerPartition(DataSet input) { return input.mapPartition(new RichMapPartitionFunction>() { @Override public void mapPartition(Iterable values, Collector> out) throws Exception { @@ -68,7 +68,6 @@ public void mapPartition(Iterable values, Collector> ou for (T value : values) { counter++; } - out.collect(new Tuple2<>(getRuntimeContext().getIndexOfThisSubtask(), counter)); } }); @@ -83,7 +82,7 @@ public void mapPartition(Iterable values, Collector> ou */ public static DataSet> zipWithIndex(DataSet input) { - DataSet> elementCount = countElements(input); + DataSet> elementCount = countElementsPerPartition(input); return input.mapPartition(new RichMapPartitionFunction>() { diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/utils/package.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/utils/package.scala index adad9ab4ade34..d543998c1eb5e 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/utils/package.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/utils/package.scala @@ -47,6 +47,20 @@ package object utils { @PublicEvolving implicit class DataSetUtils[T: TypeInformation : ClassTag](val self: DataSet[T]) { + /** + * Method that goes over all the elements in each partition in order to retrieve + * the total number of elements. + * + * @return a data set of tuple2 consisting of (subtask index, number of elements mappings) + */ + def countElementsPerPartition: DataSet[(Int, Long)] = { + implicit val typeInfo = createTuple2TypeInformation[Int, Long]( + BasicTypeInfo.INT_TYPE_INFO.asInstanceOf[TypeInformation[Int]], + BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]] + ) + wrap(jutils.countElementsPerPartition(self.javaSet)).map { t => (t.f0.toInt, t.f1.toLong)} + } + /** * Method that takes a set of subtask index, total number of elements mappings * and assigns ids to all the elements from the input data set. diff --git a/flink-tests/src/test/java/org/apache/flink/test/util/DataSetUtilsITCase.java b/flink-tests/src/test/java/org/apache/flink/test/util/DataSetUtilsITCase.java index 4ccc6e24ba006..afbcb8901a206 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/util/DataSetUtilsITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/util/DataSetUtilsITCase.java @@ -18,8 +18,6 @@ package org.apache.flink.test.util; -import com.google.common.collect.Lists; -import com.google.common.collect.Sets; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; @@ -32,8 +30,10 @@ import org.junit.runner.RunWith; import org.junit.runners.Parameterized; +import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; +import java.util.HashSet; import java.util.List; import java.util.Set; @@ -44,13 +44,25 @@ public DataSetUtilsITCase(TestExecutionMode mode) { super(mode); } + @Test + public void testCountElementsPerPartition() throws Exception { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + long expectedSize = 100L; + DataSet numbers = env.generateSequence(0, expectedSize - 1); + + DataSet> ds = DataSetUtils.countElementsPerPartition(numbers); + + Assert.assertEquals(env.getParallelism(), ds.count()); + Assert.assertEquals(expectedSize, ds.sum(1).collect().get(0).f1.longValue()); + } + @Test public void testZipWithIndex() throws Exception { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); long expectedSize = 100L; DataSet numbers = env.generateSequence(0, expectedSize - 1); - List> result = Lists.newArrayList(DataSetUtils.zipWithIndex(numbers).collect()); + List> result = new ArrayList<>(DataSetUtils.zipWithIndex(numbers).collect()); Assert.assertEquals(expectedSize, result.size()); // sort result by created index @@ -79,7 +91,7 @@ public Long map(Tuple2 value) throws Exception { } }); - Set result = Sets.newHashSet(ids.collect()); + Set result = new HashSet<>(ids.collect()); Assert.assertEquals(expectedSize, result.size()); } diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/util/DataSetUtilsITCase.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/util/DataSetUtilsITCase.scala index 25ecc9c5bdc00..83dd2a4e8fe8c 100644 --- a/flink-tests/src/test/scala/org/apache/flink/api/scala/util/DataSetUtilsITCase.scala +++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/util/DataSetUtilsITCase.scala @@ -34,7 +34,7 @@ class DataSetUtilsITCase ( @Test @throws(classOf[Exception]) def testZipWithIndex(): Unit = { - val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val env = ExecutionEnvironment.getExecutionEnvironment val expectedSize = 100L @@ -52,7 +52,7 @@ class DataSetUtilsITCase ( @Test @throws(classOf[Exception]) def testZipWithUniqueId(): Unit = { - val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val env = ExecutionEnvironment.getExecutionEnvironment val expectedSize = 100L @@ -73,4 +73,19 @@ class DataSetUtilsITCase ( Assert.assertEquals(checksum.getCount, 15) Assert.assertEquals(checksum.getChecksum, 55) } + + @Test + @throws(classOf[Exception]) + def testCountElementsPerPartition(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + + val expectedSize = 100L + + val numbers = env.generateSequence(0, expectedSize - 1) + + val ds = numbers.countElementsPerPartition + + Assert.assertEquals(env.getParallelism, ds.collect().size) + Assert.assertEquals(expectedSize, ds.sum(1).collect().head._2) + } }