Skip to content

Commit

Permalink
[SPARK-4860][pyspark][sql] speeding up sample() and takeSample()
Browse files Browse the repository at this point in the history
This PR modifies the python `SchemaRDD` to use `sample()` and `takeSample()` from Scala instead of the slower python implementations from `rdd.py`. This is worthwhile because the `Row`'s are already serialized as Java objects.

In order to use the faster `takeSample()`, a `takeSampleToPython()` method was implemented in `SchemaRDD.scala` following the pattern of `collectToPython()`.

Author: jbencook <jbenjamincook@gmail.com>
Author: J. Benjamin Cook <jbenjamincook@gmail.com>

Closes apache#3764 from jbencook/master and squashes the following commits:

6fbc769 [J. Benjamin Cook] [SPARK-4860][pyspark][sql] fixing sloppy indentation for takeSampleToPython() arguments
5170da2 [J. Benjamin Cook] [SPARK-4860][pyspark][sql] fixing typo: from RDD to SchemaRDD
de22f70 [jbencook] [SPARK-4860][pyspark][sql] using sample() method from JavaSchemaRDD
b916442 [jbencook] [SPARK-4860][pyspark][sql] adding sample() to JavaSchemaRDD
020cbdf [jbencook] [SPARK-4860][pyspark][sql] using Scala implementations of `sample()` and `takeSample()`
  • Loading branch information
jbencook authored and JoshRosen committed Dec 24, 2014
1 parent 7e2deb7 commit fd41eb9
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 0 deletions.
28 changes: 28 additions & 0 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2085,6 +2085,34 @@ def subtract(self, other, numPartitions=None):
else:
raise ValueError("Can only subtract another SchemaRDD")

def sample(self, withReplacement, fraction, seed=None):
"""
Return a sampled subset of this SchemaRDD.
>>> srdd = sqlCtx.inferSchema(rdd)
>>> srdd.sample(False, 0.5, 97).count()
2L
"""
assert fraction >= 0.0, "Negative fraction value: %s" % fraction
seed = seed if seed is not None else random.randint(0, sys.maxint)
rdd = self._jschema_rdd.sample(withReplacement, fraction, long(seed))
return SchemaRDD(rdd, self.sql_ctx)

def takeSample(self, withReplacement, num, seed=None):
"""Return a fixed-size sampled subset of this SchemaRDD.
>>> srdd = sqlCtx.inferSchema(rdd)
>>> srdd.takeSample(False, 2, 97)
[Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')]
"""
seed = seed if seed is not None else random.randint(0, sys.maxint)
with SCCallSiteSync(self.context) as css:
bytesInJava = self._jschema_rdd.baseSchemaRDD() \
.takeSampleToPython(withReplacement, num, long(seed)) \
.iterator()
cls = _create_cls(self.schema())
return map(cls, self._collect_iterator_through_file(bytesInJava))


def _test():
import doctest
Expand Down
15 changes: 15 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,21 @@ class SchemaRDD(
}.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable)
}

/**
* Serializes the Array[Row] returned by SchemaRDD's takeSample(), using the same
* format as javaToPython and collectToPython. It is used by pyspark.
*/
private[sql] def takeSampleToPython(
withReplacement: Boolean,
num: Int,
seed: Long): JList[Array[Byte]] = {
val fieldTypes = schema.fields.map(_.dataType)
val pickle = new Pickler
new java.util.ArrayList(this.takeSample(withReplacement, num, seed).map { row =>
EvaluatePython.rowToArray(row, fieldTypes)
}.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable)
}

/**
* Creates SchemaRDD by applying own schema to derived RDD. Typically used to wrap return value
* of base RDD functions that do not change schema.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,4 +218,10 @@ class JavaSchemaRDD(
*/
def subtract(other: JavaSchemaRDD, p: Partitioner): JavaSchemaRDD =
this.baseSchemaRDD.subtract(other.baseSchemaRDD, p).toJavaSchemaRDD

/**
* Return a SchemaRDD with a sampled version of the underlying dataset.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaSchemaRDD =
this.baseSchemaRDD.sample(withReplacement, fraction, seed).toJavaSchemaRDD
}

0 comments on commit fd41eb9

Please sign in to comment.