Skip to content

Commit

Permalink
Wrap a few more RDD functions in an operation scope
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew Or committed May 4, 2015
1 parent 3ffe566 commit 1c310e4
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 17 deletions.
38 changes: 27 additions & 11 deletions core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,45 +30,57 @@ import org.apache.spark.util.StatCounter
*/
class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
/** Add up the elements in this RDD. */
def sum(): Double = {
def sum(): Double = self.withScope {
self.fold(0.0)(_ + _)
}

/**
* Return a [[org.apache.spark.util.StatCounter]] object that captures the mean, variance and
* count of the RDD's elements in one operation.
*/
def stats(): StatCounter = {
def stats(): StatCounter = self.withScope {
self.mapPartitions(nums => Iterator(StatCounter(nums))).reduce((a, b) => a.merge(b))
}

/** Compute the mean of this RDD's elements. */
def mean(): Double = stats().mean
def mean(): Double = self.withScope {
stats().mean
}

/** Compute the variance of this RDD's elements. */
def variance(): Double = stats().variance
def variance(): Double = self.withScope {
stats().variance
}

/** Compute the standard deviation of this RDD's elements. */
def stdev(): Double = stats().stdev
def stdev(): Double = self.withScope {
stats().stdev
}

/**
* Compute the sample standard deviation of this RDD's elements (which corrects for bias in
* estimating the standard deviation by dividing by N-1 instead of N).
*/
def sampleStdev(): Double = stats().sampleStdev
def sampleStdev(): Double = self.withScope {
stats().sampleStdev
}

/**
* Compute the sample variance of this RDD's elements (which corrects for bias in
* estimating the variance by dividing by N-1 instead of N).
*/
def sampleVariance(): Double = stats().sampleVariance
def sampleVariance(): Double = self.withScope {
stats().sampleVariance
}

/**
* :: Experimental ::
* Approximate operation to return the mean within a timeout.
*/
@Experimental
def meanApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
def meanApprox(
timeout: Long,
confidence: Double = 0.95): PartialResult[BoundedDouble] = self.withScope {
val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns)
val evaluator = new MeanEvaluator(self.partitions.length, confidence)
self.context.runApproximateJob(self, processPartition, evaluator, timeout)
Expand All @@ -79,7 +91,9 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
* Approximate operation to return the sum within a timeout.
*/
@Experimental
def sumApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
def sumApprox(
timeout: Long,
confidence: Double = 0.95): PartialResult[BoundedDouble] = self.withScope {
val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns)
val evaluator = new SumEvaluator(self.partitions.length, confidence)
self.context.runApproximateJob(self, processPartition, evaluator, timeout)
Expand All @@ -93,7 +107,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
* If the RDD contains infinity, NaN throws an exception
* If the elements in RDD do not vary (max == min) always returns a single bucket.
*/
def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = {
def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = self.withScope {
// Scala's built-in range has issues. See #SI-8782
def customRange(min: Double, max: Double, steps: Int): IndexedSeq[Double] = {
val span = max - min
Expand Down Expand Up @@ -140,7 +154,9 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
* the maximum value of the last position and all NaN entries will be counted
* in that bucket.
*/
def histogram(buckets: Array[Double], evenBuckets: Boolean = false): Array[Long] = {
def histogram(
buckets: Array[Double],
evenBuckets: Boolean = false): Array[Long] = self.withScope {
if (buckets.length < 2) {
throw new IllegalArgumentException("buckets array must have at least two elements")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ class OrderedRDDFunctions[K : Ordering : ClassTag,
self: RDD[P])
extends Logging with Serializable
{
// TODO: Don't forget to scope me later

private val ordering = implicitly[Ordering[K]]

/**
Expand All @@ -59,7 +57,7 @@ class OrderedRDDFunctions[K : Ordering : ClassTag,
*/
// TODO: this currently doesn't work on P other than Tuple2!
def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.length)
: RDD[(K, V)] =
: RDD[(K, V)] = self.withScope
{
val part = new RangePartitioner(numPartitions, self, ascending)
new ShuffledRDD[K, V, V](self, part)
Expand All @@ -73,7 +71,7 @@ class OrderedRDDFunctions[K : Ordering : ClassTag,
* This is more efficient than calling `repartition` and then sorting within each partition
* because it can push the sorting down into the shuffle machinery.
*/
def repartitionAndSortWithinPartitions(partitioner: Partitioner): RDD[(K, V)] = {
def repartitionAndSortWithinPartitions(partitioner: Partitioner): RDD[(K, V)] = self.withScope {
new ShuffledRDD[K, V, V](self, partitioner).setKeyOrdering(ordering)
}

Expand All @@ -83,7 +81,7 @@ class OrderedRDDFunctions[K : Ordering : ClassTag,
* performed efficiently by only scanning the partitions that might contain matching elements.
* Otherwise, a standard `filter` is applied to all partitions.
*/
def filterByRange(lower: K, upper: K): RDD[P] = {
def filterByRange(lower: K, upper: K): RDD[P] = self.withScope {

def inRange(k: K): Boolean = ordering.gteq(k, lower) && ordering.lteq(k, upper)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag
* byte arrays to BytesWritable, and Strings to Text. The `path` can be on any Hadoop-supported
* file system.
*/
def saveAsSequenceFile(path: String, codec: Option[Class[_ <: CompressionCodec]] = None) {
def saveAsSequenceFile(
path: String,
codec: Option[Class[_ <: CompressionCodec]] = None): Unit = self.withScope {
def anyToWritable[U <% Writable](u: U): Writable = u

// TODO We cannot force the return type of `anyToWritable` be same as keyWritableClass and
Expand Down

0 comments on commit 1c310e4

Please sign in to comment.