Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing an Improved Pregel API #1217

Closed
wants to merge 12 commits into from
147 changes: 134 additions & 13 deletions graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,58 @@ import org.apache.spark.Logging


/**
* Implements a Pregel-like bulk-synchronous message-passing API.
* The Pregel Vertex class contains the vertex attribute and the boolean flag
* indicating whether the vertex is active.
*
* Unlike the original Pregel API, the GraphX Pregel API factors the sendMessage computation over
* edges, enables the message sending computation to read both vertex attributes, and constrains
* messages to the graph structure. These changes allow for substantially more efficient
* distributed execution while also exposing greater flexibility for graph-based computation.
* @param attr the vertex prorperty during the pregel computation (e.g.,
* PageRank)
* @param isActive a flag indicating whether the vertex is active
* @tparam T the type of the vertex property
*/
sealed case class PregelVertex[@specialized T]
(attr: T, isActive: Boolean = true) extends Product2[T, Boolean] {
override def _1: T = attr
override def _2: Boolean = isActive
}


/**
* The Pregel API enables users to express iterative graph algorithms in GraphX
* and is loosely based on the Google and GraphLab APIs.
*
* At a high-level iterative graph algorithms like PageRank recursively define
* vertex properties in terms of the properties of neighboring vertices. These
* recursive properties are then computed through iterative fixed-point
* computations. For example, the PageRank of a web-page can be defined as a
* weighted sum of the PageRank of web-pages that link to that page and is
* computed by iteratively updating the PageRank of each page until a
* fixed-point is reached (the PageRank values stop changing).
*
* The GraphX Pregel API expresses iterative graph algorithms as vertex-programs
* which can send and receive messages from neighboring vertices. Vertex
* programs have the following logic:
*
* {{{
* while ( there are active vertices ) {
* for ( v in allVertices ) {
* // Send messages to neighbors
* if (isActive) {
* for (nbr in neighbors(v)) {
* val msgs: List[(id, msg)] = computeMsgs(Triplet(v, nbr))
* for ((id, msg) <- msgs) {
* messageSum(id) = reduce(messageSum(id), msg)
* }
* }
* }
* // Receive the "sum" of the messages to v and update the property
* (vertexProperty(v), isActive) =
* vertexProgram(vertexProperty(v), messagesSum(v))
* }
* }
* }}}
*
* The user defined `vertexProgram`, `computeMessage`, and `reduce` functions
* capture the core logic of graph algorithms.
*
* @example We can use the Pregel abstraction to implement PageRank:
* {{{
Expand All @@ -41,19 +87,90 @@ import org.apache.spark.Logging
* // Set the vertex attributes to the initial pagerank values
* .mapVertices((id, attr) => 1.0)
*
* def vertexProgram(id: VertexId, attr: Double, msgSum: Double): Double =
* resetProb + (1.0 - resetProb) * msgSum
* def sendMessage(id: VertexId, edge: EdgeTriplet[Double, Double]): Iterator[(VertexId, Double)] =
* Iterator((edge.dstId, edge.srcAttr * edge.attr))
* // Define the vertex program and message calculation functions.
* def vertexProgram(iter: Int, id: VertexId, oldV: PregelVertex[Double],
* msgSum: Option[Double]) = {
* PregelVertex(resetProb + (1.0 - resetProb) * msgSum.getOrElse(0.0))
* }
*
* def computeMsgs(iter: Int, edge: EdgeTriplet[PregelVertex[Double], Double]) = {
* Iterator((edge.dstId, edge.srcAttr.attr * edge.attr))
* }
*
* def messageCombiner(a: Double, b: Double): Double = a + b
* val initialMessage = 0.0
* // Execute Pregel for a fixed number of iterations.
* Pregel(pagerankGraph, initialMessage, numIter)(
* vertexProgram, sendMessage, messageCombiner)
*
* // Run PageRank
* val prGraph = Pregel.run(pagerankGraph, numIter, activeDirection = EdgeDirection.Out)(
* vertexProgram, sendMessage, messageCombiner).cache()
*
* // Normalize the pagerank vector:
* val normalizer: Double = prGraph.vertices.map(x => x._2).reduce(_ + _)
*
* prGraph.mapVertices((id, pr) => pr / normalizer)
*
* }}}
*
*/
object Pregel extends Logging {
/**
* The new Pregel API.
*/
def run[VD: ClassTag, ED: ClassTag, A: ClassTag]
(graph: Graph[VD, ED],
maxIterations: Int = Int.MaxValue,
activeDirection: EdgeDirection = EdgeDirection.Either)
(vertexProgram: (Int, VertexId, PregelVertex[VD], Option[A]) => PregelVertex[VD],
computeMsgs: (Int, EdgeTriplet[PregelVertex[VD], ED]) => Iterator[(VertexId, A)],
mergeMsg: (A, A) => A)
: Graph[VD, ED] =
{
// Initialize the graph with all vertices active
var currengGraph: Graph[PregelVertex[VD], ED] =
graph.mapVertices { (vid, vdata) => PregelVertex(vdata) }.cache()
// Determine the set of vertices that did not vote to halt
var activeVertices = currengGraph.vertices
var numActive = activeVertices.count()
var iteration = 0
while (numActive > 0 && iteration < maxIterations) {
// get a reference to the current graph to enable unprecistance.
val prevG = currengGraph

// Compute the messages for all the active vertices
val messages = currengGraph.mapReduceTriplets( t => computeMsgs(iteration, t), mergeMsg,
Some((activeVertices, activeDirection)))

// Receive the messages to the subset of active vertices
currengGraph = currengGraph.outerJoinVertices(messages){ (vid, pVertex, msgOpt) =>
// If the vertex voted to halt and received no message then we can skip the vertex program
if (!pVertex.isActive && msgOpt.isEmpty) {
pVertex
} else {
// The vertex program is either active or received a message (or both).
// A vertex program should vote to halt again even if it has previously voted to halt
vertexProgram(iteration, vid, pVertex, msgOpt)
}
}.cache()

// Recompute the active vertices (those that have not voted to halt)
activeVertices = currengGraph.vertices.filter(v => v._2._2)

// Force all computation!
numActive = activeVertices.count()

// Unpersist the RDDs hidden by newly-materialized RDDs
//prevG.unpersistVertices(blocking=false)
//prevG.edges.unpersist(blocking=false)

//println("Finished Iteration " + i)
// g.vertices.foreach(println(_))

logInfo("Pregel finished iteration " + iteration)
// count the iteration
iteration += 1
}
currengGraph.mapVertices((id, vdata) => vdata.attr)
} // end of apply


/**
* Execute a Pregel-like iterative vertex-parallel abstraction. The
Expand Down Expand Up @@ -109,6 +226,7 @@ object Pregel extends Logging {
* @return the resulting graph at the end of the computation
*
*/
// @deprecated ("Switching to Pregel.run.", "1.1")
def apply[VD: ClassTag, ED: ClassTag, A: ClassTag]
(graph: Graph[VD, ED],
initialMsg: A,
Expand Down Expand Up @@ -158,4 +276,7 @@ object Pregel extends Logging {
g
} // end of apply




} // end of class Pregel
38 changes: 22 additions & 16 deletions graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.reflect.ClassTag

import org.apache.spark.Logging
import org.apache.spark.graphx._
import org.apache.spark.graphx.Pregel._

/**
* PageRank algorithm implementation. There are two implementations of PageRank implemented.
Expand Down Expand Up @@ -144,34 +145,39 @@ object PageRank extends Logging {
}
// Set the weight on the edges based on the degree
.mapTriplets( e => 1.0 / e.srcAttr )
// Set the vertex attributes to (initalPR, delta = 0)
.mapVertices( (id, attr) => (0.0, 0.0) )
// Set the vertex attributes to (currentPr, deltaToSend)
.mapVertices( (id, attr) => (resetProb, (1.0 - resetProb) * resetProb) )
.cache()

// Define the three functions needed to implement PageRank in the GraphX
// version of Pregel
def vertexProgram(id: VertexId, attr: (Double, Double), msgSum: Double): (Double, Double) = {
val (oldPR, lastDelta) = attr
val newPR = oldPR + (1.0 - resetProb) * msgSum
(newPR, newPR - oldPR)
def vertexProgram(iter: Int, id: VertexId, vertex: PregelVertex[(Double, Double)],
msgSum: Option[Double]) = {
var (oldPR, pendingDelta) = vertex.attr
val newPR = oldPR + msgSum.getOrElse(0.0)
// if we were active then we sent the pending delta on the last iteration
if (vertex.isActive) {
pendingDelta = 0.0
}
pendingDelta += (1.0 - resetProb) * msgSum.getOrElse(0.0)
val isActive = math.abs(pendingDelta) >= tol
PregelVertex((newPR, pendingDelta), isActive)
}

def sendMessage(edge: EdgeTriplet[(Double, Double), Double]) = {
if (edge.srcAttr._2 > tol) {
Iterator((edge.dstId, edge.srcAttr._2 * edge.attr))
} else {
Iterator.empty
}
def sendMessage(iter: Int, edge: EdgeTriplet[PregelVertex[(Double, Double)], Double]) = {
val PregelVertex((srcPr, srcDelta), srcIsActive) = edge.srcAttr
assert(srcIsActive)
Iterator((edge.dstId, srcDelta * edge.attr))
}

def messageCombiner(a: Double, b: Double): Double = a + b

// The initial message received by all vertices in PageRank
val initialMessage = resetProb / (1.0 - resetProb)

// Execute a dynamic version of Pregel.
Pregel(pagerankGraph, initialMessage, activeDirection = EdgeDirection.Out)(
val prGraph = Pregel.run(pagerankGraph, activeDirection = EdgeDirection.Out)(
vertexProgram, sendMessage, messageCombiner)
.mapVertices((vid, attr) => attr._1)
.cache()
val normalizer: Double = prGraph.vertices.map(x => x._2).reduce(_ + _)
prGraph.mapVertices((id, pr) => pr / normalizer)
} // end of deltaPageRank
}
Original file line number Diff line number Diff line change
Expand Up @@ -265,4 +265,19 @@ object GraphGenerators {
Graph.fromEdgeTuples(edges, 1)
} // end of starGraph


/**
* Create a cycle graph.
*
* @param sc the spark context in which to construct the graph
* @param nverts the number of vertices in the cycle
*
* @return A cycle graph containing `nverts` vertices with vertex 0
* being the center vertex.
*/
def cycleGraph(sc: SparkContext, nverts: Int): Graph[Int, Int] = {
val edges: RDD[(VertexId, VertexId)] = sc.parallelize(0 until nverts).map(vid => (vid, (vid + 1) % nverts))
Graph.fromEdgeTuples(edges, 1)
} // end of starGraph

} // end of Graph Generators