Permalink
Browse files

Manually merge pull request #175 by Imran Rashid

  • Loading branch information...
mateiz committed Sep 11, 2012
1 parent 995982b commit 6d7f907e73e9702c0dbd0e41e4a52022c0b81d3d
@@ -3,6 +3,7 @@ package spark
import java.io._
import scala.collection.mutable.Map
import scala.collection.generic.Growable
/**
* A datatype that can be accumulated, i.e. has an commutative and associative +.
@@ -92,6 +93,29 @@ trait AccumulableParam[R, T] extends Serializable {
def zero(initialValue: R): R
}
class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable, T]
extends AccumulableParam[R,T] {
def addAccumulator(growable: R, elem: T) : R = {
growable += elem
growable
}
def addInPlace(t1: R, t2: R) : R = {
t1 ++= t2
t1
}
def zero(initialValue: R): R = {
// We need to clone initialValue, but it's hard to specify that R should also be Cloneable.
// Instead we'll serialize it to a buffer and load it back.
val ser = (new spark.JavaSerializer).newInstance()
val copy = ser.deserialize[R](ser.serialize(initialValue))
copy.clear() // In case it contained stuff
copy
}
}
/**
* A simpler value of [[spark.Accumulable]] where the result type being accumulated is the same
* as the types of elements being merged.
@@ -7,6 +7,7 @@ import akka.actor.Actor
import akka.actor.Actor._
import scala.collection.mutable.ArrayBuffer
import scala.collection.generic.Growable
import org.apache.hadoop.fs.Path
import org.apache.hadoop.conf.Configuration
@@ -307,6 +308,16 @@ class SparkContext(
def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) =
new Accumulable(initialValue, param)
/**
* Create an accumulator from a "mutable collection" type.
*
* Growable and TraversableOnce are the standard APIs that guarantee += and ++=, implemented by
* standard mutable collections. So you can use this with mutable Map, Set, etc.
*/
def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T](initialValue: R) = {
val param = new GrowableAccumulableParam[R,T]
new Accumulable(initialValue, param)
}
// Keep around a weak hash map of values to Cached versions?
def broadcast[T](value: T) = SparkEnv.get.broadcastManager.newBroadcast[T] (value, isLocal)
@@ -56,7 +56,6 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
}
}
implicit object SetAccum extends AccumulableParam[mutable.Set[Any], Any] {
def addInPlace(t1: mutable.Set[Any], t2: mutable.Set[Any]) : mutable.Set[Any] = {
t1 ++= t2
@@ -71,7 +70,6 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
}
}
test ("value not readable in tasks") {
import SetAccum._
val maxI = 1000
@@ -89,4 +87,29 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
}
}
test ("collection accumulators") {
val maxI = 1000
for (nThreads <- List(1, 10)) {
// test single & multi-threaded
val sc = new SparkContext("local[" + nThreads + "]", "test")
val setAcc = sc.accumulableCollection(mutable.HashSet[Int]())
val bufferAcc = sc.accumulableCollection(mutable.ArrayBuffer[Int]())
val mapAcc = sc.accumulableCollection(mutable.HashMap[Int,String]())
val d = sc.parallelize((1 to maxI) ++ (1 to maxI))
d.foreach {
x => {setAcc += x; bufferAcc += x; mapAcc += (x -> x.toString)}
}
// Note that this is typed correctly -- no casts necessary
setAcc.value.size should be (maxI)
bufferAcc.value.size should be (2 * maxI)
mapAcc.value.size should be (maxI)
for (i <- 1 to maxI) {
setAcc.value should contain(i)
bufferAcc.value should contain(i)
mapAcc.value should contain (i -> i.toString)
}
sc.stop()
}
}
}

0 comments on commit 6d7f907

Please sign in to comment.