-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathTruncatedNormal.go
executable file
·59 lines (49 loc) · 1.15 KB
/
TruncatedNormal.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
package initializer
type ITruncatedNormal struct {
mean float64
name string
seed interface{}
stddev float64
}
func TruncatedNormal() *ITruncatedNormal {
return &ITruncatedNormal{
mean: 0,
seed: nil,
stddev: 0.05,
}
}
func (i *ITruncatedNormal) SetMean(mean float64) *ITruncatedNormal {
i.mean = mean
return i
}
func (i *ITruncatedNormal) SetName(name string) *ITruncatedNormal {
i.name = name
return i
}
func (i *ITruncatedNormal) SetSeed(seed interface{}) *ITruncatedNormal {
i.seed = seed
return i
}
func (i *ITruncatedNormal) SetStddev(stddev float64) *ITruncatedNormal {
i.stddev = stddev
return i
}
type jsonConfigITruncatedNormal struct {
ClassName string `json:"class_name"`
Name string `json:"name"`
Config map[string]interface{} `json:"config"`
}
func (i *ITruncatedNormal) GetKerasLayerConfig() interface{} {
return jsonConfigITruncatedNormal{
ClassName: "TruncatedNormal",
Name: i.name,
Config: map[string]interface{}{
"mean": i.mean,
"seed": i.seed,
"stddev": i.stddev,
},
}
}
func (i *ITruncatedNormal) GetCustomLayerDefinition() string {
return ``
}