/
mutual_information.js
64 lines (60 loc) · 1.56 KB
/
mutual_information.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import Histogram from './histogram.js'
/**
* Mutual information feature selector
*/
export default class MutualInformationFeatureSelection {
// https://qiita.com/shimopino/items/5fee7504c7acf044a521
// https://qiita.com/hyt-sasaki/items/ffaab049e46f800f7cbf
/**
* @param {number} k Number of selected features
*/
constructor(k) {
this._k = k
this._bins = 40
}
_mutual_information(a, b) {
const histogram = new Histogram({ count: this._bins })
const ha = histogram.fit(a)
const hb = histogram.fit(b)
const hab = histogram.fit(a.map((v, i) => [v[0], b[i][0]]))
const na = a.length,
nb = b.length
let v = 0
for (let i = 0; i < ha.length; i++) {
for (let j = 0; j < hb.length; j++) {
if (hab[i][j] > 0 && ha[i] > 0 && hb[j] > 0) {
const pab = hab[i][j] / na
const pa = ha[i] / na
const pb = hb[j] / nb
v += pab * Math.log(pab / (pa * pb))
}
}
}
return v / na
}
/**
* Fit model.
*
* @param {Array<Array<number>>} x Training data
* @param {Array<Array<number>>} y Target values
*/
fit(x, y) {
const imp = []
for (let i = 0; i < x[0].length; i++) {
const a = x.map(v => [v[i]])
imp.push(this._mutual_information(a, y))
}
this._importance = imp.map((v, i) => [v, i])
this._importance.sort((a, b) => b[0] - a[0])
}
/**
* Returns feature selected values.
*
* @param {Array<Array<number>>} x Sample data
* @returns {Array<Array<number>>} Predicted values
*/
predict(x) {
const impidx = this._importance.slice(0, this._k).map(im => im[1])
return x.map(d => impidx.map(i => d[i]))
}
}