Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bagel: Large-scale graph processing on Spark #48

Merged
merged 10 commits into from May 13, 2011
159 changes: 159 additions & 0 deletions bagel/src/main/scala/spark/bagel/Bagel.scala
@@ -0,0 +1,159 @@
package spark.bagel

import spark._
import spark.SparkContext._

import scala.collection.mutable.ArrayBuffer

object Bagel extends Logging {
def run[V <: Vertex : Manifest, M <: Message : Manifest, C : Manifest, A : Manifest](
sc: SparkContext,
verts: RDD[(String, V)],
msgs: RDD[(String, M)]
)(
combiner: Combiner[M, C] = new DefaultCombiner[M],
aggregator: Aggregator[V, A] = new NullAggregator[V],
superstep: Int = 0,
numSplits: Int = sc.numCores
)(
compute: (V, Option[C], A, Int) => (V, Iterable[M])
): RDD[V] = {

logInfo("Starting superstep "+superstep+".")
val startTime = System.currentTimeMillis

val aggregated = agg(verts, aggregator)
val combinedMsgs = msgs.combineByKey(combiner.createCombiner, combiner.mergeMsg, combiner.mergeCombiners, numSplits)
val grouped = verts.groupWith(combinedMsgs)
val (processed, numMsgs, numActiveVerts) = comp[V, M, C](sc, grouped, compute(_, _, aggregated, superstep))

val timeTaken = System.currentTimeMillis - startTime
logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000))

// Check stopping condition and iterate
val noActivity = numMsgs == 0 && numActiveVerts == 0
if (noActivity) {
processed.map { case (id, (vert, msgs)) => vert }
} else {
val newVerts = processed.mapValues { case (vert, msgs) => vert }
val newMsgs = processed.flatMap {
case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m))
}
run(sc, newVerts, newMsgs)(combiner, aggregator, superstep + 1, numSplits)(compute)
}
}

/**
* Aggregates the given vertices using the given aggregator, or does
* nothing if it is a NullAggregator.
*/
def agg[V <: Vertex, A : Manifest](verts: RDD[(String, V)], aggregator: Aggregator[V, A]): A = aggregator match {
case _: NullAggregator[_] =>
None
case _ =>
verts.map {
case (id, vert) => aggregator.createAggregator(vert)
}.reduce(aggregator.mergeAggregators(_, _))
}

/**
* Processes the given vertex-message RDD using the compute
* function. Returns the processed RDD, the number of messages
* created, and the number of active vertices.
*/
def comp[V <: Vertex, M <: Message, C](sc: SparkContext, grouped: RDD[(String, (Seq[V], Seq[C]))], compute: (V, Option[C]) => (V, Iterable[M])): (RDD[(String, (V, Iterable[M]))], Int, Int) = {
var numMsgs = sc.accumulator(0)
var numActiveVerts = sc.accumulator(0)
val processed = grouped.flatMapValues {
case (Seq(), _) => None
case (Seq(v), c) =>
val (newVert, newMsgs) =
compute(v, c match {
case Seq(comb) => Some(comb)
case Seq() => None
})

numMsgs += newMsgs.size
if (newVert.active)
numActiveVerts += 1

Some((newVert, newMsgs))
}.cache

// Force evaluation of processed RDD for accurate performance measurements
processed.foreach(x => {})

(processed, numMsgs.value, numActiveVerts.value)
}

/**
* Converts a compute function that doesn't take an aggregator to
* one that does, so it can be passed to Bagel.run.
*/
implicit def addAggregatorArg[
V <: Vertex : Manifest, M <: Message : Manifest, C
](
compute: (V, Option[C], Int) => (V, Iterable[M])
): (V, Option[C], Option[Nothing], Int) => (V, Iterable[M]) = {
(vert: V, messages: Option[C], aggregator: Option[Nothing], superstep: Int) => compute(vert, messages, superstep)
}
}

// TODO: Simplify Combiner interface and make it more OO.
trait Combiner[M, C] {
def createCombiner(msg: M): C
def mergeMsg(combiner: C, msg: M): C
def mergeCombiners(a: C, b: C): C
}

trait Aggregator[V, A] {
def createAggregator(vert: V): A
def mergeAggregators(a: A, b: A): A
}

@serializable
class DefaultCombiner[M] extends Combiner[M, ArrayBuffer[M]] {
def createCombiner(msg: M): ArrayBuffer[M] =
ArrayBuffer(msg)
def mergeMsg(combiner: ArrayBuffer[M], msg: M): ArrayBuffer[M] =
combiner += msg
def mergeCombiners(a: ArrayBuffer[M], b: ArrayBuffer[M]): ArrayBuffer[M] =
a ++= b
}

@serializable
class NullAggregator[V] extends Aggregator[V, Option[Nothing]] {
def createAggregator(vert: V): Option[Nothing] = None
def mergeAggregators(a: Option[Nothing], b: Option[Nothing]): Option[Nothing] = None
}

/**
* Represents a Bagel vertex.
*
* Subclasses may store state along with each vertex and must be
* annotated with @serializable.
*/
trait Vertex {
def id: String
def active: Boolean
}

/**
* Represents a Bagel message to a target vertex.
*
* Subclasses may contain a payload to deliver to the target vertex
* and must be annotated with @serializable.
*/
trait Message {
def targetId: String
}

/**
* Represents a directed edge between two vertices.
*
* Subclasses may store state along each edge and must be annotated
* with @serializable.
*/
trait Edge {
def targetId: String
}
96 changes: 96 additions & 0 deletions bagel/src/main/scala/spark/bagel/examples/ShortestPath.scala
@@ -0,0 +1,96 @@
package spark.bagel.examples

import spark._
import spark.SparkContext._

import scala.math.min

import spark.bagel._
import spark.bagel.Bagel._

object ShortestPath {
def main(args: Array[String]) {
if (args.length < 4) {
System.err.println("Usage: ShortestPath <graphFile> <startVertex> " +
"<numSplits> <host>")
System.exit(-1)
}

val graphFile = args(0)
val startVertex = args(1)
val numSplits = args(2).toInt
val host = args(3)
val sc = new SparkContext(host, "ShortestPath")

// Parse the graph data from a file into two RDDs, vertices and messages
val lines =
(sc.textFile(graphFile)
.filter(!_.matches("^\\s*#.*"))
.map(line => line.split("\t")))

val vertices: RDD[(String, SPVertex)] =
(lines.groupBy(line => line(0))
.map {
case (vertexId, lines) => {
val outEdges = lines.collect {
case Array(_, targetId, edgeValue) =>
new SPEdge(targetId, edgeValue.toInt)
}

(vertexId, new SPVertex(vertexId, Int.MaxValue, outEdges, true))
}
})

val messages: RDD[(String, SPMessage)] =
(lines.filter(_.length == 2)
.map {
case Array(vertexId, messageValue) =>
(vertexId, new SPMessage(vertexId, messageValue.toInt))
})

System.err.println("Read "+vertices.count()+" vertices and "+
messages.count()+" messages.")

// Do the computation
val compute = addAggregatorArg {
(self: SPVertex, messageMinValue: Option[Int], superstep: Int) =>
val newValue = messageMinValue match {
case Some(minVal) => min(self.value, minVal)
case None => self.value
}

val outbox =
if (newValue != self.value)
self.outEdges.map(edge =>
new SPMessage(edge.targetId, newValue + edge.value))
else
List()

(new SPVertex(self.id, newValue, self.outEdges, false), outbox)
}
val result = Bagel.run(sc, vertices, messages)(combiner = MinCombiner, numSplits = numSplits)(compute)

// Print the result
System.err.println("Shortest path from "+startVertex+" to all vertices:")
val shortest = result.map(vertex =>
"%s\t%s\n".format(vertex.id, vertex.value match {
case x if x == Int.MaxValue => "inf"
case x => x
})).collect.mkString
println(shortest)
}
}

@serializable
object MinCombiner extends Combiner[SPMessage, Int] {
def createCombiner(msg: SPMessage): Int =
msg.value
def mergeMsg(combiner: Int, msg: SPMessage): Int =
min(combiner, msg.value)
def mergeCombiners(a: Int, b: Int): Int =
min(a, b)
}

@serializable class SPVertex(val id: String, val value: Int, val outEdges: Seq[SPEdge], val active: Boolean) extends Vertex
@serializable class SPEdge(val targetId: String, val value: Int) extends Edge
@serializable class SPMessage(val targetId: String, val value: Int) extends Message