/
macqueen_kmeans.js
74 lines (70 loc) · 1.43 KB
/
macqueen_kmeans.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
/**
* MacQueen's k-Means algorithm
*/
export default class MacQueenKMeans {
// Some methods for classification and analysis of multivariate observations
/**
* @param {number} k Number of clusters
*/
constructor(k) {
this._k = k
this._c = []
this._n = []
this._d = (a, b) => Math.sqrt(a.reduce((s, v, i) => s + (v - b[i]) ** 2, 0))
}
/**
* Centroids
*
* @type {Array<Array<number>>}
*/
get centroids() {
return this._c
}
/**
* Fit model.
*
* @param {Array<Array<number>>} datas Training data
*/
fit(datas) {
const n = datas.length
for (let i = 0; i < n; i++) {
if (this._c.length < this._k) {
this._c.push(datas[i].concat())
this._n.push(1)
continue
}
let min_d = Infinity
let min_k = 0
for (let k = 0; k < this._k; k++) {
const d = this._d(datas[i], this._c[k])
if (d < min_d) {
min_d = d
min_k = k
}
}
this._c[min_k] = this._c[min_k].map((c, j) => (c * this._n[min_k] + datas[i][j]) / (this._n[min_k] + 1))
this._n[min_k]++
}
}
/**
* Returns predicted categories.
*
* @param {Array<Array<number>>} datas Sample data
* @returns {number[]} Predicted values
*/
predict(datas) {
const p = []
for (let i = 0; i < datas.length; i++) {
let min_d = Infinity
p[i] = -1
for (let k = 0; k < this._c.length; k++) {
const d = this._d(datas[i], this._c[k])
if (d < min_d) {
min_d = d
p[i] = k
}
}
}
return p
}
}