diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index 65a1a8c68f6d2..81493580137cd 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -335,7 +335,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * } * }}} */ - def outerJoinVertices[U: ClassTag, VD2: ClassTag](other: RDD[(VertexId, U)]) + def outerJoinVertices[U: ClassTag, VD2: ClassTag](other: RDD[(VertexId, U)], destructive: Boolean = false) (mapFunc: (VertexId, VD, Option[U]) => VD2) : Graph[VD2, ED] diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index 5d98d3b83b69b..1786fa75b5c17 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -132,7 +132,7 @@ object Pregel extends Logging { val newVerts = g.vertices.innerJoin(messages)(vprog).cache() // Update the graph with the new vertices. prevG = g - g = g.outerJoinVertices(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) } + g = g.outerJoinVertices(newVerts, destructive = i > 0) { (vid, old, newOpt) => newOpt.getOrElse(old) } g.cache() val oldMessages = messages diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index 3a6beb46746d3..dd8519fe1de7c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -281,7 +281,7 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( } // end of mapReduceTriplets override def outerJoinVertices[U: ClassTag, VD2: ClassTag] - (other: RDD[(VertexId, U)]) + (other: RDD[(VertexId, U)], destructive: Boolean = false) (updateF: (VertexId, VD, Option[U]) => VD2): Graph[VD2, ED] = { if (classTag[VD] equals classTag[VD2]) { @@ -290,7 +290,7 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( val changedVerts = vertices.asInstanceOf[VertexRDD[VD2]].diff(newVerts) val newReplicatedVertexView = new ReplicatedVertexView[VD2]( changedVerts, edges, routingTable, - Some(replicatedVertexView.asInstanceOf[ReplicatedVertexView[VD2]])) + Some(replicatedVertexView.asInstanceOf[ReplicatedVertexView[VD2]]), destructive) new GraphImpl(newVerts, edges, routingTable, newReplicatedVertexView) } else { // updateF does not preserve type, so we must re-replicate all vertices diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala index b7bca80784670..38530b78fa005 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala @@ -42,7 +42,8 @@ class ReplicatedVertexView[VD: ClassTag]( updatedVerts: VertexRDD[VD], edges: EdgeRDD[_], routingTable: RoutingTable, - prevViewOpt: Option[ReplicatedVertexView[VD]] = None) { + prevViewOpt: Option[ReplicatedVertexView[VD]] = None, + destructive: Boolean = false) { /** * Within each edge partition, create a local map from vid to an index into the attribute @@ -122,6 +123,7 @@ class ReplicatedVertexView[VD: ClassTag]( val shippedVerts = routingTable.get(includeSrc, includeDst) .zipPartitions(verts)(ReplicatedVertexView.buildBuffer(_, _)(vdTag)) .partitionBy(edges.partitioner.get) + val destructiveLocal = destructive // to avoid closure capture // TODO: Consider using a specialized shuffler. prevViewOpt match {