-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathinterface.go
62 lines (54 loc) · 1.43 KB
/
interface.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
package layer
import "C"
import (
"fmt"
tf "github.com/galeone/tensorflow/tensorflow/go"
)
type DataType string
// Types of scalar values in the TensorFlow type system.
const (
Float16 DataType = "float16"
Float32 DataType = "float32"
Float64 DataType = "float64"
Double DataType = "double"
Int32 DataType = "int32"
Uint32 DataType = "uint32"
Uint8 DataType = "uint8"
Int16 DataType = "int16"
Int8 DataType = "int8"
String DataType = "string"
Complex64 DataType = "complex64"
Complex DataType = "complex"
Int64 DataType = "int64"
Uint64 DataType = "uint64"
Bool DataType = "bool"
Qint8 DataType = "qint8"
Quint8 DataType = "quint8"
Qint32 DataType = "qint32"
Bfloat16 DataType = "bfloat16"
Qint16 DataType = "qint16"
Quint16 DataType = "quint16"
Uint16 DataType = "uint16"
Complex128 DataType = "complex128"
Half DataType = "half"
)
type Layer interface {
GetShape() tf.Shape
GetDtype() DataType
SetInputs(inputs ...Layer) Layer
GetInputs() []Layer
GetName() string
GetLayerWeights() []*tf.Tensor
GetKerasLayerConfig() interface{}
GetCustomLayerDefinition() string
}
var uniqueNameCounts = make(map[string]int)
func UniqueName(name string) string {
count := uniqueNameCounts[name]
count++
uniqueNameCounts[name] = count
return fmt.Sprintf("%s_%d", name, count)
}
func (d *DataType) String() string {
return string(*d)
}