In [72]:
spark

org.apache.spark.sql.SparkSession@5dbcf890

In [73]:
import org.apache.spark.graphx._
def dijkstra[VD](g:Graph[VD,Double], origin:VertexId) = {
    // 초기화
    var distGraph = g.mapVertices(
        (vertexId, props) => (false, if (vertexId == origin) 0 else Double.MaxValue)
    )
    
    // Iteration
    for (i <- 1L to g.vertices.count-1) { 
        // dist가 제일 작은 VertexId 찾기
        val currentVertexId = distGraph.vertices.filter(!_._2._1).fold( (0L, (false, Double.MaxValue)) )((a,b) => if (a._2._2 < b._2._2) a else b)._1
        
        // aggregateMessages 의 return은 Vertices
        val newDistances = distGraph.aggregateMessages[Double]( 
                triplet => // Map Function
                    if (triplet.srcId == currentVertexId)
                        // Send Message to destination vertex
                        triplet.sendToDst(triplet.srcAttr._2 + triplet.attr), 
                (a,b) => math.min(a,b) // reduce phrase
            )
        
        // outerJoinVertices 의 return은 Graph
        distGraph = distGraph.outerJoinVertices(newDistances)(
            (vertexId, props, newDist) => 
                // return (Boolean, NewDist)
                (props._1 || vertexId == currentVertexId, math.min(props._2, newDist.getOrElse(Double.MaxValue)))
            )
    }
    
    // property and dist
    g.outerJoinVertices(distGraph.vertices)((vertexId, props, dist) => (props, dist.getOrElse( (false,Double.MaxValue) )._2 ) )
}

dijkstra: [VD](g: org.apache.spark.graphx.Graph[VD,Double], origin: org.apache.spark.graphx.VertexId)org.apache.spark.graphx.Graph[(VD, Double),Double]


In [74]:
val myVertices = sc.makeRDD(
    Array(
        (1L, "A"), (2L, "B"), (3L, "C"), 
        (4L, "D"), (5L, "E"), (6L, "F"), 
        (7L, "G"))
    )

val myEdges = sc.makeRDD(
    Array(
        Edge(1L, 2L, 7.0), Edge(1L, 4L, 5.0), Edge(2L, 3L, 8.0), 
        Edge(2L, 4L, 9.0), Edge(2L, 5L, 7.0), Edge(3L, 5L, 5.0), 
        Edge(4L, 5L, 15.0), Edge(4L, 6L, 6.0), Edge(5L, 6L, 8.0), 
        Edge(5L, 7L, 9.0), Edge(6L, 7L, 11.0))
    )

val myGraph = Graph(myVertices, myEdges)
dijkstra(myGraph, 1L).vertices.map(_._2).collect

myVertices = ParallelCollectionRDD[1728] at makeRDD at <console>:85
myEdges = ParallelCollectionRDD[1729] at makeRDD at <console>:92
myGraph = org.apache.spark.graphx.impl.GraphImpl@71464fe


Array((A,0.0), (B,7.0), (C,15.0), (D,5.0), (E,14.0), (F,11.0), (G,22.0))

# DEBUG Phrase

In [75]:
var distGraph = myGraph.mapVertices(
    (vertexId, props) => (false, if (vertexId == 1L) 0 else Double.MaxValue)
)

distGraph.vertices.collect

distGraph = org.apache.spark.graphx.impl.GraphImpl@26155e0a


Array((1,(false,0.0)), (2,(false,1.7976931348623157E308)), (3,(false,1.7976931348623157E308)), (4,(false,1.7976931348623157E308)), (5,(false,1.7976931348623157E308)), (6,(false,1.7976931348623157E308)), (7,(false,1.7976931348623157E308)))

In [76]:
val currentVertexId = distGraph.vertices.filter(!_._2._1).fold( (0L, (false, Double.MaxValue)) )((a,b) => if (a._2._2 < b._2._2) a else b)._1

currentVertexId = 1


1

In [77]:
distGraph.vertices.filter(!_._2._1).collect // 아직 방문하지 않은 Vertices

Array((1,(false,0.0)), (2,(false,1.7976931348623157E308)), (3,(false,1.7976931348623157E308)), (4,(false,1.7976931348623157E308)), (5,(false,1.7976931348623157E308)), (6,(false,1.7976931348623157E308)), (7,(false,1.7976931348623157E308)))

In [78]:
val newDistances = distGraph.aggregateMessages[Double]( 
                triplet => // Map Function
                    if (triplet.srcId == currentVertexId)
                        // Send Message to destination vertex
                        triplet.sendToDst(triplet.srcAttr._2 + triplet.attr), 
                (a,b) => math.min(a,b) // reduce phrase
            )

newDistances.collect // VertexRDD

newDistances = VertexRDDImpl[1855] at RDD at VertexRDD.scala:57


Array((2,7.0), (4,5.0))

In [79]:
dijkstra(myGraph, 1L).vertices.collect

Array((1,(A,0.0)), (2,(B,7.0)), (3,(C,15.0)), (4,(D,5.0)), (5,(E,14.0)), (6,(F,11.0)), (7,(G,22.0)))