/
mda.js
126 lines (115 loc) · 3.17 KB
/
mda.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import Matrix from '../util/matrix.js'
/**
* Mixture discriminant analysis
*/
export default class MixtureDiscriminant {
// http://www.personal.psu.edu/jol2/course/stat597e/notes2/mda.pdf
/**
* @param {number} r Number of components
*/
constructor(r) {
this._r = r
}
/**
* Initialize this model.
*
* @param {Array<Array<number>>} x Training data
* @param {*[]} y Target values
*/
init(x, y) {
this._x = x
this._y = y
this._d = this._x[0].length
this._c = [...new Set(y)]
this._p = Matrix.random(this._x.length, this._r)
this._p.div(this._p.sum(1))
this._a = []
for (let i = 0; i < this._c.length; i++) {
const xi = this._x.filter((v, k) => this._y[k] === this._c[i])
this._a[i] = xi.length / this._x.length
}
this._mstep()
}
_mstep() {
this._pi = Matrix.zeros(this._c.length, this._r)
this._m = []
for (let i = 0; i < this._c.length; i++) {
this._m[i] = Matrix.zeros(this._r, this._d)
}
const count = Matrix.zeros(this._c.length, 1)
const psum = []
for (let i = 0; i < this._c.length; i++) {
psum[i] = Matrix.zeros(1, this._r)
}
for (let i = 0; i < this._x.length; i++) {
const cidx = this._c.indexOf(this._y[i])
for (let r = 0; r < this._r; r++) {
const pir = this._p.at(i, r)
this._pi.addAt(cidx, r, pir)
for (let d = 0; d < this._d; d++) {
this._m[cidx].addAt(r, d, this._x[i][d] * pir)
}
}
count.addAt(cidx, 0, 1)
psum[cidx].add(this._p.row(i))
}
this._pi.div(count)
for (let i = 0; i < this._c.length; i++) {
this._m[i].div(psum[i].t)
}
this._s = Matrix.zeros(this._d, this._d)
const yidx = this._y.map((v, i) => [v, i])
for (let i = 0; i < this._c.length; i++) {
let xi = this._x.filter((v, k) => this._y[k] === this._c[i])
xi = Matrix.fromArray(xi)
const pi = this._p.row(yidx.filter(v => v[0] === this._c[i]).map(v => v[1]))
for (let r = 0; r < this._r; r++) {
const xir = Matrix.sub(xi, this._m[i].row(r))
const v = xir.tDot(Matrix.mult(xir, pi.col(r)))
this._s.add(v)
}
}
this._s.div(this._x.length)
this._sinv = this._s.inv()
}
/**
* Fit model.
*/
fit() {
this._p = Matrix.zeros(this._x.length, this._r)
for (let i = 0; i < this._x.length; i++) {
const cidx = this._c.indexOf(this._y[i])
const x = new Matrix(1, this._d, this._x[i])
for (let r = 0; r < this._r; r++) {
const xir = Matrix.sub(x, this._m[cidx].row(r))
const v = xir.dot(this._sinv).dot(xir.t).toScaler()
this._p.addAt(i, r, this._pi.at(cidx, r) * Math.exp(-v / 2))
}
}
this._p.div(this._p.sum(1))
this._mstep()
}
/**
* Returns predicted categories.
*
* @param {Array<Array<number>>} data Sample data
* @returns {*[]} Predicted values
*/
predict(data) {
const x = Matrix.fromArray(data)
const p = Matrix.zeros(x.rows, this._c.length)
for (let i = 0; i < this._c.length; i++) {
for (let r = 0; r < this._r; r++) {
const xi = Matrix.sub(x, this._m[i].row(r))
const v = xi.dot(this._sinv)
v.mult(xi)
const s = v.sum(1)
s.map(v => this._pi.at(i, r) * Math.exp(-v / 2))
for (let n = 0; n < x.rows; n++) {
p.addAt(n, i, this._a[i] * s.at(n, 0))
}
}
}
return p.argmax(1).value.map(v => this._c[v])
}
}