diff --git a/graphx/src/main/scala/org/apache/spark/graphx/WikiPipelineBenchmark.scala b/graphx/src/main/scala/org/apache/spark/graphx/WikiPipelineBenchmark.scala index 4aa2e793152e0..355734505a1bd 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/WikiPipelineBenchmark.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/WikiPipelineBenchmark.scala @@ -14,6 +14,14 @@ import java.util.{HashSet => JHashSet, TreeSet => JTreeSet} object WikiPipelineBenchmark extends Logging { + def time[A](label: String)(fn: => A): A = { + val startTime = System.currentTimeMillis + logWarning("Starting %s...".format(label)) + val result = fn + logWarning("Finished %s. Time: %f".format(label, (System.currentTimeMillis - startTime) / 1000.0)) + result + } + def main(args: Array[String]) = { val host = args(0) @@ -30,13 +38,15 @@ object WikiPipelineBenchmark extends Logging { case "graphx" => { val rawData = args(2) val numIters = args(3).toInt - benchmarkGraphx(sc, rawData, numIters) + val numPRIters = args(4).toInt + val numParts = args(5).toInt + benchmarkGraphx(sc, rawData, numIters, numPRIters, numParts) } case "extract" => { val rawData = args(2) val outBase = args(3) - val (vertices, edges) = extractLinkGraph(sc, rawData) + val (vertices, edges) = extractLinkGraph(sc, rawData, 128) val g = Graph(vertices, edges) val cleanG = g.subgraph(x => true, (vid, vd) => vd != null).cache val rawEdges = cleanG.edges.map(e => (e.srcId, e.dstId)) @@ -57,67 +67,70 @@ object WikiPipelineBenchmark extends Logging { } - def benchmarkGraphx(sc: SparkContext, rawData: String, numIters: Int) { - val (vertices, edges) = extractLinkGraph(sc, rawData) - logWarning("creating graph") - val g = Graph(vertices, edges) - val cleanG = g.subgraph(x => true, (vid, vd) => vd != null).cache - logWarning(s"DIRTY graph has ${g.triplets.count()} EDGES, ${g.vertices.count()} VERTICES") - logWarning(s"CLEAN graph has ${cleanG.triplets.count()} EDGES, ${cleanG.vertices.count()} VERTICES") - val resultG = pagerankConnComponentsAlt(numIters, cleanG) + def benchmarkGraphx(sc: SparkContext, rawData: String, numIters: Int, numPRIters: Int, numParts: Int) { + val (vertices, edges) = extractLinkGraph(sc, rawData, numParts) + val g = time("graph creation") { + val result = Graph(vertices, edges) + logWarning("Graph has %d vertex partitions, %d edge partitions".format(result.vertices.partitions.length, result.edges.partitions.length)) + logWarning(s"DIRTY graph has ${result.triplets.count()} EDGES, ${result.vertices.count()} VERTICES") + result + } + // TODO: try reindexing + val cleanG = g.subgraph(x => true, (vid, vd) => vd != null).partitionBy(PartitionStrategy.EdgePartition2D).cache() + cleanG.vertices.setName("cleanG vertices") + cleanG.edges.setName("cleanG edges") + time("graph cleaning and repartitioning") { + logWarning(s"CLEAN graph has ${cleanG.triplets.count()} EDGES, ${cleanG.vertices.count()} VERTICES") + } + val resultG = pagerankConnComponentsAlt(numIters, cleanG, numPRIters) logWarning(s"ORIGINAL graph has ${cleanG.triplets.count()} EDGES, ${cleanG.vertices.count()} VERTICES") logWarning(s"FINAL graph has ${resultG.triplets.count()} EDGES, ${resultG.vertices.count()} VERTICES") } - def pagerankConnComponentsAlt(numRepetitions: Int, g: Graph[String, Double]): Graph[String, Double] = { + def pagerankConnComponentsAlt(numRepetitions: Int, g: Graph[String, Double], numPRIters: Int): Graph[String, Double] = { var currentGraph = g logWarning("starting iterations") for (i <- 0 to numRepetitions) { - currentGraph.cache - val startTime = System.currentTimeMillis - logWarning("starting pagerank") - // GRAPH VIEW - val ccStartTime = System.currentTimeMillis - val ccGraph = ConnectedComponents.run(currentGraph).cache - val zeroVal = new JTreeSet[VertexId]() - val seqOp = (s: JTreeSet[VertexId], vtuple: (VertexId, VertexId)) => { - s.add(vtuple._2) - s - } - val combOp = (s1: JTreeSet[VertexId], s2: JTreeSet[VertexId]) => { - s1.addAll(s2) - s1 - } - // TABLE VIEW - val numCCs = ccGraph.vertices.aggregate(zeroVal)(seqOp, combOp).size() - val ccEndTime = System.currentTimeMillis - logWarning(s"Connected Components TIMEX: ${(ccEndTime - ccStartTime)/1000.0}") - logWarning(s"Number of connected components for iteration $i: $numCCs") - val prStartTime = System.currentTimeMillis - val pr = PageRank.run(currentGraph, 20).cache - pr.vertices.count - val prEndTime = System.currentTimeMillis - logWarning(s"Pagerank TIMEX: ${(prEndTime - prStartTime)/1000.0}") - logWarning("Pagerank completed") - // TABLE VIEW - val prAndTitle = currentGraph.outerJoinVertices(pr.vertices)({(id: VertexId, title: String, rank: Option[Double]) => (title, rank.getOrElse(0.0))}).cache - prAndTitle.vertices.count - // logWarning("join completed.") - val top20 = prAndTitle.vertices.top(20)(Ordering.by((entry: (VertexId, (String, Double))) => entry._2._2)) - logWarning(s"Top20 for iteration $i:\n${top20.mkString("\n")}") - val top20verts = top20.map(_._1).toSet - // filter out top 20 vertices - val filterTop20 = {(v: VertexId, d: String) => - !top20verts.contains(v) + currentGraph.cache() + currentGraph.vertices.setName("currentGraph vertices %d".format(i)) + currentGraph.edges.setName("currentGraph edges %d".format(i)) + time("stage %d".format(i)) { + // GRAPH VIEW + time("connected components, stage %d".format(i)) { + val ccGraph = ConnectedComponents.run(currentGraph) + val numCCs = ccGraph.vertices.map { case (id, cc) => cc }.distinct(1).count + logWarning(s"Number of connected components for iteration $i: $numCCs") + ccGraph.unpersistVertices(blocking = false) + } + val pr = + time("pagerank, stage %d".format(i)) { + PageRank.run(currentGraph, numPRIters) + } + // TABLE VIEW + val top20verts = + time("top 20 pages, stage %d".format(i)) { + val prAndTitle = currentGraph.outerJoinVertices(pr.vertices)({(id: VertexId, title: String, rank: Option[Double]) => (title, rank.getOrElse(0.0))}) + val top20 = prAndTitle.vertices.top(20)(Ordering.by((entry: (VertexId, (String, Double))) => entry._2._2)) + pr.unpersistVertices(blocking = false) + prAndTitle.unpersistVertices(blocking = false) + logWarning(s"Top20 for iteration $i:\n${top20.mkString("\n")}") + top20.map(_._1).toSet + } + val newGraph = + time("filter out top 20 pages, stage %d".format(i)) { + // filter out top 20 vertices + val filterTop20 = {(v: VertexId, d: String) => + !top20verts.contains(v) + } + val result = currentGraph.subgraph(x => true, filterTop20).cache() + result.vertices.setName("newGraph vertices %d".format(i)) + result.edges.setName("newGraph edges %d".format(i)) + result.vertices.count + result + } + currentGraph.unpersistVertices(blocking = false) + currentGraph = newGraph } - val newGraph = currentGraph.subgraph(x => true, filterTop20).cache - newGraph.vertices.count - logWarning(s"TOTAL_TIMEX iter $i ${(System.currentTimeMillis - startTime)/1000.0}") - currentGraph.unpersistVertices(blocking = false) - ccGraph.unpersistVertices(blocking = false) - pr.unpersistVertices(blocking = false) - prAndTitle.unpersistVertices(blocking = false) - currentGraph = newGraph } currentGraph } @@ -141,29 +154,25 @@ object WikiPipelineBenchmark extends Logging { } } - def extractLinkGraph(sc: SparkContext, rawData: String): (RDD[(VertexId, String)], RDD[Edge[Double]]) = { + def extractLinkGraph(sc: SparkContext, rawData: String, numParts: Int): (RDD[(VertexId, String)], RDD[Edge[Double]]) = { val conf = new Configuration conf.set("key.value.separator.in.input.line", " ") conf.set("xmlinput.start", "") conf.set("xmlinput.end", "") - logWarning("about to load xml rdd") val xmlRDD = sc.newAPIHadoopFile(rawData, classOf[XmlInputFormat], classOf[LongWritable], classOf[Text], conf) - .map(t => t._2.toString) - // xmlRDD.count - logWarning(s"XML RDD counted. Found ${xmlRDD.count} raw articles.") - val repartXMLRDD = xmlRDD.repartition(128) - logWarning(s"XML RDD repartitioned. Found ${repartXMLRDD.count} raw articles.") - - val allArtsRDD = repartXMLRDD.map { raw => new WikiArticle(raw) }.cache - logWarning(s"Total articles: Found ${allArtsRDD.count} UNPARTITIONED articles.") - - val wikiRDD = allArtsRDD.filter { art => art.relevant }.cache //.repartition(128) - logWarning(s"wikiRDD counted. Found ${wikiRDD.count} relevant articles in ${wikiRDD.partitions.size} partitions") + .map(t => t._2.toString).coalesce(numParts, false) + val allArtsRDD = xmlRDD.map { raw => new WikiArticle(raw) } + + val wikiRDD = + time("filter relevant articles") { + val result = allArtsRDD.filter { art => art.relevant }.cache().setName("wikiRDD") + logWarning(s"wikiRDD counted. Found ${result.count} relevant articles in ${result.partitions.size} partitions") + result + } val vertices: RDD[(VertexId, String)] = wikiRDD.map { art => (art.vertexID, art.title) } val edges: RDD[Edge[Double]] = wikiRDD.flatMap { art => art.edges } (vertices, edges) - } def pipelinePostProcessing(sc: SparkContext, basePath: String, iter: Int) { @@ -176,16 +185,17 @@ object WikiPipelineBenchmark extends Logging { val rankAndTitle = artNames.join(pageranks) val top20 = rankAndTitle.top(20)(Ordering.by((entry: (VertexId, (String, Double))) => entry._2._2)) logWarning(s"Top20 for iteration $iter:\n${top20.mkString("\n")}") - val zeroVal = new JTreeSet[VertexId]() - val seqOp = (s: JTreeSet[VertexId], vtuple: (VertexId, VertexId)) => { - s.add(vtuple._2) - s - } - val combOp = (s1: JTreeSet[VertexId], s2: JTreeSet[VertexId]) => { - s1.addAll(s2) - s1 - } - val numCCs = connComponents.aggregate(zeroVal)(seqOp, combOp).size() + val numCCs = connComponents.map{ case (id, cc) => cc }.distinct(1).count + // val zeroVal = new JTreeSet[VertexId]() + // val seqOp = (s: JTreeSet[VertexId], vtuple: (VertexId, VertexId)) => { + // s.add(vtuple._2) + // s + // } + // val combOp = (s1: JTreeSet[VertexId], s2: JTreeSet[VertexId]) => { + // s1.addAll(s2) + // s1 + // } + // val numCCs = connComponents.aggregate(zeroVal)(seqOp, combOp).size() logWarning(s"Number of connected components for iteration $iter: $numCCs") val top20verts = top20.map(_._1).toSet val newVertices = artNames.filter { case (v, d) => !top20verts.contains(v) }