Permalink
Browse files

Add Bagel, an implementation of Pregel on Spark

  • Loading branch information...
1 parent 94ba95b commit c0736f6f68e47b82e3634252f8dba4f709a33ba5 @ankurdave committed Apr 13, 2011
View
103 bagel/src/main/scala/bagel/Pregel.scala
@@ -0,0 +1,103 @@
+package bagel
+
+import spark._
+import spark.SparkContext._
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.ArrayBuffer
+
+object Pregel extends Logging {
+ /**
+ * Runs a Pregel job on the given vertices, running the specified
+ * compute function on each vertex in every superstep. Before
+ * beginning the first superstep, sends the given messages to their
+ * destination vertices. In the join stage, launches splits
+ * separate tasks (where splits is manually specified to work
+ * around a bug in Spark).
+ *
+ * Halts when no more messages are being sent between vertices, and
+ * all vertices have voted to halt by setting their state to
+ * Inactive.
+ */
+ def run[V <: Vertex : Manifest, M <: Message : Manifest, C](sc: SparkContext, verts: RDD[(String, V)], msgs: RDD[(String, M)], splits: Int, messageCombiner: (C, M) => C, defaultCombined: () => C, mergeCombined: (C, C) => C, superstep: Int = 0)(compute: (V, C, Int) => (V, Iterable[M])): RDD[V] = {
+ println("Starting superstep "+superstep+".")
+ val startTime = System.currentTimeMillis
+
+ // Bring together vertices and messages
+ println("Joining vertices and messages...")
+ val combinedMsgs = msgs.combineByKey({x => messageCombiner(defaultCombined(), x)}, messageCombiner, mergeCombined, splits)
+ println("verts.splits.size = " + verts.splits.size)
+ println("combinedMsgs.splits.size = " + combinedMsgs.splits.size)
+ println("verts.partitioner = " + verts.partitioner)
+ println("combinedMsgs.partitioner = " + combinedMsgs.partitioner)
+ val joined = verts.groupWith(combinedMsgs)
+ println("joined.splits.size = " + joined.splits.size)
+ println("joined.partitioner = " + joined.partitioner)
+ //val joined = graph.groupByKeyAsymmetrical(messageCombiner, defaultCombined, mergeCombined, splits)
+ println("Done joining vertices and messages.")
+
+ // Run compute on each vertex
+ println("Running compute on each vertex...")
+ var messageCount = sc.accumulator(0)
+ var activeVertexCount = sc.accumulator(0)
+ val processed = joined.flatMapValues {
+ case (Seq(), _) => None
+ case (Seq(v), Seq(comb)) =>
+ val (newVertex, newMessages) = compute(v, comb, superstep)
+ messageCount += newMessages.size
+ if (newVertex.active)
+ activeVertexCount += 1
+ Some((newVertex, newMessages))
+ //val result = ArrayBuffer[(String, Either[V, M])]((newVertex.id, Left(newVertex)))
+ //result ++= newMessages.map(m => (m.targetId, Right(m)))
+ case (Seq(v), Seq()) =>
+ val (newVertex, newMessages) = compute(v, defaultCombined(), superstep)
+ messageCount += newMessages.size
+ if (newVertex.active)
+ activeVertexCount += 1
+ Some((newVertex, newMessages))
+ }.cache
+ //MATEI: Added this
+ processed.foreach(x => {})
+ println("Done running compute on each vertex.")
+
+ println("Checking stopping condition...")
+ val stop = messageCount.value == 0 && activeVertexCount.value == 0
+
+ val timeTaken = System.currentTimeMillis - startTime
+ println("Superstep %d took %d s".format(superstep, timeTaken / 1000))
+
+ val newVerts = processed.mapValues(_._1)
+ val newMsgs = processed.flatMap(x => x._2._2.map(m => (m.targetId, m)))
+
+ if (superstep >= 10)
+ processed.map { _._2._1 }
+ else
+ run(sc, newVerts, newMsgs, splits, messageCombiner, defaultCombined, mergeCombined, superstep + 1)(compute)
+ }
+}
+
+/**
+ * Represents a Pregel vertex. Must be subclassed to store state
+ * along with each vertex. Must be annotated with @serializable.
+ */
+trait Vertex {
+ def id: String
+ def active: Boolean
+}
+
+/**
+ * Represents a Pregel message to a target vertex. Must be
+ * subclassed to contain a payload. Must be annotated with @serializable.
+ */
+trait Message {
+ def targetId: String
+}
+
+/**
+ * Represents a directed edge between two vertices. Owned by the
+ * source vertex, and contains the ID of the target vertex. Must
+ * be subclassed to store state along with each edge. Must be annotated with @serializable.
+ */
+trait Edge {
+ def targetId: String
+}
View
86 bagel/src/main/scala/bagel/ShortestPath.scala
@@ -0,0 +1,86 @@
+package bagel
+
+import spark._
+import spark.SparkContext._
+
+import scala.math.min
+
+/*
+object ShortestPath {
+ def main(args: Array[String]) {
+ if (args.length < 4) {
+ System.err.println("Usage: ShortestPath <graphFile> <startVertex> " +
+ "<numSplits> <host>")
+ System.exit(-1)
+ }
+
+ val graphFile = args(0)
+ val startVertex = args(1)
+ val numSplits = args(2).toInt
+ val host = args(3)
+ val sc = new SparkContext(host, "ShortestPath")
+
+ // Parse the graph data from a file into two RDDs, vertices and messages
+ val lines =
+ (sc.textFile(graphFile)
+ .filter(!_.matches("^\\s*#.*"))
+ .map(line => line.split("\t")))
+
+ val vertices: RDD[(String, Either[SPVertex, SPMessage])] =
+ (lines.groupBy(line => line(0))
+ .map {
+ case (vertexId, lines) => {
+ val outEdges = lines.collect {
+ case Array(_, targetId, edgeValue) =>
+ new SPEdge(targetId, edgeValue.toInt)
+ }
+
+ (vertexId, Left[SPVertex, SPMessage](new SPVertex(vertexId, Int.MaxValue, outEdges, true)))
+ }
+ })
+
+ val messages: RDD[(String, Either[SPVertex, SPMessage])] =
+ (lines.filter(_.length == 2)
+ .map {
+ case Array(vertexId, messageValue) =>
+ (vertexId, Right[SPVertex, SPMessage](new SPMessage(vertexId, messageValue.toInt)))
+ })
+
+ val graph: RDD[(String, Either[SPVertex, SPMessage])] = vertices ++ messages
+
+ System.err.println("Read "+vertices.count()+" vertices and "+
+ messages.count()+" messages.")
+
+ // Do the computation
+ def messageCombiner(minSoFar: Int, message: SPMessage): Int =
+ min(minSoFar, message.value)
+
+ val result = Pregel.run(sc, graph, numSplits, messageCombiner, () => Int.MaxValue, min _) {
+ (self: SPVertex, messageMinValue: Int, superstep: Int) =>
+ val newValue = min(self.value, messageMinValue)
+
+ val outbox =
+ if (newValue != self.value)
+ self.outEdges.map(edge =>
+ new SPMessage(edge.targetId, newValue + edge.value))
+ else
+ List()
+
+ (new SPVertex(self.id, newValue, self.outEdges, false), outbox)
+ }
+
+ // Print the result
+ System.err.println("Shortest path from "+startVertex+" to all vertices:")
+ val shortest = result.map(vertex =>
+ "%s\t%s\n".format(vertex.id, vertex.value match {
+ case x if x == Int.MaxValue => "inf"
+ case x => x
+ })).collect.mkString
+ println(shortest)
+ }
+}
+
+@serializable class SPVertex(val id: String, val value: Int, val outEdges: Seq[SPEdge], val active: Boolean) extends Vertex
+@serializable class SPEdge(val targetId: String, val value: Int) extends Edge
+@serializable class SPMessage(val targetId: String, val value: Int) extends Message
+*/
View
201 bagel/src/main/scala/bagel/WikipediaPageRank.scala
@@ -0,0 +1,201 @@
+package bagel
+
+import spark._
+import spark.SparkContext._
+
+import scala.collection.mutable.ArrayBuffer
+
+import scala.xml.{XML,NodeSeq}
+
+import java.io.{Externalizable,ObjectInput,ObjectOutput,DataOutputStream,DataInputStream}
+
+import com.esotericsoftware.kryo._
+
+object WikipediaPageRank {
+ def main(args: Array[String]) {
+ if (args.length < 4) {
+ System.err.println("Usage: PageRank <inputFile> <threshold> <numSplits> <host> [<noCombiner>]")
+ System.exit(-1)
+ }
+
+ System.setProperty("spark.serialization", "spark.KryoSerialization")
+ System.setProperty("spark.kryo.registrator", classOf[PRKryoRegistrator].getName)
+
+ val inputFile = args(0)
+ val threshold = args(1).toDouble
+ val numSplits = args(2).toInt
+ val host = args(3)
+ val noCombiner = args.length > 4 && args(4).nonEmpty
+ val sc = new SparkContext(host, "WikipediaPageRank")
+
+ // Parse the Wikipedia page data into a graph
+ val input = sc.textFile(inputFile)
+
+ println("Counting vertices...")
+ val numVertices = input.count()
+ println("Done counting vertices.")
+
+ println("Parsing input file...")
+ val vertices: RDD[(String, PRVertex)] = input.map(line => {
+ val fields = line.split("\t")
+ val (title, body) = (fields(1), fields(3).replace("\\n", "\n"))
+ val links =
+ if (body == "\\N")
+ NodeSeq.Empty
+ else
+ try {
+ XML.loadString(body) \\ "link" \ "target"
+ } catch {
+ case e: org.xml.sax.SAXParseException =>
+ System.err.println("Article \""+title+"\" has malformed XML in body:\n"+body)
+ NodeSeq.Empty
+ }
+ val outEdges = ArrayBuffer(links.map(link => new PREdge(new String(link.text))): _*)
+ val id = new String(title)
+ (id, (new PRVertex(id, 1.0 / numVertices, outEdges, true)))
+ })
+ val graph = vertices.groupByKey(numSplits).mapValues(_.head).cache
+
+ println("Done parsing input file.")
+ println("Input file had "+graph.count+" vertices.")
+
+ // Do the computation
+ val epsilon = 0.01 / numVertices
+ val result =
+ if (noCombiner) {
+ val messages = sc.parallelize(List[(String, PRMessage)]())
+ Pregel.run[PRVertex, PRMessage, ArrayBuffer[PRMessage]](sc, graph, messages, numSplits, NoCombiner.messageCombiner, NoCombiner.defaultCombined, NoCombiner.mergeCombined)(NoCombiner.compute(numVertices, epsilon))
+ } else {
+ val messages = sc.parallelize(List[(String, PRMessage)]())
+ Pregel.run[PRVertex, PRMessage, Double](sc, graph, messages, numSplits, Combiner.messageCombiner, Combiner.defaultCombined, Combiner.mergeCombined)(Combiner.compute(numVertices, epsilon))
+ }
+
+ // Print the result
+ System.err.println("Articles with PageRank >= "+threshold+":")
+ val top = result.filter(_.value >= threshold).map(vertex =>
+ "%s\t%s\n".format(vertex.id, vertex.value)).collect.mkString
+ println(top)
+ }
+
+ object Combiner {
+ def messageCombiner(minSoFar: Double, message: PRMessage): Double =
+ minSoFar + message.value
+
+ def mergeCombined(a: Double, b: Double) = a + b
+
+ def defaultCombined(): Double = 0.0
+
+ def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messageSum: Double, superstep: Int): (PRVertex, Iterable[PRMessage]) = {
+ val newValue =
+ if (messageSum != 0)
+ 0.15 / numVertices + 0.85 * messageSum
+ else
+ self.value
+
+ val terminate = (superstep >= 10 && (newValue - self.value).abs < epsilon) || superstep >= 30
+
+ val outbox =
+ if (!terminate)
+ self.outEdges.map(edge =>
+ new PRMessage(edge.targetId, newValue / self.outEdges.size))
+ else
+ ArrayBuffer[PRMessage]()
+
+ (new PRVertex(self.id, newValue, self.outEdges, !terminate), outbox)
+ }
+ }
+
+ object NoCombiner {
+ def messageCombiner(messagesSoFar: ArrayBuffer[PRMessage], message: PRMessage): ArrayBuffer[PRMessage] =
+ messagesSoFar += message
+
+ def mergeCombined(a: ArrayBuffer[PRMessage], b: ArrayBuffer[PRMessage]): ArrayBuffer[PRMessage] =
+ a ++= b
+
+ def defaultCombined(): ArrayBuffer[PRMessage] = ArrayBuffer[PRMessage]()
+
+ def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Seq[PRMessage], superstep: Int): (PRVertex, Iterable[PRMessage]) =
+ Combiner.compute(numVertices, epsilon)(self, messages.map(_.value).sum, superstep)
+ }
+}
+
+@serializable class PRVertex() extends Vertex with Externalizable {
+ var id: String = _
+ var value: Double = _
+ var outEdges: ArrayBuffer[PREdge] = _
+ var active: Boolean = true
+
+ def this(id: String, value: Double, outEdges: ArrayBuffer[PREdge], active: Boolean) {
+ this()
+ this.id = id
+ this.value = value
+ this.outEdges = outEdges
+ this.active = active
+ }
+
+ def writeExternal(out: ObjectOutput) {
+ out.writeUTF(id)
+ out.writeDouble(value)
+ out.writeInt(outEdges.length)
+ for (e <- outEdges)
+ out.writeUTF(e.targetId)
+ out.writeBoolean(active)
+ }
+
+ def readExternal(in: ObjectInput) {
+ id = in.readUTF()
+ value = in.readDouble()
+ val numEdges = in.readInt()
+ outEdges = new ArrayBuffer[PREdge](numEdges)
+ for (i <- 0 until numEdges) {
+ outEdges += new PREdge(in.readUTF())
+ }
+ active = in.readBoolean()
+ }
+}
+
+@serializable class PRMessage() extends Message with Externalizable {
+ var targetId: String = _
+ var value: Double = _
+
+ def this(targetId: String, value: Double) {
+ this()
+ this.targetId = targetId
+ this.value = value
+ }
+
+ def writeExternal(out: ObjectOutput) {
+ out.writeUTF(targetId)
+ out.writeDouble(value)
+ }
+
+ def readExternal(in: ObjectInput) {
+ targetId = in.readUTF()
+ value = in.readDouble()
+ }
+}
+
+@serializable class PREdge() extends Edge with Externalizable {
+ var targetId: String = _
+
+ def this(targetId: String) {
+ this()
+ this.targetId = targetId
+ }
+
+ def writeExternal(out: ObjectOutput) {
+ out.writeUTF(targetId)
+ }
+
+ def readExternal(in: ObjectInput) {
+ targetId = in.readUTF()
+ }
+}
+
+class PRKryoRegistrator extends KryoRegistrator {
+ def registerClasses(kryo: Kryo) {
+ kryo.register(classOf[PRVertex])
+ kryo.register(classOf[PRMessage])
+ kryo.register(classOf[PREdge])
+ }
+}
View
2 project/build/SparkProject.scala
@@ -14,6 +14,8 @@ extends ParentProject(info) with IdeaProject
lazy val examples =
project("examples", "Spark Examples", new ExamplesProject(_), core)
+ lazy val bagel = project("bagel", "Bagel", core)
+
class CoreProject(info: ProjectInfo)
extends DefaultProject(info) with Eclipsify with IdeaProject with DepJar with XmlTestReport
{}

0 comments on commit c0736f6

Please sign in to comment.