Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Clean up Bagel source and interface

  • Loading branch information...
commit c5b3ea755ff8a69aa39dd6e46d57cbe9d5bcbcae 1 parent 19122af
@ankurdave authored
View
109 bagel/src/main/scala/bagel/Pregel.scala
@@ -7,75 +7,81 @@ import scala.collection.mutable.ArrayBuffer
object Pregel extends Logging {
/**
- * Runs a Pregel job on the given vertices, running the specified
- * compute function on each vertex in every superstep. Before
- * beginning the first superstep, sends the given messages to their
- * destination vertices. In the join stage, launches splits
- * separate tasks (where splits is manually specified to work
- * around a bug in Spark).
+ * Runs a Pregel job on the given vertices consisting of the
+ * specified compute function.
*
- * Halts when no more messages are being sent between vertices, and
- * all vertices have voted to halt by setting their state to
- * Inactive.
+ * Before beginning the first superstep, the given messages are sent
+ * to their destination vertices.
+ *
+ * During the job, the specified combiner functions are applied to
+ * messages as they travel between vertices.
+ *
+ * The job halts and returns the resulting set of vertices when no
+ * messages are being sent between vertices and all vertices have
+ * voted to halt by setting their state to inactive.
*/
- def run[V <: Vertex : Manifest, M <: Message : Manifest, C](sc: SparkContext, verts: RDD[(String, V)], msgs: RDD[(String, M)], splits: Int, messageCombiner: (C, M) => C, defaultCombined: () => C, mergeCombined: (C, C) => C, maxSupersteps: Option[Int] = None, superstep: Int = 0)(compute: (V, C, Int) => (V, Iterable[M])): RDD[V] = {
+ def run[V <: Vertex : Manifest, M <: Message : Manifest, C](
+ sc: SparkContext,
+ verts: RDD[(String, V)],
+ msgs: RDD[(String, M)],
+ createCombiner: M => C,
+ mergeMsg: (C, M) => C,
+ mergeCombiners: (C, C) => C,
+ numSplits: Int,
+ superstep: Int = 0
+ )(compute: (V, Option[C], Int) => (V, Iterable[M])): RDD[V] = {
+
logInfo("Starting superstep "+superstep+".")
val startTime = System.currentTimeMillis
// Bring together vertices and messages
- val combinedMsgs = msgs.combineByKey({x => messageCombiner(defaultCombined(), x)}, messageCombiner, mergeCombined, splits)
- logDebug("verts.splits.size = " + verts.splits.size)
- logDebug("combinedMsgs.splits.size = " + combinedMsgs.splits.size)
- logDebug("verts.partitioner = " + verts.partitioner)
- logDebug("combinedMsgs.partitioner = " + combinedMsgs.partitioner)
-
- val joined = verts.groupWith(combinedMsgs)
- logDebug("joined.splits.size = " + joined.splits.size)
- logDebug("joined.partitioner = " + joined.partitioner)
+ val combinedMsgs = msgs.combineByKey(createCombiner, mergeMsg, mergeCombiners, numSplits)
+ val grouped = verts.groupWith(combinedMsgs)
// Run compute on each vertex
- var messageCount = sc.accumulator(0)
- var activeVertexCount = sc.accumulator(0)
- val processed = joined.flatMapValues {
+ var numMsgs = sc.accumulator(0)
+ var numActiveVerts = sc.accumulator(0)
+ val processed = grouped.flatMapValues {
case (Seq(), _) => None
- case (Seq(v), Seq(comb)) =>
- val (newVertex, newMessages) = compute(v, comb, superstep)
+ case (Seq(v), c) =>
+ val (newVert, newMsgs) =
+ compute(v, c match {
+ case Seq(comb) => Some(comb)
+ case Seq() => None
+ }, superstep)
- messageCount += newMessages.size
- if (newVertex.active)
- activeVertexCount += 1
+ numMsgs += newMsgs.size
+ if (newVert.active)
+ numActiveVerts += 1
- Some((newVertex, newMessages))
- case (Seq(v), Seq()) =>
- val (newVertex, newMessages) = compute(v, defaultCombined(), superstep)
-
- messageCount += newMessages.size
- if (newVertex.active)
- activeVertexCount += 1
-
- Some((newVertex, newMessages))
+ Some((newVert, newMsgs))
}.cache
+
// Force evaluation of processed RDD for accurate performance measurements
processed.foreach(x => {})
val timeTaken = System.currentTimeMillis - startTime
logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000))
- // Check stopping condition and recurse
- val stop = messageCount.value == 0 && activeVertexCount.value == 0
- if (stop || (maxSupersteps.isDefined && superstep >= maxSupersteps.get)) {
- processed.map { _._2._1 }
+ // Check stopping condition and iterate
+ val noActivity = numMsgs.value == 0 && numActiveVerts.value == 0
+ if (noActivity) {
+ processed.map { case (id, (vert, msgs)) => vert }
} else {
- val newVerts = processed.mapValues(_._1)
- val newMsgs = processed.flatMap(x => x._2._2.map(m => (m.targetId, m)))
- run(sc, newVerts, newMsgs, splits, messageCombiner, defaultCombined, mergeCombined, maxSupersteps, superstep + 1)(compute)
+ val newVerts = processed.mapValues { case (vert, msgs) => vert }
+ val newMsgs = processed.flatMap {
+ case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m))
+ }
+ run(sc, newVerts, newMsgs, createCombiner, mergeMsg, mergeCombiners, numSplits, superstep + 1)(compute)
}
}
}
/**
- * Represents a Pregel vertex. Must be subclassed to store state
- * along with each vertex. Must be annotated with @serializable.
+ * Represents a Pregel vertex.
+ *
+ * Subclasses may store state along with each vertex and must be
+ * annotated with @serializable.
*/
trait Vertex {
def id: String
@@ -83,17 +89,20 @@ trait Vertex {
}
/**
- * Represents a Pregel message to a target vertex. Must be
- * subclassed to contain a payload. Must be annotated with @serializable.
+ * Represents a Pregel message to a target vertex.
+ *
+ * Subclasses may contain a payload to deliver to the target vertex
+ * and must be annotated with @serializable.
*/
trait Message {
def targetId: String
}
/**
- * Represents a directed edge between two vertices. Owned by the
- * source vertex, and contains the ID of the target vertex. Must
- * be subclassed to store state along with each edge. Must be annotated with @serializable.
+ * Represents a directed edge between two vertices.
+ *
+ * Subclasses may store state along each edge and must be annotated
+ * with @serializable.
*/
trait Edge {
def targetId: String
View
15 bagel/src/main/scala/bagel/ShortestPath.scala
@@ -49,12 +49,17 @@ object ShortestPath {
messages.count()+" messages.")
// Do the computation
- def messageCombiner(minSoFar: Int, message: SPMessage): Int =
- min(minSoFar, message.value)
+ def createCombiner(message: SPMessage): Int = message.value
+ def mergeMsg(combiner: Int, message: SPMessage): Int =
+ min(combiner, message.value)
+ def mergeCombiners(a: Int, b: Int): Int = min(a, b)
- val result = Pregel.run(sc, vertices, messages, numSplits, messageCombiner, () => Int.MaxValue, min _) {
- (self: SPVertex, messageMinValue: Int, superstep: Int) =>
- val newValue = min(self.value, messageMinValue)
+ val result = Pregel.run(sc, vertices, messages, createCombiner, mergeMsg, mergeCombiners, numSplits) {
+ (self: SPVertex, messageMinValue: Option[Int], superstep: Int) =>
+ val newValue = messageMinValue match {
+ case Some(minVal) => min(self.value, minVal)
+ case None => self.value
+ }
val outbox =
if (newValue != self.value)
View
99 bagel/src/main/scala/bagel/WikipediaPageRank.scala
@@ -4,7 +4,6 @@ import spark._
import spark.SparkContext._
import scala.collection.mutable.ArrayBuffer
-
import scala.xml.{XML,NodeSeq}
import java.io.{Externalizable,ObjectInput,ObjectOutput,DataOutputStream,DataInputStream}
@@ -14,7 +13,7 @@ import com.esotericsoftware.kryo._
object WikipediaPageRank {
def main(args: Array[String]) {
if (args.length < 4) {
- System.err.println("Usage: PageRank <inputFile> <threshold> <numSplits> <host> [<noCombiner>]")
+ System.err.println("Usage: WikipediaPageRank <inputFile> <threshold> <numSplits> <host> [<noCombiner>]")
System.exit(-1)
}
@@ -52,22 +51,18 @@ object WikipediaPageRank {
}
val outEdges = ArrayBuffer(links.map(link => new PREdge(new String(link.text))): _*)
val id = new String(title)
- (id, (new PRVertex(id, 1.0 / numVertices, outEdges, true)))
- })
- val graph = vertices.groupByKey(numSplits).mapValues(_.head).cache
-
+ (id, new PRVertex(id, 1.0 / numVertices, outEdges, true))
+ }).cache
println("Done parsing input file.")
- println("Input file had "+graph.count+" vertices.")
// Do the computation
val epsilon = 0.01 / numVertices
+ val messages = sc.parallelize(List[(String, PRMessage)]())
val result =
if (noCombiner) {
- val messages = sc.parallelize(List[(String, PRMessage)]())
- Pregel.run[PRVertex, PRMessage, ArrayBuffer[PRMessage]](sc, graph, messages, numSplits, NoCombiner.messageCombiner, NoCombiner.defaultCombined, NoCombiner.mergeCombined)(NoCombiner.compute(numVertices, epsilon))
+ Pregel.run[PRVertex, PRMessage, ArrayBuffer[PRMessage]](sc, vertices, messages, NoCombiner.createCombiner, NoCombiner.mergeMsg, NoCombiner.mergeCombiners, numSplits)(NoCombiner.compute(numVertices, epsilon))
} else {
- val messages = sc.parallelize(List[(String, PRMessage)]())
- Pregel.run[PRVertex, PRMessage, Double](sc, graph, messages, numSplits, Combiner.messageCombiner, Combiner.defaultCombined, Combiner.mergeCombined)(Combiner.compute(numVertices, epsilon))
+ Pregel.run[PRVertex, PRMessage, Double](sc, vertices, messages, Combiner.createCombiner, Combiner.mergeMsg, Combiner.mergeCombiners, numSplits)(Combiner.compute(numVertices, epsilon))
}
// Print the result
@@ -78,19 +73,19 @@ object WikipediaPageRank {
}
object Combiner {
- def messageCombiner(minSoFar: Double, message: PRMessage): Double =
- minSoFar + message.value
+ def createCombiner(message: PRMessage): Double = message.value
- def mergeCombined(a: Double, b: Double) = a + b
+ def mergeMsg(combiner: Double, message: PRMessage): Double =
+ combiner + message.value
- def defaultCombined(): Double = 0.0
+ def mergeCombiners(a: Double, b: Double) = a + b
- def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messageSum: Double, superstep: Int): (PRVertex, Iterable[PRMessage]) = {
- val newValue =
- if (messageSum != 0)
- 0.15 / numVertices + 0.85 * messageSum
- else
- self.value
+ def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messageSum: Option[Double], superstep: Int): (PRVertex, Iterable[PRMessage]) = {
+ val newValue = messageSum match {
+ case Some(msgSum) if msgSum != 0 =>
+ 0.15 / numVertices + 0.85 * msgSum
+ case _ => self.value
+ }
val terminate = (superstep >= 10 && (newValue - self.value).abs < epsilon) || superstep >= 30
@@ -106,20 +101,24 @@ object WikipediaPageRank {
}
object NoCombiner {
- def messageCombiner(messagesSoFar: ArrayBuffer[PRMessage], message: PRMessage): ArrayBuffer[PRMessage] =
- messagesSoFar += message
+ def createCombiner(message: PRMessage): ArrayBuffer[PRMessage] =
+ ArrayBuffer(message)
- def mergeCombined(a: ArrayBuffer[PRMessage], b: ArrayBuffer[PRMessage]): ArrayBuffer[PRMessage] =
- a ++= b
+ def mergeMsg(combiner: ArrayBuffer[PRMessage], message: PRMessage): ArrayBuffer[PRMessage] =
+ combiner += message
- def defaultCombined(): ArrayBuffer[PRMessage] = ArrayBuffer[PRMessage]()
+ def mergeCombiners(a: ArrayBuffer[PRMessage], b: ArrayBuffer[PRMessage]): ArrayBuffer[PRMessage] =
+ a ++= b
- def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Seq[PRMessage], superstep: Int): (PRVertex, Iterable[PRMessage]) =
- Combiner.compute(numVertices, epsilon)(self, messages.map(_.value).sum, superstep)
+ def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Option[ArrayBuffer[PRMessage]], superstep: Int): (PRVertex, Iterable[PRMessage]) =
+ Combiner.compute(numVertices, epsilon)(self, messages match {
+ case Some(msgs) => Some(msgs.map(_.value).sum)
+ case None => None
+ }, superstep)
}
}
-@serializable class PRVertex() extends Vertex with Externalizable {
+@serializable class PRVertex() extends Vertex {
var id: String = _
var value: Double = _
var outEdges: ArrayBuffer[PREdge] = _
@@ -132,29 +131,9 @@ object WikipediaPageRank {
this.outEdges = outEdges
this.active = active
}
-
- def writeExternal(out: ObjectOutput) {
- out.writeUTF(id)
- out.writeDouble(value)
- out.writeInt(outEdges.length)
- for (e <- outEdges)
- out.writeUTF(e.targetId)
- out.writeBoolean(active)
- }
-
- def readExternal(in: ObjectInput) {
- id = in.readUTF()
- value = in.readDouble()
- val numEdges = in.readInt()
- outEdges = new ArrayBuffer[PREdge](numEdges)
- for (i <- 0 until numEdges) {
- outEdges += new PREdge(in.readUTF())
- }
- active = in.readBoolean()
- }
}
-@serializable class PRMessage() extends Message with Externalizable {
+@serializable class PRMessage() extends Message {
var targetId: String = _
var value: Double = _
@@ -163,33 +142,15 @@ object WikipediaPageRank {
this.targetId = targetId
this.value = value
}
-
- def writeExternal(out: ObjectOutput) {
- out.writeUTF(targetId)
- out.writeDouble(value)
- }
-
- def readExternal(in: ObjectInput) {
- targetId = in.readUTF()
- value = in.readDouble()
- }
}
-@serializable class PREdge() extends Edge with Externalizable {
+@serializable class PREdge() extends Edge {
var targetId: String = _
def this(targetId: String) {
this()
this.targetId = targetId
}
-
- def writeExternal(out: ObjectOutput) {
- out.writeUTF(targetId)
- }
-
- def readExternal(in: ObjectInput) {
- targetId = in.readUTF()
- }
}
class PRKryoRegistrator extends KryoRegistrator {

0 comments on commit c5b3ea7

Please sign in to comment.
Something went wrong with that request. Please try again.