-
Notifications
You must be signed in to change notification settings - Fork 4
/
constant.go
85 lines (71 loc) · 2.39 KB
/
constant.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
package opset13
import (
"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)
// Constant represents the ONNX constant operator.
type Constant struct {
value tensor.Tensor
}
// newConstant creates a new constant operator.
func newConstant() ops.Operator {
return &Constant{}
}
// Init initializes the constant operator. It supports all constant types except
// `sparse_value`, `value_string`, and `value_strings`.
func (c *Constant) Init(n *onnx.NodeProto) error {
attributes := n.GetAttribute()
if len(attributes) != 1 {
return ops.ErrInvalidAttributeCount(1, len(attributes), c)
}
attr := attributes[0]
switch attr.GetName() {
case "sparse_value", "value_string", "value_strings":
return ops.ErrUnsupportedAttribute(attr.GetName(), c)
case "value":
t, err := onnx.TensorFromProto(attr.GetT())
if err != nil {
return err
}
c.value = t
case "value_float":
c.value = tensor.New(tensor.FromScalar(attr.GetF()))
case "value_floats":
floats := attr.GetFloats()
c.value = tensor.New(tensor.WithShape(len(floats)), tensor.WithBacking(floats))
case "value_int":
c.value = tensor.New(tensor.FromScalar(attr.GetI()))
case "value_ints":
ints := attr.GetInts()
c.value = tensor.New(tensor.WithShape(len(ints)), tensor.WithBacking(ints))
default:
return ops.ErrUnsupportedAttribute(attr.GetName(), c)
}
return nil
}
// Apply applies the constant operator.
func (c *Constant) Apply(_ []tensor.Tensor) ([]tensor.Tensor, error) {
return []tensor.Tensor{c.value}, nil
}
// ValidateInputs validates the inputs that will be given to Apply for this operator.
func (c *Constant) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ValidateInputs(c, inputs)
}
// GetMinInputs returns the minimum number of input tensors this operator expects.
func (c *Constant) GetMinInputs() int {
return 0
}
// GetMaxInputs returns the maximum number of input tensors this operator expects.
func (c *Constant) GetMaxInputs() int {
return 0
}
// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes
// for the corresponding input tensor.
func (c *Constant) GetInputTypeConstraints() [][]tensor.Dtype {
return [][]tensor.Dtype{}
}
// String implements the stringer interface, and can be used to format errors or messages.
func (c *Constant) String() string {
return "constant operator"
}