Skip to content

Commit

Permalink
[SPARK-12215][ML][DOC] User guide section for KMeans in spark.ml
Browse files Browse the repository at this point in the history
cc jkbradley

Author: Yu ISHIKAWA <yuu.ishikawa@gmail.com>

Closes #10244 from yu-iskw/SPARK-12215.
  • Loading branch information
yu-iskw authored and jkbradley committed Dec 16, 2015
1 parent 22f6cd8 commit 26d70bd
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 28 deletions.
71 changes: 71 additions & 0 deletions docs/ml-clustering.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,77 @@ In this section, we introduce the pipeline API for [clustering in mllib](mllib-c
* This will become a table of contents (this text will be scraped).
{:toc}

## K-means

[k-means](http://en.wikipedia.org/wiki/K-means_clustering) is one of the
most commonly used clustering algorithms that clusters the data points into a
predefined number of clusters. The MLlib implementation includes a parallelized
variant of the [k-means++](http://en.wikipedia.org/wiki/K-means%2B%2B) method
called [kmeans||](http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf).

`KMeans` is implemented as an `Estimator` and generates a `KMeansModel` as the base model.

### Input Columns

<table class="table">
<thead>
<tr>
<th align="left">Param name</th>
<th align="left">Type(s)</th>
<th align="left">Default</th>
<th align="left">Description</th>
</tr>
</thead>
<tbody>
<tr>
<td>featuresCol</td>
<td>Vector</td>
<td>"features"</td>
<td>Feature vector</td>
</tr>
</tbody>
</table>

### Output Columns

<table class="table">
<thead>
<tr>
<th align="left">Param name</th>
<th align="left">Type(s)</th>
<th align="left">Default</th>
<th align="left">Description</th>
</tr>
</thead>
<tbody>
<tr>
<td>predictionCol</td>
<td>Int</td>
<td>"prediction"</td>
<td>Predicted cluster center</td>
</tr>
</tbody>
</table>

### Example

<div class="codetabs">

<div data-lang="scala" markdown="1">
Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.clustering.KMeans) for more details.

{% include_example scala/org/apache/spark/examples/ml/KMeansExample.scala %}
</div>

<div data-lang="java" markdown="1">
Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/KMeans.html) for more details.

{% include_example java/org/apache/spark/examples/ml/JavaKMeansExample.java %}
</div>

</div>


## Latent Dirichlet allocation (LDA)

`LDA` is implemented as an `Estimator` that supports both `EMLDAOptimizer` and `OnlineLDAOptimizer`,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,20 @@
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.catalyst.expressions.GenericRow;
// $example on$
import org.apache.spark.ml.clustering.KMeansModel;
import org.apache.spark.ml.clustering.KMeans;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.catalyst.expressions.GenericRow;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
// $example off$


/**
Expand Down Expand Up @@ -74,6 +76,7 @@ public static void main(String[] args) {
JavaSparkContext jsc = new JavaSparkContext(conf);
SQLContext sqlContext = new SQLContext(jsc);

// $example on$
// Loads data
JavaRDD<Row> points = jsc.textFile(inputFile).map(new ParsePoint());
StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())};
Expand All @@ -91,6 +94,7 @@ public static void main(String[] args) {
for (Vector center: centers) {
System.out.println(center);
}
// $example off$

jsc.stop();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,57 +17,54 @@

package org.apache.spark.examples.ml

import org.apache.spark.{SparkContext, SparkConf}
import org.apache.spark.mllib.linalg.{VectorUDT, Vectors}
import org.apache.spark.ml.clustering.KMeans
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.types.{StructField, StructType}
// scalastyle:off println

import org.apache.spark.{SparkConf, SparkContext}
// $example on$
import org.apache.spark.ml.clustering.KMeans
import org.apache.spark.mllib.linalg.Vectors
// $example off$
import org.apache.spark.sql.{DataFrame, SQLContext}

/**
* An example demonstrating a k-means clustering.
* Run with
* {{{
* bin/run-example ml.KMeansExample <file> <k>
* bin/run-example ml.KMeansExample
* }}}
*/
object KMeansExample {

final val FEATURES_COL = "features"

def main(args: Array[String]): Unit = {
if (args.length != 2) {
// scalastyle:off println
System.err.println("Usage: ml.KMeansExample <file> <k>")
// scalastyle:on println
System.exit(1)
}
val input = args(0)
val k = args(1).toInt

// Creates a Spark context and a SQL context
val conf = new SparkConf().setAppName(s"${this.getClass.getSimpleName}")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)

// Loads data
val rowRDD = sc.textFile(input).filter(_.nonEmpty)
.map(_.split(" ").map(_.toDouble)).map(Vectors.dense).map(Row(_))
val schema = StructType(Array(StructField(FEATURES_COL, new VectorUDT, false)))
val dataset = sqlContext.createDataFrame(rowRDD, schema)
// $example on$
// Crates a DataFrame
val dataset: DataFrame = sqlContext.createDataFrame(Seq(
(1, Vectors.dense(0.0, 0.0, 0.0)),
(2, Vectors.dense(0.1, 0.1, 0.1)),
(3, Vectors.dense(0.2, 0.2, 0.2)),
(4, Vectors.dense(9.0, 9.0, 9.0)),
(5, Vectors.dense(9.1, 9.1, 9.1)),
(6, Vectors.dense(9.2, 9.2, 9.2))
)).toDF("id", "features")

// Trains a k-means model
val kmeans = new KMeans()
.setK(k)
.setFeaturesCol(FEATURES_COL)
.setK(2)
.setFeaturesCol("features")
.setPredictionCol("prediction")
val model = kmeans.fit(dataset)

// Shows the result
// scalastyle:off println
println("Final Centers: ")
model.clusterCenters.foreach(println)
// scalastyle:on println
// $example off$

sc.stop()
}
}
// scalastyle:on println

0 comments on commit 26d70bd

Please sign in to comment.