Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Adding Java K-Means example

  • Loading branch information...
commit 568ddf73307f125227bced4277fcc3a5be0adb28 1 parent b990cae
@MLnick authored
View
111 examples/src/main/java/spark/examples/JavaKMeans.java
@@ -0,0 +1,111 @@
+package spark.examples;
+
+import scala.Tuple2;
+import spark.api.java.JavaPairRDD;
+import spark.api.java.JavaRDD;
+import spark.api.java.JavaSparkContext;
+import spark.api.java.function.Function;
+import spark.api.java.function.PairFunction;
+import spark.util.Vector;
+
+import java.util.List;
+import java.util.Map;
+
+public class JavaKMeans {
+
+ /** Parses numbers split by whitespace to a vector */
+ static Vector parseVector(String line) {
+ String[] splits = line.split(" ");
+ double[] data = new double[splits.length];
+ int i = 0;
+ for (String s : splits)
+ data[i] = Double.parseDouble(splits[i++]);
+ return new Vector(data);
+ }
+
+ /** Computes the vector to which the input vector is closest using squared distance */
+ static int closestPoint(Vector p, List<Vector> centers) {
+ int bestIndex = 0;
+ double closest = Double.POSITIVE_INFINITY;
+ for (int i = 0; i < centers.size(); i++) {
+ double tempDist = p.squaredDist(centers.get(i));
+ if (tempDist < closest) {
+ closest = tempDist;
+ bestIndex = i;
+ }
+ }
+ return bestIndex;
+ }
+
+ /** Computes the mean across all vectors in the input set of vectors */
+ static Vector average(List<Vector> ps) {
+ int numVectors = ps.size();
+ Vector out = new Vector(ps.get(0).elements());
+ // start from i = 1 since we already copied index 0 above
+ for (int i = 1; i < numVectors; i++) {
+ out.addInPlace(ps.get(i));
+ }
+ return out.divide(numVectors);
+ }
+
+ public static void main(String[] args) throws Exception {
+ if (args.length < 4) {
+ System.err.println("Usage: SparkKMeans <master> <file> <k> <convergeDist>");
+ System.exit(1);
+ }
+ JavaSparkContext sc = new JavaSparkContext(args[0], "JavaKMeans",
+ System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR"));
+ String path = args[1];
+ int K = Integer.parseInt(args[2]);
+ double convergeDist = Double.parseDouble(args[3]);
+
+ JavaRDD<Vector> data = sc.textFile(path).map(
+ new Function<String, Vector>() {
+ @Override
+ public Vector call(String line) throws Exception {
+ return parseVector(line);
+ }
+ }
+ ).cache();
+
+ final List<Vector> centroids = data.takeSample(false, K, 42);
+
+ double tempDist;
+ do {
+ // allocate each vector to closest centroid
+ JavaPairRDD<Integer, Vector> closest = data.map(
+ new PairFunction<Vector, Integer, Vector>() {
+ @Override
+ public Tuple2<Integer, Vector> call(Vector vector) throws Exception {
+ return new Tuple2<Integer, Vector>(
+ closestPoint(vector, centroids), vector);
+ }
+ }
+ );
+
+ // group by cluster id and average the vectors within each cluster to compute centroids
+ JavaPairRDD<Integer, List<Vector>> pointsGroup = closest.groupByKey();
+ Map<Integer, Vector> newCentroids = pointsGroup.mapValues(
+ new Function<List<Vector>, Vector>() {
+ public Vector call(List<Vector> ps) throws Exception {
+ return average(ps);
+ }
+ }).collectAsMap();
+ tempDist = 0.0;
+ for (int i = 0; i < K; i++) {
+ tempDist += centroids.get(i).squaredDist(newCentroids.get(i));
+ }
+ for (Map.Entry<Integer, Vector> t: newCentroids.entrySet()) {
+ centroids.set(t.getKey(), t.getValue());
+ }
+ System.out.println("Finished iteration (delta = " + tempDist + ")");
+ } while (tempDist > convergeDist);
+
+ System.out.println("Final centers:");
+ for (Vector c : centroids)
+ System.out.println(c);
+
+ System.exit(0);
+
+}
+}
View
1  examples/src/main/scala/spark/examples/SparkKMeans.scala
@@ -64,6 +64,7 @@ object SparkKMeans {
for (newP <- newPoints) {
kPoints(newP._1) = newP._2
}
+ println("Finished iteration (delta = " + tempDist + ")")
}
println("Final centers:")
Please sign in to comment.
Something went wrong with that request. Please try again.