From 832f7cc34dec264d14e02ca0bff4373924e62372 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 11 Jun 2015 16:28:49 -0700 Subject: [PATCH 1/7] add sampleBy to DataFrame --- python/pyspark/sql/dataframe.py | 35 +++++++++++++++++++ .../org/apache/spark/sql/DataFrame.scala | 30 ++++++++++++---- .../org/apache/spark/sql/DataFrameSuite.scala | 8 +++++ 3 files changed, 67 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 9615e576497cd..68e92cb8fe943 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -444,6 +444,41 @@ def sample(self, withReplacement, fraction, seed=None): rdd = self._jdf.sample(withReplacement, fraction, long(seed)) return DataFrame(rdd, self.sql_ctx) + @since(1.5) + def sampleBy(self, col, fractions, seed=None): + """ + Returns a stratified sample without replacement based on the + fraction given on each stratum. + + :param col: column that defines strata + :param fractions: + sampling fraction for each stratum. If a stratum is not + specified, we treat its fraction as zero. + :param seed: random seed + :return: a new DataFrame that represents the stratified sample + + >>> from pyspark.sql.functions import col + >>> dataset = sqlContext.range(0, 100).select((col("id") % 3).alias("key")) + >>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0L) + >>> sampled.groupBy("key").count().orderBy("key").show() + +---+-----+ + |key|count| + +---+-----+ + | 0| 5| + | 1| 8| + +---+-----+ + """ + if not isinstance(col, str): + raise ValueError("col must be a string, but got %r" % type(col)) + if not isinstance(fractions, dict): + raise ValueError("fractions must be a dict but got %r" % type(fractions)) + for k, v in fractions.items(): + if not isinstance(k, (float, int, long, basestring)): + raise ValueError("key must be float, int, long, or string, but got %r" % type(k)) + fractions[k] = float(v) + seed = seed if seed is not None else random.randint(0, sys.maxsize) + return DataFrame(self._jdf.sampleBy(col, self._jmap(fractions), seed), self.sql_ctx) + @since(1.4) def randomSplit(self, weights, seed=None): """Randomly splits this :class:`DataFrame` with the provided weights. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 59f64dd4bc648..38218f4640469 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -18,9 +18,8 @@ package org.apache.spark.sql import java.io.CharArrayWriter -import java.util.Properties +import java.util.{Properties, UUID} -import scala.collection.JavaConversions._ import scala.language.implicitConversions import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag @@ -33,11 +32,11 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.analysis.{MultiAlias, ResolvedStar, UnresolvedAttribute, UnresolvedRelation} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} +import org.apache.spark.sql.catalyst.analysis.{MultiAlias, ResolvedStar, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Filter, _} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD} import org.apache.spark.sql.json.JacksonGenerator import org.apache.spark.sql.sources.CreateTableUsingAsSelect @@ -45,7 +44,6 @@ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils - private[sql] object DataFrame { def apply(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = { new DataFrame(sqlContext, logicalPlan) @@ -947,6 +945,26 @@ class DataFrame private[sql]( sample(withReplacement, fraction, Utils.random.nextLong) } + /** + * Returns a stratified sample without replacement based on the fraction given on each stratum. + * @param col column that defines strata + * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat + * its fraction as zero. + * @param seed random seed + * @return a new [[DataFrame]] that represents the stratified sample + */ + def sampleBy(col: String, fractions: Map[Any, Double], seed: Long): DataFrame = { + require(fractions.values.forall(p => p >= 0.0 && p <= 1.0), + s"Fractions must be in [0, 1], but got $fractions.") + import org.apache.spark.sql.functions.rand + val c = Column(col) + val r = rand(seed).as("rand_" + UUID.randomUUID().toString.take(8)) + val expr = fractions.toSeq.map { case (k, v) => + (c === k) && (r < v) + }.reduce(_ || _) || false + this.filter(expr) + } + /** * Randomly splits this [[DataFrame]] with the provided weights. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index bb8621abe64ad..acddedd0e0cc9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -635,4 +635,12 @@ class DataFrameSuite extends QueryTest { val res11 = ctx.range(-1).select("id") assert(res11.count == 0) } + + test("sampleBy") { + val df = ctx.range(0, 100).select((col("id") % 3).as("key")) + val sampled = df.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) + checkAnswer( + sampled.groupBy("key").count().orderBy("key"), + Seq(Row(0, 4), Row(1, 9))) + } } From 4a14834f74f3edd45403473c251b9b4e09ad034a Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 11 Jun 2015 16:46:08 -0700 Subject: [PATCH 2/7] move sampleBy to stat --- python/pyspark/sql/dataframe.py | 7 ++++- .../org/apache/spark/sql/DataFrame.scala | 30 ++++--------------- .../spark/sql/DataFrameStatFunctions.scala | 24 +++++++++++++++ .../apache/spark/sql/DataFrameStatSuite.scala | 12 ++++++-- .../org/apache/spark/sql/DataFrameSuite.scala | 8 ----- 5 files changed, 46 insertions(+), 35 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 68e92cb8fe943..a4a375f0a0000 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -477,7 +477,7 @@ def sampleBy(self, col, fractions, seed=None): raise ValueError("key must be float, int, long, or string, but got %r" % type(k)) fractions[k] = float(v) seed = seed if seed is not None else random.randint(0, sys.maxsize) - return DataFrame(self._jdf.sampleBy(col, self._jmap(fractions), seed), self.sql_ctx) + return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx) @since(1.4) def randomSplit(self, weights, seed=None): @@ -1353,6 +1353,11 @@ def freqItems(self, cols, support=None): freqItems.__doc__ = DataFrame.freqItems.__doc__ + def sampleBy(self, col, fractions, seed=None): + return self.df.sampleBy(col, fractions, seed) + + sampleBy.__doc__ = DataFrame.sampleBy.__doc__ + def _test(): import doctest diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 38218f4640469..59f64dd4bc648 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql import java.io.CharArrayWriter -import java.util.{Properties, UUID} +import java.util.Properties +import scala.collection.JavaConversions._ import scala.language.implicitConversions import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag @@ -32,11 +33,11 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} -import org.apache.spark.sql.catalyst.analysis.{MultiAlias, ResolvedStar, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.{MultiAlias, ResolvedStar, UnresolvedAttribute, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{Filter, _} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} -import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD} import org.apache.spark.sql.json.JacksonGenerator import org.apache.spark.sql.sources.CreateTableUsingAsSelect @@ -44,6 +45,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils + private[sql] object DataFrame { def apply(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = { new DataFrame(sqlContext, logicalPlan) @@ -945,26 +947,6 @@ class DataFrame private[sql]( sample(withReplacement, fraction, Utils.random.nextLong) } - /** - * Returns a stratified sample without replacement based on the fraction given on each stratum. - * @param col column that defines strata - * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat - * its fraction as zero. - * @param seed random seed - * @return a new [[DataFrame]] that represents the stratified sample - */ - def sampleBy(col: String, fractions: Map[Any, Double], seed: Long): DataFrame = { - require(fractions.values.forall(p => p >= 0.0 && p <= 1.0), - s"Fractions must be in [0, 1], but got $fractions.") - import org.apache.spark.sql.functions.rand - val c = Column(col) - val r = rand(seed).as("rand_" + UUID.randomUUID().toString.take(8)) - val expr = fractions.toSeq.map { case (k, v) => - (c === k) && (r < v) - }.reduce(_ || _) || false - this.filter(expr) - } - /** * Randomly splits this [[DataFrame]] with the provided weights. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index edb9ed7bba56a..955d28771b4df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.util.UUID + import org.apache.spark.annotation.Experimental import org.apache.spark.sql.execution.stat._ @@ -163,4 +165,26 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { def freqItems(cols: Seq[String]): DataFrame = { FrequentItems.singlePassFreqItems(df, cols, 0.01) } + + /** + * Returns a stratified sample without replacement based on the fraction given on each stratum. + * @param col column that defines strata + * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat + * its fraction as zero. + * @param seed random seed + * @return a new [[DataFrame]] that represents the stratified sample + * + * @since 1.5.0 + */ + def sampleBy(col: String, fractions: Map[Any, Double], seed: Long): DataFrame = { + require(fractions.values.forall(p => p >= 0.0 && p <= 1.0), + s"Fractions must be in [0, 1], but got $fractions.") + import org.apache.spark.sql.functions.rand + val c = Column(col) + val r = rand(seed).as("rand_" + UUID.randomUUID().toString.take(8)) + val expr = fractions.toSeq.map { case (k, v) => + (c === k) && (r < v) + }.reduce(_ || _) || false + df.filter(expr) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 0d3ff899dad72..3dd46889127ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql import org.scalatest.Matchers._ -import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.functions.col -class DataFrameStatSuite extends SparkFunSuite { +class DataFrameStatSuite extends QueryTest { private val sqlCtx = org.apache.spark.sql.test.TestSQLContext import sqlCtx.implicits._ @@ -98,4 +98,12 @@ class DataFrameStatSuite extends SparkFunSuite { val items2 = singleColResults.collect().head items2.getSeq[Double](0) should contain (-1.0) } + + test("sampleBy") { + val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key")) + val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) + checkAnswer( + sampled.groupBy("key").count().orderBy("key"), + Seq(Row(0, 4), Row(1, 9))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index acddedd0e0cc9..bb8621abe64ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -635,12 +635,4 @@ class DataFrameSuite extends QueryTest { val res11 = ctx.range(-1).select("id") assert(res11.count == 0) } - - test("sampleBy") { - val df = ctx.range(0, 100).select((col("id") % 3).as("key")) - val sampled = df.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) - checkAnswer( - sampled.groupBy("key").count().orderBy("key"), - Seq(Row(0, 4), Row(1, 9))) - } } From 991f26f4ca51d8e7a214c0da51cabde3ced9169d Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 11 Jun 2015 18:49:29 -0700 Subject: [PATCH 3/7] fix seed --- python/pyspark/sql/dataframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index a4a375f0a0000..68e33f89f28c8 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -459,7 +459,7 @@ def sampleBy(self, col, fractions, seed=None): >>> from pyspark.sql.functions import col >>> dataset = sqlContext.range(0, 100).select((col("id") % 3).alias("key")) - >>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0L) + >>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0) >>> sampled.groupBy("key").count().orderBy("key").show() +---+-----+ |key|count| From 103beb3782a54d85bdc89853ea98ee5e3eecba63 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 24 Jun 2015 11:14:01 -0700 Subject: [PATCH 4/7] add Java-friendly sampleBy --- .../spark/sql/DataFrameStatFunctions.scala | 24 ++++++++++++++++--- .../apache/spark/sql/JavaDataFrameSuite.java | 9 +++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 955d28771b4df..23fc0cbddea4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql -import java.util.UUID +import java.{util => ju, lang => jl} + +import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.execution.stat._ @@ -172,19 +174,35 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat * its fraction as zero. * @param seed random seed + * @tparam T stratum type * @return a new [[DataFrame]] that represents the stratified sample * * @since 1.5.0 */ - def sampleBy(col: String, fractions: Map[Any, Double], seed: Long): DataFrame = { + def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = { require(fractions.values.forall(p => p >= 0.0 && p <= 1.0), s"Fractions must be in [0, 1], but got $fractions.") import org.apache.spark.sql.functions.rand val c = Column(col) - val r = rand(seed).as("rand_" + UUID.randomUUID().toString.take(8)) + val r = rand(seed).as("rand_" + ju.UUID.randomUUID().toString.take(8)) val expr = fractions.toSeq.map { case (k, v) => (c === k) && (r < v) }.reduce(_ || _) || false df.filter(expr) } + + /** + * Returns a stratified sample without replacement based on the fraction given on each stratum. + * @param col column that defines strata + * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat + * its fraction as zero. + * @param seed random seed + * @tparam T stratum type + * @return a new [[DataFrame]] that represents the stratified sample + * + * @since 1.5.0 + */ + def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { + sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 72c42f4fe376b..5d29991092093 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -227,4 +227,13 @@ public void testCovariance() { Double result = df.stat().cov("a", "b"); Assert.assertTrue(Math.abs(result) < 1e-6); } + + @Test + public void testSampleBy() { + DataFrame df = context.range(0, 100).select(col("id").mod(3).as("key")); + DataFrame sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); + Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); + Row[] expected = new Row[] {RowFactory.create(0, 4), RowFactory.create(1, 9)}; + Assert.assertArrayEquals(expected, actual); + } } From f051afd95ad369b3f596b99f939cc4e54066ccfc Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 29 Jul 2015 18:07:34 -0700 Subject: [PATCH 5/7] use udf instead of building expression --- .../org/apache/spark/sql/DataFrameStatFunctions.scala | 10 +++++----- .../test/org/apache/spark/sql/JavaDataFrameSuite.java | 2 +- .../org/apache/spark/sql/DataFrameStatSuite.scala | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index b1459e2e4040f..f307a4f6375b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -185,13 +185,13 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = { require(fractions.values.forall(p => p >= 0.0 && p <= 1.0), s"Fractions must be in [0, 1], but got $fractions.") - import org.apache.spark.sql.functions.rand + import org.apache.spark.sql.functions.{rand, udf} val c = Column(col) val r = rand(seed).as("rand_" + ju.UUID.randomUUID().toString.take(8)) - val expr = fractions.toSeq.map { case (k, v) => - (c === k) && (r < v) - }.reduce(_ || _) || false - df.filter(expr) + val f = udf { (stratum: Any, x: Double) => + x < fractions.getOrElse(stratum.asInstanceOf[T], 0.0) + } + df.filter(f(c, r)) } /** diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 5d29991092093..93971cebcc3a8 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -233,7 +233,7 @@ public void testSampleBy() { DataFrame df = context.range(0, 100).select(col("id").mod(3).as("key")); DataFrame sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); - Row[] expected = new Row[] {RowFactory.create(0, 4), RowFactory.create(1, 9)}; + Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 7)}; Assert.assertArrayEquals(expected, actual); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 7c1a83f20ceed..16a1e901a5073 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -136,6 +136,6 @@ class DataFrameStatSuite extends QueryTest { val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) checkAnswer( sampled.groupBy("key").count().orderBy("key"), - Seq(Row(0, 4), Row(1, 9))) + Seq(Row(0, 5), Row(1, 7))) } } From 542bd37272ec949f2795daacc0840a63e26adc0f Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 30 Jul 2015 13:39:08 -0700 Subject: [PATCH 6/7] update test --- .../scala/org/apache/spark/sql/DataFrameStatFunctions.scala | 2 +- .../test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java | 2 +- .../test/scala/org/apache/spark/sql/DataFrameStatSuite.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index f307a4f6375b9..2e68e358f2f1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -187,7 +187,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { s"Fractions must be in [0, 1], but got $fractions.") import org.apache.spark.sql.functions.{rand, udf} val c = Column(col) - val r = rand(seed).as("rand_" + ju.UUID.randomUUID().toString.take(8)) + val r = rand(seed) val f = udf { (stratum: Any, x: Double) => x < fractions.getOrElse(stratum.asInstanceOf[T], 0.0) } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 5c57e2ac472af..2c669bb59a0b5 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -232,7 +232,7 @@ public void testSampleBy() { DataFrame df = context.range(0, 100).select(col("id").mod(3).as("key")); DataFrame sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); - Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 7)}; + Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 8)}; Assert.assertArrayEquals(expected, actual); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 16a1e901a5073..07a675e64f527 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -136,6 +136,6 @@ class DataFrameStatSuite extends QueryTest { val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) checkAnswer( sampled.groupBy("key").count().orderBy("key"), - Seq(Row(0, 5), Row(1, 7))) + Seq(Row(0, 5), Row(1, 8))) } } From fbf9044bb589a9651cdb7bf3fab200f668b0d9be Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 30 Jul 2015 14:17:36 -0700 Subject: [PATCH 7/7] fix python test --- python/pyspark/sql/dataframe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index a7851c04c3a3e..0f3480c239187 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -461,9 +461,10 @@ def sampleBy(self, col, fractions, seed=None): +---+-----+ |key|count| +---+-----+ - | 0| 5| + | 0| 3| | 1| 8| +---+-----+ + """ if not isinstance(col, str): raise ValueError("col must be a string, but got %r" % type(col))