Skip to content

Commit 40474b0

Browse files
committed
In Pregel, pass iteration number to vprog and sendMsg
To accommodate the operations before the loop, this changes the iteration numbering to start at 1 instead of 0.
1 parent 669e3f0 commit 40474b0

File tree

1 file changed

+26
-10
lines changed

1 file changed

+26
-10
lines changed

graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -109,26 +109,28 @@ object Pregel extends Logging {
109109
* @return the resulting graph at the end of the computation
110110
*
111111
*/
112-
def apply[VD: ClassTag, ED: ClassTag, A: ClassTag]
112+
def runWithIterationNumber[VD: ClassTag, ED: ClassTag, A: ClassTag]
113113
(graph: Graph[VD, ED],
114114
initialMsg: A,
115115
maxIterations: Int = Int.MaxValue,
116116
activeDirection: EdgeDirection = EdgeDirection.Either)
117-
(vprog: (VertexId, VD, A) => VD,
118-
sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
117+
(vprog: (Int, VertexId, VD, A) => VD,
118+
sendMsg: (Int, EdgeTriplet[VD, ED]) => Iterator[(VertexId, A)],
119119
mergeMsg: (A, A) => A)
120120
: Graph[VD, ED] =
121121
{
122-
var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache()
122+
var g = graph.mapVertices((vid, vdata) => vprog(0, vid, vdata, initialMsg)).cache()
123123
// compute the messages
124-
var messages = g.mapReduceTriplets(sendMsg, mergeMsg)
124+
var messages = g.mapReduceTriplets(e => sendMsg(0, e), mergeMsg)
125125
var activeMessages = messages.count()
126126
// Loop
127127
var prevG: Graph[VD, ED] = null
128-
var i = 0
129-
while (activeMessages > 0 && i < maxIterations) {
128+
var i = 1
129+
while (activeMessages > 0 && i <= maxIterations) {
130130
// Receive the messages. Vertices that didn't get any messages do not appear in newVerts.
131-
val newVerts = g.vertices.innerJoin(messages)(vprog).cache()
131+
val newVerts = g.vertices.innerJoin(messages) {
132+
(id, attr, msg) => vprog(i, id, attr, msg)
133+
}.cache()
132134
// Update the graph with the new vertices.
133135
prevG = g
134136
g = g.outerJoinVertices(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) }
@@ -138,7 +140,8 @@ object Pregel extends Logging {
138140
// Send new messages. Vertices that didn't get any messages don't appear in newVerts, so don't
139141
// get to send messages. We must cache messages so it can be materialized on the next line,
140142
// allowing us to uncache the previous iteration.
141-
messages = g.mapReduceTriplets(sendMsg, mergeMsg, Some((newVerts, activeDirection))).cache()
143+
messages = g.mapReduceTriplets(
144+
e => sendMsg(i, e), mergeMsg, Some((newVerts, activeDirection))).cache()
142145
// The call to count() materializes `messages`, `newVerts`, and the vertices of `g`. This
143146
// hides oldMessages (depended on by newVerts), newVerts (depended on by messages), and the
144147
// vertices of prevG (depended on by newVerts, oldMessages, and the vertices of g).
@@ -158,4 +161,17 @@ object Pregel extends Logging {
158161
g
159162
} // end of apply
160163

161-
} // end of class Pregel
164+
def apply[VD: ClassTag, ED: ClassTag, A: ClassTag]
165+
(graph: Graph[VD, ED],
166+
initialMsg: A,
167+
maxIterations: Int = Int.MaxValue,
168+
activeDirection: EdgeDirection = EdgeDirection.Either)
169+
(vprog: (VertexId, VD, A) => VD,
170+
sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
171+
mergeMsg: (A, A) => A): Graph[VD, ED] = {
172+
runWithIterationNumber(graph, initialMsg, maxIterations, activeDirection)(
173+
(iteration, id, attr, msg) => vprog(id, attr, msg),
174+
(iteration, e) => sendMsg(e),
175+
mergeMsg)
176+
}
177+
}// end of class Pregel

0 commit comments

Comments
 (0)