-
Notifications
You must be signed in to change notification settings - Fork 1
/
exact_bytes.go
137 lines (122 loc) · 3.53 KB
/
exact_bytes.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
// Copyright © 2021 The Things Industries B.V.
// SPDX-License-Identifier: Apache-2.0
package types
import (
"encoding/base64"
"encoding/hex"
"fmt"
"strings"
"github.com/TheThingsIndustries/protoc-gen-go-flags/flagsplugin"
"github.com/spf13/pflag"
)
// GetExactBytes returns a value from a bytes flag.
func GetExactBytes(fs *pflag.FlagSet, name string) (value []byte, set bool, err error) {
name = toDash.Replace(name)
flag := fs.Lookup(name)
if flag == nil {
return nil, false, &flagsplugin.ErrFlagNotFound{FlagName: name}
}
return flag.Value.(*ExactBytesValue).Value, flag.Changed, nil
}
// ExactBytesValue implements pflag.Value interface.
type ExactBytesValue struct {
length int
Value []byte
}
// Set implements pflag.Value interface.
func (ebv *ExactBytesValue) Set(s string) error {
trimmed := strings.TrimSuffix(s, "=")
switch len(trimmed) {
case hex.EncodedLen(ebv.length):
b, err := hex.DecodeString(trimmed)
if err != nil {
return err
}
ebv.Value = b
case base64.RawStdEncoding.EncodedLen(ebv.length):
b, err := base64.RawStdEncoding.DecodeString(flagsplugin.Base64Replacer.Replace(trimmed))
if err != nil {
return err
}
ebv.Value = b
default:
return fmt.Errorf("Invalid bytes length: want %d got %d", ebv.length, len(trimmed))
}
return nil
}
// Type implements pflag.Value interface.
func (ebv *ExactBytesValue) Type() string {
return fmt.Sprintf("%d-bytes", ebv.length)
}
// String implements pflag.Value interface.
func (ebv *ExactBytesValue) String() string {
return hex.EncodeToString(ebv.Value)
}
// New8BytesFlag defines a new flag that holds a byte array of length 8.
func New8BytesFlag(name, usage string, opts ...flagsplugin.FlagOption) *pflag.Flag {
flag := &pflag.Flag{
Name: name,
Usage: usage,
Value: &ExactBytesValue{length: 8},
}
flagsplugin.ApplyOptions(flag, opts...)
return flag
}
// GetExactBytesSlice returns a value from a byte flag.
func GetExactBytesSlice(fs *pflag.FlagSet, name string) (value [][]byte, set bool, err error) {
name = toDash.Replace(name)
flag := fs.Lookup(name)
if flag == nil {
return nil, false, &flagsplugin.ErrFlagNotFound{FlagName: name}
}
value = make([][]byte, len(flag.Value.(*ExactBytesSliceValue).Values))
for i, v := range flag.Value.(*ExactBytesSliceValue).Values {
value[i] = v.Value
}
return value, flag.Changed, nil
}
// ExactBytesSliceValue implements pflag.Value interface.
type ExactBytesSliceValue struct {
length int
Values []ExactBytesValue
}
// Set implements pflag.Value interface.
func (ebv *ExactBytesSliceValue) Set(s string) error {
vs, err := flagsplugin.SplitSliceElements(s)
if err != nil {
return err
}
for _, v := range vs {
ev := ExactBytesValue{length: ebv.length}
if err := ev.Set(v); err != nil {
return err
}
ebv.Values = append(ebv.Values, ev)
}
return nil
}
// Type implements pflag.Value interface.
func (ebv *ExactBytesSliceValue) Type() string {
return fmt.Sprintf("%d-bytes", ebv.length)
}
// String implements pflag.Value interface.
func (ebv *ExactBytesSliceValue) String() string {
if len(ebv.Values) == 0 {
return ""
}
vs := make([]string, len(ebv.Values))
for i, v := range ebv.Values {
vs[i] = v.String()
}
return "[" + flagsplugin.JoinSliceElements(vs) + "]"
}
// New8BytesSliceFlag defines a new flag that holds a slice of byte arrays of length 8.
func New8BytesSliceFlag(name, usage string, opts ...flagsplugin.FlagOption) *pflag.Flag {
flag := &pflag.Flag{
Name: name,
Usage: usage,
Value: &ExactBytesSliceValue{length: 8},
}
flagsplugin.ApplyOptions(flag, opts...)
return flag
}