From 3832f2137676a76d6d06a0bb6dbcedcba801910b Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 8 Sep 2018 15:30:49 +0200 Subject: [PATCH 1/7] Adding overloaded sampleBy with Column type --- .../spark/sql/DataFrameStatFunctions.scala | 49 +++++++++++++++---- .../apache/spark/sql/DataFrameStatSuite.scala | 20 +++++++- 2 files changed, 59 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index a41753098966e..72cf58bc8f901 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -370,15 +370,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @since 1.5.0 */ def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = { - require(fractions.values.forall(p => p >= 0.0 && p <= 1.0), - s"Fractions must be in [0, 1], but got $fractions.") - import org.apache.spark.sql.functions.{rand, udf} - val c = Column(col) - val r = rand(seed) - val f = udf { (stratum: Any, x: Double) => - x < fractions.getOrElse(stratum.asInstanceOf[T], 0.0) - } - df.filter(f(c, r)) + sampleBy(Column(col), fractions, seed) } /** @@ -396,6 +388,45 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) } + /** + * Returns a stratified sample without replacement based on the fraction given on each stratum. + * @param col column that defines strata + * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat + * its fraction as zero. + * @param seed random seed + * @tparam T stratum type + * @return a new `DataFrame` that represents the stratified sample + * + * The stratified sample can be performed over multiple columns: + * {{{ + * import org.apache.spark.sql.Row + * import org.apache.spark.sql.functions.struct + * + * val df = spark.createDataFrame(Seq(("Bob", 17), ("Alice", 10), ("Nico", 8), ("Bob", 17), + * ("Alice", 10))).toDF("name", "age") + * val fractions = Map(Row("Alice", 10) -> 0.3, Row("Nico", 8) -> 1.0) + * df.stat.sampleBy(struct($"name", $"age"), fractions, 36L).show() + * +-----+---+ + * | name|age| + * +-----+---+ + * | Nico| 8| + * |Alice| 10| + * +-----+---+ + * }}} + * + * @since 3.0.0 + */ + def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): DataFrame = { + require(fractions.values.forall(p => p >= 0.0 && p <= 1.0), + s"Fractions must be in [0, 1], but got $fractions.") + import org.apache.spark.sql.functions.{rand, udf} + val r = rand(seed) + val f = udf { (stratum: Any, x: Double) => + x < fractions.getOrElse(stratum.asInstanceOf[T], 0.0) + } + df.filter(f(col, r)) + } + /** * Builds a Count-min Sketch over a specified column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 8eae35325faea..589873b9c3ea4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -23,7 +23,7 @@ import org.scalatest.Matchers._ import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.stat.StatFunctions -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.{col, lit, struct} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{DoubleType, StructField, StructType} @@ -374,6 +374,24 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { Seq(Row(0, 6), Row(1, 11))) } + test("sampleBy one column") { + val df = spark.range(0, 100).select((col("id") % 3).as("key")) + val sampled = df.stat.sampleBy($"key", Map(0 -> 0.1, 1 -> 0.2), 0L) + checkAnswer( + sampled.groupBy("key").count().orderBy("key"), + Seq(Row(0, 6), Row(1, 11))) + } + + test("sampleBy multiple columns") { + val df = spark.range(0, 100) + .select(lit("Foo").as("name"), (col("id") % 3).as("key")) + val sampled = df.stat.sampleBy( + struct($"name", $"key"), Map(Row("Foo", 0) -> 0.1, Row("Foo", 1) -> 0.2), 0L) + checkAnswer( + sampled.groupBy("key").count().orderBy("key"), + Seq(Row(0, 6), Row(1, 11))) + } + // This test case only verifies that `DataFrame.countMinSketch()` methods do return // `CountMinSketch`es that meet required specs. Test cases for `CountMinSketch` can be found in // `CountMinSketchSuite` in project spark-sketch. From 5cd3229ce8bfe894dac8ebc097109da237d95401 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 8 Sep 2018 15:39:30 +0200 Subject: [PATCH 2/7] Adding overloaded sampleBy with Column type for Java --- .../spark/sql/DataFrameStatFunctions.scala | 16 ++++++++++++++++ .../org/apache/spark/sql/JavaDataFrameSuite.java | 11 +++++++++++ 2 files changed, 27 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 72cf58bc8f901..7c12432d33c33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -427,6 +427,22 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { df.filter(f(col, r)) } + /** + * (Java-specific) Returns a stratified sample without replacement based on the fraction given + * on each stratum. + * @param col column that defines strata + * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat + * its fraction as zero. + * @param seed random seed + * @tparam T stratum type + * @return a new `DataFrame` that represents the stratified sample + * + * @since 3.0.0 + */ + def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { + sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) + } + /** * Builds a Count-min Sketch over a specified column. * diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 69a2904f5f3fe..3f37e5814ccaa 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -290,6 +290,17 @@ public void testSampleBy() { Assert.assertTrue(2 <= actual.get(1).getLong(1) && actual.get(1).getLong(1) <= 13); } + @Test + public void testSampleByColumn() { + Dataset df = spark.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); + Dataset sampled = df.stat().sampleBy(col("key"), ImmutableMap.of(0, 0.1, 1, 0.2), 0L); + List actual = sampled.groupBy("key").count().orderBy("key").collectAsList(); + Assert.assertEquals(0, actual.get(0).getLong(0)); + Assert.assertTrue(0 <= actual.get(0).getLong(1) && actual.get(0).getLong(1) <= 8); + Assert.assertEquals(1, actual.get(1).getLong(0)); + Assert.assertTrue(2 <= actual.get(1).getLong(1) && actual.get(1).getLong(1) <= 13); + } + @Test public void pivot() { Dataset df = spark.table("courseSales"); From e2e61498c47da9d7b36d2e0727ce8642d5d71472 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 8 Sep 2018 16:56:36 +0200 Subject: [PATCH 3/7] Adding overloaded sampleBy with Column type for Python --- python/pyspark/sql/dataframe.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 1affc9b4fcf6c..cfcc0b15b4b51 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -880,10 +880,14 @@ def sampleBy(self, col, fractions, seed=None): | 0| 5| | 1| 9| +---+-----+ + >>> dataset.sampleBy(col("key"), fractions={2: 1.0}, seed=0).count() + 33 """ - if not isinstance(col, basestring): - raise ValueError("col must be a string, but got %r" % type(col)) + if isinstance(col, basestring): + col = Column(col) + elif not isinstance(col, Column): + raise ValueError("col must be a string or a column, but got %r" % type(col)) if not isinstance(fractions, dict): raise ValueError("fractions must be a dict but got %r" % type(fractions)) for k, v in fractions.items(): @@ -891,7 +895,7 @@ def sampleBy(self, col, fractions, seed=None): raise ValueError("key must be float, int, long, or string, but got %r" % type(k)) fractions[k] = float(v) seed = seed if seed is not None else random.randint(0, sys.maxsize) - return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx) + return DataFrame(self._jdf.stat().sampleBy(col._jc, self._jmap(fractions), seed), self.sql_ctx) @since(1.4) def randomSplit(self, weights, seed=None): From 7e7794153924b824dc5fe5f05375c8b9950ef539 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 8 Sep 2018 18:26:31 +0200 Subject: [PATCH 4/7] Making Python code style checker happy --- python/pyspark/sql/dataframe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index cfcc0b15b4b51..1e6a5009e7875 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -895,7 +895,8 @@ def sampleBy(self, col, fractions, seed=None): raise ValueError("key must be float, int, long, or string, but got %r" % type(k)) fractions[k] = float(v) seed = seed if seed is not None else random.randint(0, sys.maxsize) - return DataFrame(self._jdf.stat().sampleBy(col._jc, self._jmap(fractions), seed), self.sql_ctx) + return DataFrame(self._jdf.stat()\ + .sampleBy(col._jc, self._jmap(fractions), seed), self.sql_ctx) @since(1.4) def randomSplit(self, weights, seed=None): From 2845bca09797a34e930e6aca42f198ec5cbd95e3 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 8 Sep 2018 18:40:55 +0200 Subject: [PATCH 5/7] Removing unneeded backslash --- python/pyspark/sql/dataframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 1e6a5009e7875..c853907c8643b 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -895,7 +895,7 @@ def sampleBy(self, col, fractions, seed=None): raise ValueError("key must be float, int, long, or string, but got %r" % type(k)) fractions[k] = float(v) seed = seed if seed is not None else random.randint(0, sys.maxsize) - return DataFrame(self._jdf.stat()\ + return DataFrame(self._jdf.stat() .sampleBy(col._jc, self._jmap(fractions), seed), self.sql_ctx) @since(1.4) From e85175e18e95d7751748d4615792579375859786 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 10 Sep 2018 23:40:30 +0200 Subject: [PATCH 6/7] Addressing Hyukjin Kwon's review comments --- python/pyspark/sql/dataframe.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index c853907c8643b..bf6b990487617 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -883,6 +883,8 @@ def sampleBy(self, col, fractions, seed=None): >>> dataset.sampleBy(col("key"), fractions={2: 1.0}, seed=0).count() 33 + .. versionchanged:: 3.0 + Added sampling by a column of :class:`Column` """ if isinstance(col, basestring): col = Column(col) @@ -894,9 +896,9 @@ def sampleBy(self, col, fractions, seed=None): if not isinstance(k, (float, int, long, basestring)): raise ValueError("key must be float, int, long, or string, but got %r" % type(k)) fractions[k] = float(v) + col = col._jc seed = seed if seed is not None else random.randint(0, sys.maxsize) - return DataFrame(self._jdf.stat() - .sampleBy(col._jc, self._jmap(fractions), seed), self.sql_ctx) + return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx) @since(1.4) def randomSplit(self, weights, seed=None): From 1740d60a9bdc1c84b1d74d7637411396b9fbff75 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 20 Sep 2018 12:01:48 +0200 Subject: [PATCH 7/7] Re-targeting to 2.5 instead of 3.0 --- python/pyspark/sql/dataframe.py | 2 +- .../scala/org/apache/spark/sql/DataFrameStatFunctions.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index bf6b990487617..21bc69b8236fd 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -883,7 +883,7 @@ def sampleBy(self, col, fractions, seed=None): >>> dataset.sampleBy(col("key"), fractions={2: 1.0}, seed=0).count() 33 - .. versionchanged:: 3.0 + .. versionchanged:: 2.5 Added sampling by a column of :class:`Column` """ if isinstance(col, basestring): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 7c12432d33c33..75b84773bd0b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -414,7 +414,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * +-----+---+ * }}} * - * @since 3.0.0 + * @since 2.5.0 */ def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): DataFrame = { require(fractions.values.forall(p => p >= 0.0 && p <= 1.0), @@ -437,7 +437,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @tparam T stratum type * @return a new `DataFrame` that represents the stratified sample * - * @since 3.0.0 + * @since 2.5.0 */ def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed)