Skip to content

Commit

Permalink
Clean up WikiPipelineBenchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
ankurdave committed Mar 28, 2014
1 parent e8be08e commit 8609184
Showing 1 changed file with 89 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ import java.util.{HashSet => JHashSet, TreeSet => JTreeSet}

object WikiPipelineBenchmark extends Logging {

def time[A](label: String)(fn: => A): A = {
val startTime = System.currentTimeMillis
logWarning("Starting %s...".format(label))
val result = fn
logWarning("Finished %s. Time: %f".format(label, (System.currentTimeMillis - startTime) / 1000.0))
result
}

def main(args: Array[String]) = {

val host = args(0)
Expand All @@ -30,13 +38,15 @@ object WikiPipelineBenchmark extends Logging {
case "graphx" => {
val rawData = args(2)
val numIters = args(3).toInt
benchmarkGraphx(sc, rawData, numIters)
val numPRIters = args(4).toInt
val numParts = args(5).toInt
benchmarkGraphx(sc, rawData, numIters, numPRIters, numParts)
}

case "extract" => {
val rawData = args(2)
val outBase = args(3)
val (vertices, edges) = extractLinkGraph(sc, rawData)
val (vertices, edges) = extractLinkGraph(sc, rawData, 128)
val g = Graph(vertices, edges)
val cleanG = g.subgraph(x => true, (vid, vd) => vd != null).cache
val rawEdges = cleanG.edges.map(e => (e.srcId, e.dstId))
Expand All @@ -57,67 +67,70 @@ object WikiPipelineBenchmark extends Logging {

}

def benchmarkGraphx(sc: SparkContext, rawData: String, numIters: Int) {
val (vertices, edges) = extractLinkGraph(sc, rawData)
logWarning("creating graph")
val g = Graph(vertices, edges)
val cleanG = g.subgraph(x => true, (vid, vd) => vd != null).cache
logWarning(s"DIRTY graph has ${g.triplets.count()} EDGES, ${g.vertices.count()} VERTICES")
logWarning(s"CLEAN graph has ${cleanG.triplets.count()} EDGES, ${cleanG.vertices.count()} VERTICES")
val resultG = pagerankConnComponentsAlt(numIters, cleanG)
def benchmarkGraphx(sc: SparkContext, rawData: String, numIters: Int, numPRIters: Int, numParts: Int) {
val (vertices, edges) = extractLinkGraph(sc, rawData, numParts)
val g = time("graph creation") {
val result = Graph(vertices, edges)
logWarning("Graph has %d vertex partitions, %d edge partitions".format(result.vertices.partitions.length, result.edges.partitions.length))
logWarning(s"DIRTY graph has ${result.triplets.count()} EDGES, ${result.vertices.count()} VERTICES")
result
}
// TODO: try reindexing
val cleanG = g.subgraph(x => true, (vid, vd) => vd != null).partitionBy(PartitionStrategy.EdgePartition2D).cache()
cleanG.vertices.setName("cleanG vertices")
cleanG.edges.setName("cleanG edges")
time("graph cleaning and repartitioning") {
logWarning(s"CLEAN graph has ${cleanG.triplets.count()} EDGES, ${cleanG.vertices.count()} VERTICES")
}
val resultG = pagerankConnComponentsAlt(numIters, cleanG, numPRIters)
logWarning(s"ORIGINAL graph has ${cleanG.triplets.count()} EDGES, ${cleanG.vertices.count()} VERTICES")
logWarning(s"FINAL graph has ${resultG.triplets.count()} EDGES, ${resultG.vertices.count()} VERTICES")
}

def pagerankConnComponentsAlt(numRepetitions: Int, g: Graph[String, Double]): Graph[String, Double] = {
def pagerankConnComponentsAlt(numRepetitions: Int, g: Graph[String, Double], numPRIters: Int): Graph[String, Double] = {
var currentGraph = g
logWarning("starting iterations")
for (i <- 0 to numRepetitions) {
currentGraph.cache
val startTime = System.currentTimeMillis
logWarning("starting pagerank")
// GRAPH VIEW
val ccStartTime = System.currentTimeMillis
val ccGraph = ConnectedComponents.run(currentGraph).cache
val zeroVal = new JTreeSet[VertexId]()
val seqOp = (s: JTreeSet[VertexId], vtuple: (VertexId, VertexId)) => {
s.add(vtuple._2)
s
}
val combOp = (s1: JTreeSet[VertexId], s2: JTreeSet[VertexId]) => {
s1.addAll(s2)
s1
}
// TABLE VIEW
val numCCs = ccGraph.vertices.aggregate(zeroVal)(seqOp, combOp).size()
val ccEndTime = System.currentTimeMillis
logWarning(s"Connected Components TIMEX: ${(ccEndTime - ccStartTime)/1000.0}")
logWarning(s"Number of connected components for iteration $i: $numCCs")
val prStartTime = System.currentTimeMillis
val pr = PageRank.run(currentGraph, 20).cache
pr.vertices.count
val prEndTime = System.currentTimeMillis
logWarning(s"Pagerank TIMEX: ${(prEndTime - prStartTime)/1000.0}")
logWarning("Pagerank completed")
// TABLE VIEW
val prAndTitle = currentGraph.outerJoinVertices(pr.vertices)({(id: VertexId, title: String, rank: Option[Double]) => (title, rank.getOrElse(0.0))}).cache
prAndTitle.vertices.count
// logWarning("join completed.")
val top20 = prAndTitle.vertices.top(20)(Ordering.by((entry: (VertexId, (String, Double))) => entry._2._2))
logWarning(s"Top20 for iteration $i:\n${top20.mkString("\n")}")
val top20verts = top20.map(_._1).toSet
// filter out top 20 vertices
val filterTop20 = {(v: VertexId, d: String) =>
!top20verts.contains(v)
currentGraph.cache()
currentGraph.vertices.setName("currentGraph vertices %d".format(i))
currentGraph.edges.setName("currentGraph edges %d".format(i))
time("stage %d".format(i)) {
// GRAPH VIEW
time("connected components, stage %d".format(i)) {
val ccGraph = ConnectedComponents.run(currentGraph)
val numCCs = ccGraph.vertices.map { case (id, cc) => cc }.distinct(1).count
logWarning(s"Number of connected components for iteration $i: $numCCs")
ccGraph.unpersistVertices(blocking = false)
}
val pr =
time("pagerank, stage %d".format(i)) {
PageRank.run(currentGraph, numPRIters)
}
// TABLE VIEW
val top20verts =
time("top 20 pages, stage %d".format(i)) {
val prAndTitle = currentGraph.outerJoinVertices(pr.vertices)({(id: VertexId, title: String, rank: Option[Double]) => (title, rank.getOrElse(0.0))})
val top20 = prAndTitle.vertices.top(20)(Ordering.by((entry: (VertexId, (String, Double))) => entry._2._2))
pr.unpersistVertices(blocking = false)
prAndTitle.unpersistVertices(blocking = false)
logWarning(s"Top20 for iteration $i:\n${top20.mkString("\n")}")
top20.map(_._1).toSet
}
val newGraph =
time("filter out top 20 pages, stage %d".format(i)) {
// filter out top 20 vertices
val filterTop20 = {(v: VertexId, d: String) =>
!top20verts.contains(v)
}
val result = currentGraph.subgraph(x => true, filterTop20).cache()
result.vertices.setName("newGraph vertices %d".format(i))
result.edges.setName("newGraph edges %d".format(i))
result.vertices.count
result
}
currentGraph.unpersistVertices(blocking = false)
currentGraph = newGraph
}
val newGraph = currentGraph.subgraph(x => true, filterTop20).cache
newGraph.vertices.count
logWarning(s"TOTAL_TIMEX iter $i ${(System.currentTimeMillis - startTime)/1000.0}")
currentGraph.unpersistVertices(blocking = false)
ccGraph.unpersistVertices(blocking = false)
pr.unpersistVertices(blocking = false)
prAndTitle.unpersistVertices(blocking = false)
currentGraph = newGraph
}
currentGraph
}
Expand All @@ -141,29 +154,25 @@ object WikiPipelineBenchmark extends Logging {
}
}

def extractLinkGraph(sc: SparkContext, rawData: String): (RDD[(VertexId, String)], RDD[Edge[Double]]) = {
def extractLinkGraph(sc: SparkContext, rawData: String, numParts: Int): (RDD[(VertexId, String)], RDD[Edge[Double]]) = {
val conf = new Configuration
conf.set("key.value.separator.in.input.line", " ")
conf.set("xmlinput.start", "<page>")
conf.set("xmlinput.end", "</page>")

logWarning("about to load xml rdd")
val xmlRDD = sc.newAPIHadoopFile(rawData, classOf[XmlInputFormat], classOf[LongWritable], classOf[Text], conf)
.map(t => t._2.toString)
// xmlRDD.count
logWarning(s"XML RDD counted. Found ${xmlRDD.count} raw articles.")
val repartXMLRDD = xmlRDD.repartition(128)
logWarning(s"XML RDD repartitioned. Found ${repartXMLRDD.count} raw articles.")

val allArtsRDD = repartXMLRDD.map { raw => new WikiArticle(raw) }.cache
logWarning(s"Total articles: Found ${allArtsRDD.count} UNPARTITIONED articles.")

val wikiRDD = allArtsRDD.filter { art => art.relevant }.cache //.repartition(128)
logWarning(s"wikiRDD counted. Found ${wikiRDD.count} relevant articles in ${wikiRDD.partitions.size} partitions")
.map(t => t._2.toString).coalesce(numParts, false)
val allArtsRDD = xmlRDD.map { raw => new WikiArticle(raw) }

val wikiRDD =
time("filter relevant articles") {
val result = allArtsRDD.filter { art => art.relevant }.cache().setName("wikiRDD")
logWarning(s"wikiRDD counted. Found ${result.count} relevant articles in ${result.partitions.size} partitions")
result
}
val vertices: RDD[(VertexId, String)] = wikiRDD.map { art => (art.vertexID, art.title) }
val edges: RDD[Edge[Double]] = wikiRDD.flatMap { art => art.edges }
(vertices, edges)

}

def pipelinePostProcessing(sc: SparkContext, basePath: String, iter: Int) {
Expand All @@ -176,16 +185,17 @@ object WikiPipelineBenchmark extends Logging {
val rankAndTitle = artNames.join(pageranks)
val top20 = rankAndTitle.top(20)(Ordering.by((entry: (VertexId, (String, Double))) => entry._2._2))
logWarning(s"Top20 for iteration $iter:\n${top20.mkString("\n")}")
val zeroVal = new JTreeSet[VertexId]()
val seqOp = (s: JTreeSet[VertexId], vtuple: (VertexId, VertexId)) => {
s.add(vtuple._2)
s
}
val combOp = (s1: JTreeSet[VertexId], s2: JTreeSet[VertexId]) => {
s1.addAll(s2)
s1
}
val numCCs = connComponents.aggregate(zeroVal)(seqOp, combOp).size()
val numCCs = connComponents.map{ case (id, cc) => cc }.distinct(1).count
// val zeroVal = new JTreeSet[VertexId]()
// val seqOp = (s: JTreeSet[VertexId], vtuple: (VertexId, VertexId)) => {
// s.add(vtuple._2)
// s
// }
// val combOp = (s1: JTreeSet[VertexId], s2: JTreeSet[VertexId]) => {
// s1.addAll(s2)
// s1
// }
// val numCCs = connComponents.aggregate(zeroVal)(seqOp, combOp).size()
logWarning(s"Number of connected components for iteration $iter: $numCCs")
val top20verts = top20.map(_._1).toSet
val newVertices = artNames.filter { case (v, d) => !top20verts.contains(v) }
Expand Down

0 comments on commit 8609184

Please sign in to comment.