From 075e7ef3f27af91c5190d039770cf15b08a66c81 Mon Sep 17 00:00:00 2001 From: "Sachathamakul, Patrachai (Agoda)" Date: Sun, 8 Oct 2017 17:24:44 +0700 Subject: [PATCH 1/3] Added flatten functions for RDD and Dataset --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 7 +++++++ sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 5 +++++ 2 files changed, 12 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 8798dfc925362..af36b2d458854 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -381,6 +381,13 @@ abstract class RDD[T: ClassTag]( new MapPartitionsRDD[U, T](this, (context, pid, iter) => iter.flatMap(cleanF)) } + /** + * Return a new RDD by flattening a traversable collection into a collection itself. + */ + def flatten[U: ClassTag](implicit f: T => TraversableOnce[U]): RDD[U] = withScope { + this.flatMap(y => y) + } + /** * Return a new RDD containing only the elements that satisfy a predicate. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b70dfc05330f8..897c9b1c2a1a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2542,6 +2542,11 @@ class Dataset[T] private[sql]( def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] = mapPartitions(_.flatMap(func)) + /** + * Returns a new Dataset by by flattening a traversable collection into a collection itself. + */ + def flatten[U: Encoder](implicit func: T => TraversableOnce[U]): Dataset[U] = mapPartitions(_.flatMap(x => x)) + /** * :: Experimental :: * (Java-specific) From 261e45a9a2298df2d4d1f9adc1ca1ced22e90b60 Mon Sep 17 00:00:00 2001 From: "Sachathamakul, Patrachai (Agoda)" Date: Sun, 8 Oct 2017 19:12:28 +0700 Subject: [PATCH 2/3] Added @sincen 2.3.0 in Dataset flatten function Added unit test in both RDDSuite.scala and DatasetSuite.scala --- .../test/scala/org/apache/spark/rdd/RDDSuite.scala | 1 + .../src/main/scala/org/apache/spark/sql/Dataset.scala | 2 ++ .../scala/org/apache/spark/sql/DatasetSuite.scala | 11 +++++++++++ 3 files changed, 14 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index e994d724c462f..b62e0153642b5 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -63,6 +63,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { assert(nums.map(_.toString).collect().toList === List("1", "2", "3", "4")) assert(nums.filter(_ > 2).collect().toList === List(3, 4)) assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4)) + assert(sc.makeRDD(Array(Array(1,2,3,4), Array(1,2,3,4))).flatten == List(1,2,3,4,1,2,3,4)) assert(nums.union(nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4))) assert(nums.collect({ case i if i >= 3 => i.toString }).collect().toList === List("3", "4")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 897c9b1c2a1a9..83a5ba47e590b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2544,6 +2544,8 @@ class Dataset[T] private[sql]( /** * Returns a new Dataset by by flattening a traversable collection into a collection itself. + * + * @since 2.3.0 */ def flatten[U: Encoder](implicit func: T => TraversableOnce[U]): Dataset[U] = mapPartitions(_.flatMap(x => x)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index dace6825ee40e..34a51e4f06c36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1341,6 +1341,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Seq(1).toDS().map(_ => ("", TestForTypeAlias.seqOfTupleTypeAlias)), ("", Seq((1, 1), (2, 2)))) } + test("SPARK-22152: a dataset of sequence should be flattened"){ + val ds: Dataset[Seq[Int]] = Seq(Seq(1, 2, 3)).toDS + val out: Dataset[Int] = ds.flatten + assert(out.collect.toSeq == Seq(1, 2, 3)) + } + + test("SPARK-22152: a dataset of option elements should be flattened"){ + val ds: Dataset[Option[String]] = Seq(Some("a"),None,Some("b")).toDS + val out: Dataset[String] = ds.flatten + assert(out.collect.toSeq == Seq("a", "b")) + } } case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String]) From cc08623519f4ddfdfcc883557c4cc53f11e6f0f7 Mon Sep 17 00:00:00 2001 From: "Sachathamakul, Patrachai (Agoda)" Date: Mon, 9 Oct 2017 19:06:09 +0700 Subject: [PATCH 3/3] changed x => x to identity fixed style error --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 2 +- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index af36b2d458854..683b3d3874f53 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -385,7 +385,7 @@ abstract class RDD[T: ClassTag]( * Return a new RDD by flattening a traversable collection into a collection itself. */ def flatten[U: ClassTag](implicit f: T => TraversableOnce[U]): RDD[U] = withScope { - this.flatMap(y => y) + this.flatMap(identity(_)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 83a5ba47e590b..bcd16fc5cbca2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2547,7 +2547,8 @@ class Dataset[T] private[sql]( * * @since 2.3.0 */ - def flatten[U: Encoder](implicit func: T => TraversableOnce[U]): Dataset[U] = mapPartitions(_.flatMap(x => x)) + def flatten[U: Encoder](implicit func: T => TraversableOnce[U]): Dataset[U] = + mapPartitions(_.flatMap(x => x)) /** * :: Experimental ::