/
KMeans.scala
249 lines (220 loc) · 8.02 KB
/
KMeans.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.examples.scala.clustering
import org.apache.flink.api.common.functions._
import org.apache.flink.api.scala._
import org.apache.flink.configuration.Configuration
import org.apache.flink.examples.java.clustering.util.KMeansData
import scala.collection.JavaConverters._
/**
* This example implements a basic K-Means clustering algorithm.
*
* K-Means is an iterative clustering algorithm and works as follows:
* K-Means is given a set of data points to be clustered and an initial set of ''K'' cluster
* centers.
* In each iteration, the algorithm computes the distance of each data point to each cluster center.
* Each point is assigned to the cluster center which is closest to it.
* Subsequently, each cluster center is moved to the center (''mean'') of all points that have
* been assigned to it.
* The moved cluster centers are fed into the next iteration.
* The algorithm terminates after a fixed number of iterations (as in this implementation)
* or if cluster centers do not (significantly) move in an iteration.
* This is the Wikipedia entry for the [[http://en.wikipedia
* .org/wiki/K-means_clustering K-Means Clustering algorithm]].
*
* This implementation works on two-dimensional data points.
* It computes an assignment of data points to cluster centers, i.e.,
* each data point is annotated with the id of the final cluster (center) it belongs to.
*
* Input files are plain text files and must be formatted as follows:
*
* - Data points are represented as two double values separated by a blank character.
* Data points are separated by newline characters.
* For example `"1.2 2.3\n5.3 7.2\n"` gives two data points (x=1.2, y=2.3) and (x=5.3,
* y=7.2).
* - Cluster centers are represented by an integer id and a point value.
* For example `"1 6.2 3.2\n2 2.9 5.7\n"` gives two centers (id=1, x=6.2,
* y=3.2) and (id=2, x=2.9, y=5.7).
*
* Usage:
* {{{
* KMeans <points path> <centers path> <result path> <num iterations>
* }}}
* If no parameters are provided, the program is run with default data from
* [[org.apache.flink.examples.java.clustering.util.KMeansData]]
* and 10 iterations.
*
* This example shows how to use:
*
* - Bulk iterations
* - Broadcast variables in bulk iterations
* - Custom Java objects (PoJos)
*/
object KMeans {
def main(args: Array[String]) {
if (!parseParameters(args)) {
return
}
val env = ExecutionEnvironment.getExecutionEnvironment
val points: DataSet[Point] = getPointDataSet(env)
val centroids: DataSet[Centroid] = getCentroidDataSet(env)
val finalCentroids = centroids.iterate(numIterations) { currentCentroids =>
val newCentroids = points
.map(new SelectNearestCenter).withBroadcastSet(currentCentroids, "centroids")
.map { x => (x._1, x._2, 1L) }
.groupBy(0)
.reduce { (p1, p2) => (p1._1, p1._2.add(p2._2), p1._3 + p2._3) }
.map { x => new Centroid(x._1, x._2.div(x._3)) }
newCentroids
}
val clusteredPoints: DataSet[(Int, Point)] =
points.map(new SelectNearestCenter).withBroadcastSet(finalCentroids, "centroids")
if (fileOutput) {
clusteredPoints.writeAsCsv(outputPath, "\n", " ")
}
else {
clusteredPoints.print()
}
env.execute("Scala KMeans Example")
}
private def parseParameters(programArguments: Array[String]): Boolean = {
if (programArguments.length > 0) {
fileOutput = true
if (programArguments.length == 4) {
pointsPath = programArguments(0)
centersPath = programArguments(1)
outputPath = programArguments(2)
numIterations = Integer.parseInt(programArguments(3))
}
else {
System.err.println("Usage: KMeans <points path> <centers path> <result path> <num " +
"iterations>")
false
}
}
else {
System.out.println("Executing K-Means example with default parameters and built-in default " +
"data.")
System.out.println(" Provide parameters to read input data from files.")
System.out.println(" See the documentation for the correct format of input files.")
System.out.println(" We provide a data generator to create synthetic input files for this " +
"program.")
System.out.println(" Usage: KMeans <points path> <centers path> <result path> <num " +
"iterations>")
}
true
}
private def getPointDataSet(env: ExecutionEnvironment): DataSet[Point] = {
if (fileOutput) {
env.readCsvFile[(Double, Double)](
pointsPath,
fieldDelimiter = ' ',
includedFields = Array(0, 1))
.map { x => new Point(x._1, x._2)}
}
else {
val points = KMeansData.POINTS map {
case Array(x, y) => new Point(x.asInstanceOf[Double], y.asInstanceOf[Double])
}
env.fromCollection(points)
}
}
private def getCentroidDataSet(env: ExecutionEnvironment): DataSet[Centroid] = {
if (fileOutput) {
env.readCsvFile[(Int, Double, Double)](
centersPath,
fieldDelimiter = ' ',
includedFields = Array(0, 1, 2))
.map { x => new Centroid(x._1, x._2, x._3)}
}
else {
val centroids = KMeansData.CENTROIDS map {
case Array(id, x, y) =>
new Centroid(id.asInstanceOf[Int], x.asInstanceOf[Double], y.asInstanceOf[Double])
}
env.fromCollection(centroids)
}
}
private var fileOutput: Boolean = false
private var pointsPath: String = null
private var centersPath: String = null
private var outputPath: String = null
private var numIterations: Int = 10
/**
* A simple two-dimensional point.
*/
class Point(var x: Double, var y: Double) extends Serializable {
def this() {
this(0, 0)
}
def add(other: Point): Point = {
x += other.x
y += other.y
this
}
def div(other: Long): Point = {
x /= other
y /= other
this
}
def euclideanDistance(other: Point): Double = {
Math.sqrt((x - other.x) * (x - other.x) + (y - other.y) * (y - other.y))
}
def clear(): Unit = {
x = 0
y = 0
}
override def toString: String = {
x + " " + y
}
}
/**
* A simple two-dimensional centroid, basically a point with an ID.
*/
class Centroid(var id: Int, x: Double, y: Double) extends Point(x, y) {
def this() {
this(0, 0, 0)
}
def this(id: Int, p: Point) {
this(id, p.x, p.y)
}
override def toString: String = {
id + " " + super.toString
}
}
/** Determines the closest cluster center for a data point. */
final class SelectNearestCenter extends RichMapFunction[Point, (Int, Point)] {
private var centroids: Traversable[Centroid] = null
/** Reads the centroid values from a broadcast variable into a collection. */
override def open(parameters: Configuration) {
centroids = getRuntimeContext.getBroadcastVariable[Centroid]("centroids").asScala
}
def map(p: Point): (Int, Point) = {
var minDistance: Double = Double.MaxValue
var closestCentroidId: Int = -1
for (centroid <- centroids) {
val distance = p.euclideanDistance(centroid)
if (distance < minDistance) {
minDistance = distance
closestCentroidId = centroid.id
}
}
(closestCentroidId, p)
}
}
}