/
random.js
51 lines (45 loc) · 1.17 KB
/
random.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import Layer from './base.js'
import Matrix from '../../../util/matrix.js'
import Tensor from '../../../util/tensor.js'
/**
* Random layer
*/
export default class RandomLayer extends Layer {
/**
* @param {object} config config
* @param {number | number[] | string} config.size Size of output
* @param {number} [config.mean] Mean of values
* @param {number} [config.variance] Variance of values
*/
constructor({ size, mean = 0, variance = 1, ...rest }) {
super(rest)
this._size = size
this._mean = mean
this._variance = variance
this._rows = 1
}
bind({ n }) {
this._rows = n
}
calc() {
if (typeof this._size === 'string') {
const sizes = this.graph.getNode(this._size).lastOutputSize
if (sizes.length === 2) {
return Matrix.randn(sizes[0], sizes[1], this._mean, this._variance)
}
return Tensor.randn(sizes, this._mean, this._variance)
}
if (Array.isArray(this._size)) {
return Tensor.randn([this._rows, ...this._size], this._mean, this._variance)
}
return Matrix.randn(this._rows, this._size, this._mean, this._variance)
}
grad() {}
toObject() {
return {
type: 'random',
size: this._size,
}
}
}
RandomLayer.registLayer()