Skip to content
Browse files

Refactor and add aggregator support

Refactored out the agg() and comp() methods from Pregel.run.

Defined an implicit conversion to allow applications that don't use
aggregators to avoid including a null argument for the result of the
aggregator in the compute function.
  • Loading branch information...
1 parent c18fa3e commit 563c5e717cc75869c328bba17116313eab9e976b @ankurdave committed
View
111 bagel/src/main/scala/bagel/Pregel.scala
@@ -6,37 +6,62 @@ import spark.SparkContext._
import scala.collection.mutable.ArrayBuffer
object Pregel extends Logging {
- /**
- * Runs a Pregel job on the given vertices consisting of the
- * specified compute function.
- *
- * Before beginning the first superstep, the given messages are sent
- * to their destination vertices.
- *
- * During the job, the specified combiner functions are applied to
- * messages as they travel between vertices.
- *
- * The job halts and returns the resulting set of vertices when no
- * messages are being sent between vertices and all vertices have
- * voted to halt by setting their state to inactive.
- */
- def run[V <: Vertex : Manifest, M <: Message : Manifest, C](
+ def run[V <: Vertex : Manifest, M <: Message : Manifest, C : Manifest, A : Manifest](
sc: SparkContext,
verts: RDD[(String, V)],
- msgs: RDD[(String, M)],
- combiner: Combiner[M, C],
- numSplits: Int,
- superstep: Int = 0
- )(compute: (V, Option[C], Int) => (V, Iterable[M])): RDD[V] = {
+ msgs: RDD[(String, M)]
+ )(
+ combiner: Combiner[M, C] = new DefaultCombiner[M],
+ aggregator: Aggregator[V, A] = new NullAggregator[V],
+ superstep: Int = 0,
+ numSplits: Int = sc.numCores
+ )(
+ compute: (V, Option[C], A, Int) => (V, Iterable[M])
+ ): RDD[V] = {
logInfo("Starting superstep "+superstep+".")
val startTime = System.currentTimeMillis
- // Bring together vertices and messages
+ val aggregated = agg(verts, aggregator)
val combinedMsgs = msgs.combineByKey(combiner.createCombiner, combiner.mergeMsg, combiner.mergeCombiners, numSplits)
val grouped = verts.groupWith(combinedMsgs)
+ val (processed, numMsgs, numActiveVerts) = comp[V, M, C](sc, grouped, compute(_, _, aggregated, superstep))
+
+ val timeTaken = System.currentTimeMillis - startTime
+ logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000))
+
+ // Check stopping condition and iterate
+ val noActivity = numMsgs == 0 && numActiveVerts == 0
+ if (noActivity) {
+ processed.map { case (id, (vert, msgs)) => vert }
+ } else {
+ val newVerts = processed.mapValues { case (vert, msgs) => vert }
+ val newMsgs = processed.flatMap {
+ case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m))
+ }
+ run(sc, newVerts, newMsgs)(combiner, aggregator, superstep + 1, numSplits)(compute)
+ }
+ }
+
+ /**
+ * Aggregates the given vertices using the given aggregator, or does
+ * nothing if it is a NullAggregator.
+ */
+ def agg[V <: Vertex, A : Manifest](verts: RDD[(String, V)], aggregator: Aggregator[V, A]): A = aggregator match {
+ case _: NullAggregator[_] =>
+ None
+ case _ =>
+ verts.map {
+ case (id, vert) => aggregator.createAggregator(vert)
+ }.reduce(aggregator.mergeAggregators(_, _))
+ }
- // Run compute on each vertex
+ /**
+ * Processes the given vertex-message RDD using the compute
+ * function. Returns the processed RDD, the number of messages
+ * created, and the number of active vertices.
+ */
+ def comp[V <: Vertex, M <: Message, C](sc: SparkContext, grouped: RDD[(String, (Seq[V], Seq[C]))], compute: (V, Option[C]) => (V, Iterable[M])): (RDD[(String, (V, Iterable[M]))], Int, Int) = {
var numMsgs = sc.accumulator(0)
var numActiveVerts = sc.accumulator(0)
val processed = grouped.flatMapValues {
@@ -46,7 +71,7 @@ object Pregel extends Logging {
compute(v, c match {
case Seq(comb) => Some(comb)
case Seq() => None
- }, superstep)
+ })
numMsgs += newMsgs.size
if (newVert.active)
@@ -58,30 +83,36 @@ object Pregel extends Logging {
// Force evaluation of processed RDD for accurate performance measurements
processed.foreach(x => {})
- val timeTaken = System.currentTimeMillis - startTime
- logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000))
+ (processed, numMsgs.value, numActiveVerts.value)
+ }
- // Check stopping condition and iterate
- val noActivity = numMsgs.value == 0 && numActiveVerts.value == 0
- if (noActivity) {
- processed.map { case (id, (vert, msgs)) => vert }
- } else {
- val newVerts = processed.mapValues { case (vert, msgs) => vert }
- val newMsgs = processed.flatMap {
- case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m))
- }
- run(sc, newVerts, newMsgs, combiner, numSplits, superstep + 1)(compute)
- }
+ /**
+ * Converts a compute function that doesn't take an aggregator to
+ * one that does, so it can be passed to Pregel.run.
+ */
+ implicit def addAggregatorArg[
+ V <: Vertex : Manifest, M <: Message : Manifest, C
+ ](
+ compute: (V, Option[C], Int) => (V, Iterable[M])
+ ): (V, Option[C], Option[Nothing], Int) => (V, Iterable[M]) = {
+ (vert: V, messages: Option[C], aggregator: Option[Nothing], superstep: Int) => compute(vert, messages, superstep)
}
}
+// TODO: Simplify Combiner interface and make it more OO.
trait Combiner[M, C] {
def createCombiner(msg: M): C
def mergeMsg(combiner: C, msg: M): C
def mergeCombiners(a: C, b: C): C
}
-@serializable class DefaultCombiner[M] extends Combiner[M, ArrayBuffer[M]] {
+trait Aggregator[V, A] {
+ def createAggregator(vert: V): A
+ def mergeAggregators(a: A, b: A): A
+}
+
+@serializable
+class DefaultCombiner[M] extends Combiner[M, ArrayBuffer[M]] {
def createCombiner(msg: M): ArrayBuffer[M] =
ArrayBuffer(msg)
def mergeMsg(combiner: ArrayBuffer[M], msg: M): ArrayBuffer[M] =
@@ -90,6 +121,12 @@ trait Combiner[M, C] {
a ++= b
}
+@serializable
+class NullAggregator[V] extends Aggregator[V, Option[Nothing]] {
+ def createAggregator(vert: V): Option[Nothing] = None
+ def mergeAggregators(a: Option[Nothing], b: Option[Nothing]): Option[Nothing] = None
+}
+
/**
* Represents a Pregel vertex.
*
View
5 bagel/src/main/scala/bagel/ShortestPath.scala
@@ -5,6 +5,8 @@ import spark.SparkContext._
import scala.math.min
+import bagel.Pregel._
+
object ShortestPath {
def main(args: Array[String]) {
if (args.length < 4) {
@@ -49,7 +51,7 @@ object ShortestPath {
messages.count()+" messages.")
// Do the computation
- val result = Pregel.run(sc, vertices, messages, MinCombiner, numSplits) {
+ val compute = addAggregatorArg {
(self: SPVertex, messageMinValue: Option[Int], superstep: Int) =>
val newValue = messageMinValue match {
case Some(minVal) => min(self.value, minVal)
@@ -65,6 +67,7 @@ object ShortestPath {
(new SPVertex(self.id, newValue, self.outEdges, false), outbox)
}
+ val result = Pregel.run(sc, vertices, messages)(combiner = MinCombiner, numSplits = numSplits)(compute)
// Print the result
System.err.println("Shortest path from "+startVertex+" to all vertices:")
View
6 bagel/src/main/scala/bagel/WikipediaPageRank.scala
@@ -3,6 +3,8 @@ package bagel
import spark._
import spark.SparkContext._
+import bagel.Pregel._
+
import scala.collection.mutable.ArrayBuffer
import scala.xml.{XML,NodeSeq}
@@ -60,9 +62,9 @@ object WikipediaPageRank {
val messages = sc.parallelize(List[(String, PRMessage)]())
val result =
if (noCombiner) {
- Pregel.run(sc, vertices, messages, PRNoCombiner, numSplits)(PRNoCombiner.compute(numVertices, epsilon))
+ Pregel.run(sc, vertices, messages)(numSplits = numSplits)(PRNoCombiner.compute(numVertices, epsilon))
} else {
- Pregel.run(sc, vertices, messages, PRCombiner, numSplits)(PRCombiner.compute(numVertices, epsilon))
+ Pregel.run(sc, vertices, messages)(combiner = PRCombiner, numSplits = numSplits)(PRCombiner.compute(numVertices, epsilon))
}
// Print the result
View
10 bagel/src/test/scala/bagel/BagelSuite.scala
@@ -10,6 +10,8 @@ import scala.collection.mutable.ArrayBuffer
import spark._
+import bagel.Pregel._
+
@serializable class TestVertex(val id: String, val active: Boolean, val age: Int) extends Vertex
@serializable class TestMessage(val targetId: String) extends Message
@@ -20,10 +22,10 @@ class BagelSuite extends FunSuite with Assertions {
val msgs = sc.parallelize(Array[(String, TestMessage)]())
val numSupersteps = 5
val result =
- Pregel.run(sc, verts, msgs, new DefaultCombiner[TestMessage], 1) {
+ Pregel.run(sc, verts, msgs)()(addAggregatorArg {
(self: TestVertex, msgs: Option[ArrayBuffer[TestMessage]], superstep: Int) =>
(new TestVertex(self.id, superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]())
- }
+ })
for (vert <- result.collect)
assert(vert.age === numSupersteps)
}
@@ -34,7 +36,7 @@ class BagelSuite extends FunSuite with Assertions {
val msgs = sc.parallelize(Array("a" -> new TestMessage("a")))
val numSupersteps = 5
val result =
- Pregel.run(sc, verts, msgs, new DefaultCombiner[TestMessage], 1) {
+ Pregel.run(sc, verts, msgs)()(addAggregatorArg {
(self: TestVertex, msgs: Option[ArrayBuffer[TestMessage]], superstep: Int) =>
val msgsOut =
msgs match {
@@ -44,7 +46,7 @@ class BagelSuite extends FunSuite with Assertions {
new ArrayBuffer[TestMessage]()
}
(new TestVertex(self.id, self.active, self.age + 1), msgsOut)
- }
+ })
for (vert <- result.collect)
assert(vert.age === numSupersteps)
}

0 comments on commit 563c5e7

Please sign in to comment.
Something went wrong with that request. Please try again.