/
categorical_naive_bayes.js
87 lines (83 loc) · 1.97 KB
/
categorical_naive_bayes.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
/**
* Categorical naive bayes
*/
export default class CategoricalNaiveBayes {
// https://scikit-learn.org/stable/modules/naive_bayes.html#categorical-naive-bayes
/**
* @param {number} [alpha] Smoothing parameter
*/
constructor(alpha = 1.0) {
this._alpha = alpha
}
/**
* Fit model.
*
* @param {Array<Array<*>>} datas Training data
* @param {*[]} labels Target values
*/
fit(datas, labels) {
if (!this._cand) {
this._d = datas[0].length
this._cand = []
for (let i = 0; i < this._d; i++) {
this._cand[i] = [...new Set(datas.map(v => v[i]))]
}
}
this._labels = [...new Set(labels)]
this._prob = []
for (let k = 0; k < this._labels.length; k++) {
const pk = []
for (let d = 0; d < this._d; d++) {
pk[d] = Array(this._cand[d].length).fill(0)
for (let i = 0; i < datas.length; i++) {
if (labels[i] !== this._labels[k]) {
continue
}
const idx = this._cand[d].indexOf(datas[i][d])
pk[d][idx]++
}
const s = pk[d].reduce((s, v) => s + v, 0)
pk[d] = pk[d].map(v => (v + this._alpha) / (s + this._alpha * pk[d].length))
}
this._prob[k] = pk
}
}
/**
* Returns predicted probabilities.
*
* @param {Array<Array<*>>} datas Sample data
* @returns {Array<Array<number>>} Predicted values
*/
probability(datas) {
return datas.map(v => {
const p = Array(this._labels.length).fill(1)
for (let d = 0; d < this._d; d++) {
const i = this._cand[d].indexOf(v[d])
for (let k = 0; k < this._labels.length; k++) {
p[k] *= this._prob[k][d][i]
}
}
return p
})
}
/**
* Returns predicted categories.
*
* @param {Array<Array<*>>} datas Sample data
* @returns {*[]} Predicted values
*/
predict(datas) {
const prob = this.probability(datas)
return prob.map(v => {
let max_p = 0
let max_c = -1
for (let i = 0; i < this._labels.length; i++) {
if (v[i] > max_p) {
max_p = v[i]
max_c = i
}
}
return max_c < 0 ? null : this._labels[max_c]
})
}
}