Skip to content

Commit

Permalink
[FLINK-3657] [dataSet] Change access of DataSetUtils.countElements() …
Browse files Browse the repository at this point in the history
…to 'public'

This closes #1829
  • Loading branch information
smarthi authored and fhueske committed Apr 15, 2016
1 parent d938c5f commit 5f993c6
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 9 deletions.
Expand Up @@ -60,15 +60,14 @@ public final class DataSetUtils {
* @param input the DataSet received as input * @param input the DataSet received as input
* @return a data set containing tuples of subtask index, number of elements mappings. * @return a data set containing tuples of subtask index, number of elements mappings.
*/ */
private static <T> DataSet<Tuple2<Integer, Long>> countElements(DataSet<T> input) { public static <T> DataSet<Tuple2<Integer, Long>> countElementsPerPartition(DataSet<T> input) {
return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Integer, Long>>() { return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Integer, Long>>() {
@Override @Override
public void mapPartition(Iterable<T> values, Collector<Tuple2<Integer, Long>> out) throws Exception { public void mapPartition(Iterable<T> values, Collector<Tuple2<Integer, Long>> out) throws Exception {
long counter = 0; long counter = 0;
for (T value : values) { for (T value : values) {
counter++; counter++;
} }

out.collect(new Tuple2<>(getRuntimeContext().getIndexOfThisSubtask(), counter)); out.collect(new Tuple2<>(getRuntimeContext().getIndexOfThisSubtask(), counter));
} }
}); });
Expand All @@ -83,7 +82,7 @@ public void mapPartition(Iterable<T> values, Collector<Tuple2<Integer, Long>> ou
*/ */
public static <T> DataSet<Tuple2<Long, T>> zipWithIndex(DataSet<T> input) { public static <T> DataSet<Tuple2<Long, T>> zipWithIndex(DataSet<T> input) {


DataSet<Tuple2<Integer, Long>> elementCount = countElements(input); DataSet<Tuple2<Integer, Long>> elementCount = countElementsPerPartition(input);


return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Long, T>>() { return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Long, T>>() {


Expand Down
Expand Up @@ -47,6 +47,20 @@ package object utils {
@PublicEvolving @PublicEvolving
implicit class DataSetUtils[T: TypeInformation : ClassTag](val self: DataSet[T]) { implicit class DataSetUtils[T: TypeInformation : ClassTag](val self: DataSet[T]) {


/**
* Method that goes over all the elements in each partition in order to retrieve
* the total number of elements.
*
* @return a data set of tuple2 consisting of (subtask index, number of elements mappings)
*/
def countElementsPerPartition: DataSet[(Int, Long)] = {
implicit val typeInfo = createTuple2TypeInformation[Int, Long](
BasicTypeInfo.INT_TYPE_INFO.asInstanceOf[TypeInformation[Int]],
BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]]
)
wrap(jutils.countElementsPerPartition(self.javaSet)).map { t => (t.f0.toInt, t.f1.toLong)}
}

/** /**
* Method that takes a set of subtask index, total number of elements mappings * Method that takes a set of subtask index, total number of elements mappings
* and assigns ids to all the elements from the input data set. * and assigns ids to all the elements from the input data set.
Expand Down
Expand Up @@ -18,8 +18,6 @@


package org.apache.flink.test.util; package org.apache.flink.test.util;


import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.ExecutionEnvironment;
Expand All @@ -32,8 +30,10 @@
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.runners.Parameterized;


import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.Comparator; import java.util.Comparator;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;


Expand All @@ -44,13 +44,25 @@ public DataSetUtilsITCase(TestExecutionMode mode) {
super(mode); super(mode);
} }


@Test
public void testCountElementsPerPartition() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
long expectedSize = 100L;
DataSet<Long> numbers = env.generateSequence(0, expectedSize - 1);

DataSet<Tuple2<Integer, Long>> ds = DataSetUtils.countElementsPerPartition(numbers);

Assert.assertEquals(env.getParallelism(), ds.count());
Assert.assertEquals(expectedSize, ds.sum(1).collect().get(0).f1.longValue());
}

@Test @Test
public void testZipWithIndex() throws Exception { public void testZipWithIndex() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
long expectedSize = 100L; long expectedSize = 100L;
DataSet<Long> numbers = env.generateSequence(0, expectedSize - 1); DataSet<Long> numbers = env.generateSequence(0, expectedSize - 1);


List<Tuple2<Long, Long>> result = Lists.newArrayList(DataSetUtils.zipWithIndex(numbers).collect()); List<Tuple2<Long, Long>> result = new ArrayList<>(DataSetUtils.zipWithIndex(numbers).collect());


Assert.assertEquals(expectedSize, result.size()); Assert.assertEquals(expectedSize, result.size());
// sort result by created index // sort result by created index
Expand Down Expand Up @@ -79,7 +91,7 @@ public Long map(Tuple2<Long, Long> value) throws Exception {
} }
}); });


Set<Long> result = Sets.newHashSet(ids.collect()); Set<Long> result = new HashSet<>(ids.collect());


Assert.assertEquals(expectedSize, result.size()); Assert.assertEquals(expectedSize, result.size());
} }
Expand Down
Expand Up @@ -34,7 +34,7 @@ class DataSetUtilsITCase (
@Test @Test
@throws(classOf[Exception]) @throws(classOf[Exception])
def testZipWithIndex(): Unit = { def testZipWithIndex(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment


val expectedSize = 100L val expectedSize = 100L


Expand All @@ -52,7 +52,7 @@ class DataSetUtilsITCase (
@Test @Test
@throws(classOf[Exception]) @throws(classOf[Exception])
def testZipWithUniqueId(): Unit = { def testZipWithUniqueId(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment


val expectedSize = 100L val expectedSize = 100L


Expand All @@ -73,4 +73,19 @@ class DataSetUtilsITCase (
Assert.assertEquals(checksum.getCount, 15) Assert.assertEquals(checksum.getCount, 15)
Assert.assertEquals(checksum.getChecksum, 55) Assert.assertEquals(checksum.getChecksum, 55)
} }

@Test
@throws(classOf[Exception])
def testCountElementsPerPartition(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment

val expectedSize = 100L

val numbers = env.generateSequence(0, expectedSize - 1)

val ds = numbers.countElementsPerPartition

Assert.assertEquals(env.getParallelism, ds.collect().size)
Assert.assertEquals(expectedSize, ds.sum(1).collect().head._2)
}
} }

0 comments on commit 5f993c6

Please sign in to comment.