From fd41eb9574280b5cfee9b94b4f92e4c44363fb14 Mon Sep 17 00:00:00 2001 From: jbencook Date: Tue, 23 Dec 2014 17:46:24 -0800 Subject: [PATCH] [SPARK-4860][pyspark][sql] speeding up `sample()` and `takeSample()` This PR modifies the python `SchemaRDD` to use `sample()` and `takeSample()` from Scala instead of the slower python implementations from `rdd.py`. This is worthwhile because the `Row`'s are already serialized as Java objects. In order to use the faster `takeSample()`, a `takeSampleToPython()` method was implemented in `SchemaRDD.scala` following the pattern of `collectToPython()`. Author: jbencook Author: J. Benjamin Cook Closes #3764 from jbencook/master and squashes the following commits: 6fbc769 [J. Benjamin Cook] [SPARK-4860][pyspark][sql] fixing sloppy indentation for takeSampleToPython() arguments 5170da2 [J. Benjamin Cook] [SPARK-4860][pyspark][sql] fixing typo: from RDD to SchemaRDD de22f70 [jbencook] [SPARK-4860][pyspark][sql] using sample() method from JavaSchemaRDD b916442 [jbencook] [SPARK-4860][pyspark][sql] adding sample() to JavaSchemaRDD 020cbdf [jbencook] [SPARK-4860][pyspark][sql] using Scala implementations of `sample()` and `takeSample()` --- python/pyspark/sql.py | 28 +++++++++++++++++++ .../org/apache/spark/sql/SchemaRDD.scala | 15 ++++++++++ .../spark/sql/api/java/JavaSchemaRDD.scala | 6 ++++ 3 files changed, 49 insertions(+) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 469f82473af97..9807a84a66f11 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -2085,6 +2085,34 @@ def subtract(self, other, numPartitions=None): else: raise ValueError("Can only subtract another SchemaRDD") + def sample(self, withReplacement, fraction, seed=None): + """ + Return a sampled subset of this SchemaRDD. + + >>> srdd = sqlCtx.inferSchema(rdd) + >>> srdd.sample(False, 0.5, 97).count() + 2L + """ + assert fraction >= 0.0, "Negative fraction value: %s" % fraction + seed = seed if seed is not None else random.randint(0, sys.maxint) + rdd = self._jschema_rdd.sample(withReplacement, fraction, long(seed)) + return SchemaRDD(rdd, self.sql_ctx) + + def takeSample(self, withReplacement, num, seed=None): + """Return a fixed-size sampled subset of this SchemaRDD. + + >>> srdd = sqlCtx.inferSchema(rdd) + >>> srdd.takeSample(False, 2, 97) + [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')] + """ + seed = seed if seed is not None else random.randint(0, sys.maxint) + with SCCallSiteSync(self.context) as css: + bytesInJava = self._jschema_rdd.baseSchemaRDD() \ + .takeSampleToPython(withReplacement, num, long(seed)) \ + .iterator() + cls = _create_cls(self.schema()) + return map(cls, self._collect_iterator_through_file(bytesInJava)) + def _test(): import doctest diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 7baf8ffcef787..856b10f1a8fd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -437,6 +437,21 @@ class SchemaRDD( }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable) } + /** + * Serializes the Array[Row] returned by SchemaRDD's takeSample(), using the same + * format as javaToPython and collectToPython. It is used by pyspark. + */ + private[sql] def takeSampleToPython( + withReplacement: Boolean, + num: Int, + seed: Long): JList[Array[Byte]] = { + val fieldTypes = schema.fields.map(_.dataType) + val pickle = new Pickler + new java.util.ArrayList(this.takeSample(withReplacement, num, seed).map { row => + EvaluatePython.rowToArray(row, fieldTypes) + }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable) + } + /** * Creates SchemaRDD by applying own schema to derived RDD. Typically used to wrap return value * of base RDD functions that do not change schema. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala index ac4844f9b9290..5b9c612487ace 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala @@ -218,4 +218,10 @@ class JavaSchemaRDD( */ def subtract(other: JavaSchemaRDD, p: Partitioner): JavaSchemaRDD = this.baseSchemaRDD.subtract(other.baseSchemaRDD, p).toJavaSchemaRDD + + /** + * Return a SchemaRDD with a sampled version of the underlying dataset. + */ + def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaSchemaRDD = + this.baseSchemaRDD.sample(withReplacement, fraction, seed).toJavaSchemaRDD }