/
quadratic_discriminant.js
72 lines (67 loc) · 1.48 KB
/
quadratic_discriminant.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import Matrix from '../util/matrix.js'
/**
* Quadratic discriminant analysis
*/
export default class QuadraticDiscriminant {
// https://arxiv.org/abs/1906.02590
// https://online.stat.psu.edu/stat508/book/export/html/696
constructor() {
this._m = []
this._s = []
this._sinv = []
this._c = []
this._categories = []
}
/**
* Fit model.
*
* @param {Array<Array<number>>} x Training data
* @param {*[]} y Target values
*/
fit(x, y) {
this._m = []
this._s = []
this._sinv = []
this._c = []
this._categories = []
const n = x.length
const c = new Set(y)
for (const k of c) {
const xk = []
for (let i = 0; i < y.length; i++) {
if (y[i] === k) xk.push(x[i])
}
if (xk.length === 0) break
const mat = Matrix.fromArray(xk)
this._m.push(mat.mean(0))
const s = mat.cov()
this._s.push(s)
this._sinv.push(s.inv())
this._c.push(Math.log(mat.rows / n) - Math.log(s.det()) / 2)
this._categories.push(k)
}
}
/**
* Returns predicted categories.
*
* @param {Array<Array<number>>} data Sample data
* @returns {*[]} Predicted values
*/
predict(data) {
return data.map(d => {
const k = this._m.length
const m = new Matrix(1, d.length, d)
let max_i = -1
let max_p = -Infinity
for (let i = 0; i < k; i++) {
const mi = Matrix.sub(m, this._m[i])
const v = this._c[i] - mi.dot(this._sinv[i]).dot(mi.t).toScaler() / 2
if (max_p < v) {
max_p = v
max_i = i
}
}
return this._categories[max_i]
})
}
}