Skip to content

Commit f054f10

Browse files
committed
Add KMeans clustering algorithm
1 parent 916012d commit f054f10

File tree

2 files changed

+164
-0
lines changed

2 files changed

+164
-0
lines changed
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
package net.zomis.machlearn.clustering;
2+
3+
import java.util.ArrayList;
4+
import java.util.Arrays;
5+
import java.util.List;
6+
import java.util.Random;
7+
8+
public class KMeans {
9+
10+
public static void main(String[] args) {
11+
Random random = new Random(42);
12+
double[][] inputs = new double[12][2];
13+
for (int i = 0; i < inputs.length; i++) {
14+
inputs[i] = new double[] { random.nextDouble(), random.nextDouble() };
15+
}
16+
System.out.println("a = [");
17+
Arrays.stream(inputs).forEach(d -> System.out.println(Arrays.toString(d) + ";"));
18+
System.out.println(']');
19+
int[] clusters = cluster(inputs, 2, 100, random);
20+
System.out.println("clusters = " + Arrays.toString(clusters) + ';');
21+
System.out.println("a(:,4) = clusters'");
22+
}
23+
24+
private static int[] cluster(double[][] inputs, int clusterCount, int repetitions, Random random) {
25+
// PERFORM FEATURE-SCALING ON INPUTS
26+
27+
int[] bestClusters = null;
28+
double bestCost = 0;
29+
for (int iteration = 0; iteration < repetitions; iteration++) {
30+
KMeansResult result = performClustering(inputs, clusterCount, random);
31+
int[] clusters = result.getClusters();
32+
double[][] centroids = result.getCentroids();
33+
34+
double totalCost = 0;
35+
for (int i = 0; i < inputs.length; i++) {
36+
int cluster = clusters[i];
37+
double[] centroid = centroids[cluster];
38+
double distance = eucledianDistanceSquared(inputs[i], centroid);
39+
totalCost += distance;
40+
}
41+
if (bestClusters == null || totalCost < bestCost) {
42+
bestCost = totalCost;
43+
bestClusters = clusters;
44+
}
45+
}
46+
return bestClusters;
47+
}
48+
49+
private static KMeansResult performClustering(double[][] inputs, int clusterCount, Random random) {
50+
int[] clusters = new int[inputs.length];
51+
double[][] centroids = new double[clusterCount][inputs[0].length];
52+
int[] trainingSetCentroids = new int[centroids.length];
53+
for (int i = 0; i < centroids.length; i++) {
54+
// Initialize centroids to random training set, don't initialize to the same trainingSet
55+
int trainingSet;
56+
do {
57+
trainingSet = random.nextInt(inputs.length);
58+
trainingSetCentroids[i] = trainingSet;
59+
} while (isTaken(trainingSetCentroids, i, trainingSet));
60+
centroids[i] = Arrays.copyOf(inputs[trainingSet], inputs[trainingSet].length);
61+
}
62+
63+
/* Repeat until convergence:
64+
* 1. Mark the clusters according to which one is closest
65+
* 2. Move centroids
66+
*/
67+
boolean changed = true;
68+
while (changed) {
69+
changed = changeClusters(centroids, clusters, inputs);
70+
moveCentroids(centroids, clusters, inputs);
71+
}
72+
return new KMeansResult(clusters, centroids);
73+
}
74+
75+
private static void moveCentroids(double[][] centroids, int[] clusters, double[][] inputs) {
76+
List<List<Integer>> trainingSetsInCluster = new ArrayList<>(centroids.length);
77+
for (int i = 0; i < centroids.length; i++) {
78+
trainingSetsInCluster.add(new ArrayList<>());
79+
}
80+
81+
for (int i = 0; i < inputs.length; i++) {
82+
int cluster = clusters[i];
83+
trainingSetsInCluster.get(cluster).add(i);
84+
}
85+
86+
for (int c = 0; c < trainingSetsInCluster.size(); c++) {
87+
double[] sums = new double[inputs[0].length];
88+
List<Integer> trainingSets = trainingSetsInCluster.get(c);
89+
for (int i : trainingSets) {
90+
for (int j = 0; j < inputs[i].length; j++) {
91+
sums[j] += inputs[i][j];
92+
}
93+
}
94+
centroids[c] = Arrays.stream(sums).map(d -> d / trainingSets.size()).toArray();
95+
}
96+
}
97+
98+
private static boolean changeClusters(double[][] centroids, int[] clusters, double[][] inputs) {
99+
boolean changed = false;
100+
for (int i = 0; i < inputs.length; i++) {
101+
int oldCluster = clusters[i];
102+
clusters[i] = findClosestCluster(inputs[i], centroids);
103+
changed = changed || (oldCluster != clusters[i]);
104+
}
105+
return changed;
106+
}
107+
108+
private static int findClosestCluster(double[] input, double[][] centroids) {
109+
double minDistance = eucledianDistanceSquared(input, centroids[0]);
110+
int closestIndex = 0;
111+
for (int i = 1; i < centroids.length; i++) {
112+
double distance = eucledianDistanceSquared(input, centroids[i]);
113+
if (distance < minDistance) {
114+
minDistance = distance;
115+
closestIndex = i;
116+
}
117+
}
118+
return closestIndex;
119+
}
120+
121+
private static double eucledianDistanceSquared(double[] input, double[] centroid) {
122+
if (input.length != centroid.length) {
123+
throw new IllegalArgumentException("Values must be of same length. Input has length " + input.length +
124+
"while centroid has length " + centroid.length);
125+
}
126+
double sum = 0;
127+
for (int i = 0; i < input.length; i++) {
128+
double diff = input[i] - centroid[i];
129+
sum += diff * diff;
130+
}
131+
return sum;
132+
}
133+
134+
private static boolean isTaken(int[] centroids, int upToIndex, int current) {
135+
for (int i = 0; i < upToIndex; i++) {
136+
if (centroids[i] == current) {
137+
return true;
138+
}
139+
}
140+
return false;
141+
}
142+
143+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package net.zomis.machlearn.clustering;
2+
3+
public class KMeansResult {
4+
5+
private final int[] clusters;
6+
private final double[][] centroids;
7+
8+
public KMeansResult(int[] clusters, double[][] centroids) {
9+
this.clusters = clusters;
10+
this.centroids = centroids;
11+
}
12+
13+
public double[][] getCentroids() {
14+
return centroids;
15+
}
16+
17+
public int[] getClusters() {
18+
return clusters;
19+
}
20+
21+
}

0 commit comments

Comments
 (0)