Skip to content

Commit

Permalink
[FLINK-2149][gelly] Simplified Jaccard Example
Browse files Browse the repository at this point in the history
This PR simplifies Gelly's Jaccard example by using the more efficient reduceOnNeighbors rather than groupReduceOnNeighbors.

Author: andralungu <lungu.andra@gmail.com>

Closes #770 from andralungu/jaccardImprovement and squashes the following commits:

6e77f8d [andralungu] [FLINK-2149][gelly] Simplified Jaccard Example
  • Loading branch information
andralungu authored and andra committed Jun 20, 2015
1 parent 9ee4fa5 commit b2be80d
Showing 1 changed file with 49 additions and 55 deletions.
Expand Up @@ -23,16 +23,13 @@
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.graph.Edge;
import org.apache.flink.graph.EdgeDirection;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.ReduceNeighborsFunction;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.EdgesFunction;
import org.apache.flink.graph.Triplet;
import org.apache.flink.graph.example.utils.JaccardSimilarityMeasureData;
import org.apache.flink.types.NullValue;
import org.apache.flink.util.Collector;

import java.util.HashSet;

Expand Down Expand Up @@ -66,34 +63,45 @@ public static void main(String [] args) throws Exception {

DataSet<Edge<Long, Double>> edges = getEdgesDataSet(env);

Graph<Long, NullValue, Double> graph = Graph.fromDataSet(edges, env);
Graph<Long, HashSet<Long>, Double> graph = Graph.fromDataSet(edges,
new MapFunction<Long, HashSet<Long>>() {

DataSet<Vertex<Long, HashSet<Long>>> verticesWithNeighbors =
graph.groupReduceOnEdges(new GatherNeighbors(), EdgeDirection.ALL);
@Override
public HashSet<Long> map(Long id) throws Exception {
HashSet<Long> neighbors = new HashSet<Long>();
neighbors.add(id);

Graph<Long, HashSet<Long>, Double> graphWithVertexValues = Graph.fromDataSet(verticesWithNeighbors, edges, env);
return new HashSet<Long>(neighbors);
}
}, env);

// the edge value will be the Jaccard similarity coefficient(number of common neighbors/ all neighbors)
DataSet<Tuple3<Long, Long, Double>> edgesWithJaccardWeight = graphWithVertexValues.getTriplets()
.map(new WeighEdgesMapper());
// create the set of neighbors
DataSet<Tuple2<Long, HashSet<Long>>> computedNeighbors =
graph.reduceOnNeighbors(new GatherNeighbors(), EdgeDirection.ALL);

DataSet<Edge<Long, Double>> result = graphWithVertexValues.joinWithEdges(edgesWithJaccardWeight,
new MapFunction<Tuple2<Double, Double>, Double>() {
// join with the vertices to update the node values
Graph<Long, HashSet<Long>, Double> graphWithVertexValues =
graph.joinWithVertices(computedNeighbors, new MapFunction<Tuple2<HashSet<Long>, HashSet<Long>>,
HashSet<Long>>() {

@Override
public Double map(Tuple2<Double, Double> value) throws Exception {
return value.f1;
public HashSet<Long> map(Tuple2<HashSet<Long>, HashSet<Long>> tuple2) throws Exception {
return tuple2.f1;
}
}).getEdges();
});

// compare neighbors, compute Jaccard
DataSet<Edge<Long, Double>> edgesWithJaccardValues =
graphWithVertexValues.getTriplets().map(new ComputeJaccard());

// emit result
if (fileOutput) {
result.writeAsCsv(outputPath, "\n", ",");
edgesWithJaccardValues.writeAsCsv(outputPath, "\n", ",");

// since file sinks are lazy, we trigger the execution explicitly
env.execute("Executing Jaccard Similarity Measure");
} else {
result.print();
edgesWithJaccardValues.print();
}

}
Expand All @@ -106,20 +114,14 @@ public String getDescription() {
/**
* Each vertex will have a HashSet containing its neighbor ids as value.
*/
private static final class GatherNeighbors implements EdgesFunction<Long, Double, Vertex<Long, HashSet<Long>>> {
@SuppressWarnings("serial")
private static final class GatherNeighbors implements ReduceNeighborsFunction<HashSet<Long>> {

@Override
public void iterateEdges(Iterable<Tuple2<Long, Edge<Long, Double>>> edges,
Collector<Vertex<Long, HashSet<Long>>> out) throws Exception {

HashSet<Long> neighborsHashSet = new HashSet<Long>();
long vertexId = -1;

for(Tuple2<Long, Edge<Long, Double>> edge : edges) {
neighborsHashSet.add(getNeighborID(edge));
vertexId = edge.f0;
}
out.collect(new Vertex<Long, HashSet<Long>>(vertexId, neighborsHashSet));
public HashSet<Long> reduceNeighbors(HashSet<Long> first,
HashSet<Long> second) {
first.addAll(second);
return new HashSet<Long>(first);
}
}

Expand All @@ -134,37 +136,29 @@ public void iterateEdges(Iterable<Tuple2<Long, Edge<Long, Double>>> edges,
*
* The Jaccard similarity coefficient is then, the intersection/union.
*/
private static class WeighEdgesMapper implements MapFunction<Triplet<Long, HashSet<Long>, Double>,
Tuple3<Long, Long, Double>> {
@SuppressWarnings("serial")
private static final class ComputeJaccard implements
MapFunction<Triplet<Long, HashSet<Long>, Double>, Edge<Long, Double>> {

@Override
public Tuple3<Long, Long, Double> map(Triplet<Long, HashSet<Long>, Double> triplet)
throws Exception {
public Edge<Long, Double> map(Triplet<Long, HashSet<Long>, Double> triplet) throws Exception {

Vertex<Long, HashSet<Long>> source = triplet.getSrcVertex();
Vertex<Long, HashSet<Long>> target = triplet.getTrgVertex();
Vertex<Long, HashSet<Long>> srcVertex = triplet.getSrcVertex();
Vertex<Long, HashSet<Long>> trgVertex = triplet.getTrgVertex();

long unionPlusIntersection = source.getValue().size() + target.getValue().size();
// within a HashSet, all elements are distinct
source.getValue().addAll(target.getValue());
// the source value contains the union
long union = source.getValue().size();
long intersection = unionPlusIntersection - union;
Long x = srcVertex.getId();
Long y = trgVertex.getId();
HashSet<Long> neighborSetY = trgVertex.getValue();

return new Tuple3<Long, Long, Double>(source.getId(), target.getId(), (double) intersection/union);
}
}
double unionPlusIntersection = srcVertex.getValue().size() + neighborSetY.size();
// within a HashSet, all elements are distinct
HashSet<Long> unionSet = new HashSet<Long>();
unionSet.addAll(srcVertex.getValue());
unionSet.addAll(neighborSetY);
double union = unionSet.size();
double intersection = unionPlusIntersection - union;

/**
* Helper method that extracts the neighborId given an edge.
* @param edge
* @return
*/
private static Long getNeighborID(Tuple2<Long, Edge<Long, Double>> edge) {
if(edge.f1.getSource() == edge.f0) {
return edge.f1.getTarget();
} else {
return edge.f1.getSource();
return new Edge<Long, Double>(x, y, intersection/union);
}
}

Expand Down

0 comments on commit b2be80d

Please sign in to comment.