-
Notifications
You must be signed in to change notification settings - Fork 0
/
convenient.go
328 lines (275 loc) · 8.33 KB
/
convenient.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
package bayes
import (
"encoding/binary"
"hash/crc32"
"unsafe"
"github.com/pkg/errors"
"github.com/zeebo/blake3"
)
// ============================================================================
// Shorthand functions.
// ============================================================================
// This file contains functions for easy-to-use purposes.
//
// Convenient Functions:
// - SetStorage() - Sets the storage used by the predictor.
// - Reset() - Resets the trained data of the predictor.
// - Train() - Trains the predictor with the given items.
// - Predict() - Predicts the next item from the given items.
// - HashTrans() - Returns a unique hash from the input items.
// - GetClass() - Returns the original item value of the given class ID.
// ============================================================================
const (
// StorageDefault is the default storage used by the predictor. Currently, it
// is an in-memory log (logmem package).
StorageDefault = MemoryStorage
// ScopeIDDefault is the default scope ID on creating an instance of the
// predictor.
ScopeIDDefault = uint64(0)
)
var (
_predictor NodeLogger
_classes map[uint64]_Class
_storage = StorageDefault
)
func init() {
Reset()
}
// ----------------------------------------------------------------------------
// Type: _Class (private)
// ----------------------------------------------------------------------------
// _Class holds the class ID and the original value.
type _Class struct {
Raw any
ID uint64
}
// ----------------------------------------------------------------------------
// Public functions
// ----------------------------------------------------------------------------
// GetClass returns the original value of the given class ID.
func GetClass(classID uint64) any {
return _classes[classID].Raw
}
// HashTrans returns a unique hash from the input transitions. Note that the hash
// is not cryptographically secure.
func HashTrans[T any](transitions ...T) (uint64, error) {
// Calculate the hash of the transition.
hashed, err := getBlake3(transitions...)
if err != nil {
return 0, errors.Wrap(err, "failed to get the hash of the transition")
}
// Calculate the CRC32C of the hash value.
chksum := getCRC32C(hashed)
return chopAndMergeBytes(hashed, chksum)
}
// Predict returns the next class ID inferred from the given items.
//
// To get the original value of the class, use `GetClass()`.
//
//nolint:nonamedreturns // named return is used for readability.
func Predict[T any](items []T) (classID uint64, err error) {
if _predictor == nil {
return 0, errors.New("predictor is not initialized")
}
biggest := struct {
Probability float64
Class uint64
}{
Probability: 0,
Class: 0,
}
flowID, err := HashTrans(items...)
if err != nil {
return 0, errors.Wrap(err, "failed to hash the flow")
}
for classID := range _classes {
probability := _predictor.Predict(flowID, classID)
if biggest.Probability < probability {
biggest.Probability = probability
biggest.Class = classID
}
}
return biggest.Class, nil
}
// Reset resets the train object.
func Reset() {
var err error
_predictor, err = New(_storage, ScopeIDDefault)
if err != nil {
panic(err)
}
_classes = make(map[uint64]_Class)
}
// SetStorage sets the storage used by the predictor. This won't affect the
// predictors created via `New()`.
//
// Do not forget to `Reset()` the predictor after changing the storage.
func SetStorage(storage Storage) {
_storage = storage
}
// Train trains the predictor with the given items.
//
// Once the item appears in the training set, the item is added to the class list.
func Train[T any](items []T) error {
if _predictor == nil {
Reset()
}
prevItem := uint64(0)
drill := []uint64{}
for index, itemRaw := range items {
item, err := convAnyToUint64(itemRaw)
if err != nil {
return errors.Wrap(err, "failed during training iteration")
}
if index == 0 {
prevItem = item
drill = append(drill, item)
continue
}
// 101 training. Trains only the predecessor and the successor item.
// e.g.
// previous items --> [1, 2, 3, 4, 5]
// following item --> 6
// will train:
// [5] --> 6
_predictor.Update(prevItem, item)
// Drill.
// Trains by repeating the flow of the previous items.
// e.g.
// previous items --> [1, 2, 3, 4, 5]
// following item --> 6
// will train:
// [5] --> 6
// [4, 5] --> 6
// [3, 4, 5] --> 6
// [2, 3, 4, 5] --> 6
// [1, 2, 3, 4, 5] --> 6
for i := 0; i < len(drill); i++ {
flowID, _ := HashTrans(drill[i:]...)
_predictor.Update(flowID, item)
}
prevItem = item
drill = append(drill, item)
addClass(item, itemRaw)
}
return nil
}
// ----------------------------------------------------------------------------
// Private functions
// ----------------------------------------------------------------------------
//nolint:varnamelen,cyclop
func addClass(class uint64, raw any) {
switch v := raw.(type) {
case uint64:
_classes[class] = _Class{ID: class, Raw: v}
case uint32:
_classes[class] = _Class{ID: class, Raw: v}
case uint16:
_classes[class] = _Class{ID: class, Raw: v}
case uint:
_classes[class] = _Class{ID: class, Raw: v}
case int64:
_classes[class] = _Class{ID: class, Raw: v}
case int32:
_classes[class] = _Class{ID: class, Raw: v}
case int16:
_classes[class] = _Class{ID: class, Raw: v}
case int:
_classes[class] = _Class{ID: class, Raw: v}
case float64:
_classes[class] = _Class{ID: class, Raw: v}
case float32:
_classes[class] = _Class{ID: class, Raw: v}
case string:
_classes[class] = _Class{ID: class, Raw: v}
case bool:
_classes[class] = _Class{ID: class, Raw: v}
default:
_classes[class] = _Class{ID: class, Raw: raw}
}
}
// chopAndMergeBytes combines the two input as one in 8 byte length.
//
// The first 4 bytes of the input `a` will be used as the upper half of the
// output, and the first 4 bytes of the input `b` will be used as the bottom
// half of the output.
//
//nolint:varnamelen // short parameter name is readable enough
func chopAndMergeBytes(a, b []byte) (uint64, error) {
if len(a) < 4 || len(b) < 4 {
return 0, errors.New("failed to combine bytes. Both of the input must be 4byte or more")
}
lenOut := 8
rawid := make([]byte, lenOut)
_ = copy(rawid, a) // Upper half as hash
_ = copy(rawid[4:], b) // Bottom half as checksum
return binary.BigEndian.Uint64(rawid), nil
}
//nolint:varnamelen,cyclop
func convAnyToUint64(i interface{}) (uint64, error) {
switch v := i.(type) {
case uint64:
return v, nil
case uint32:
return uint64(v), nil
case uint16:
return uint64(v), nil
case uint:
return uint64(v), nil
case int64:
return uint64(v), nil
case int32:
return uint64(v), nil
case int16:
return uint64(v), nil
case int:
return uint64(v), nil
case float64:
return uint64(v), nil
case float32:
return uint64(v), nil
case string:
h := blake3.Sum512([]byte(v))
return binary.BigEndian.Uint64(h[:]), nil
case bool:
if v {
return uint64(1), nil
}
return uint64(0), nil
}
return 0, errors.Errorf("failed to convert to uint64. Unsupported type: %T", i)
}
// getBlake3 returns the hash of the input to byte array.
func getBlake3[T any](inputs ...T) ([]byte, error) {
hasher := blake3.New()
//nolint:varnamelen // short name is readable here
for _, v := range inputs {
vv, err := convAnyToUint64(v)
if err != nil {
return nil, errors.Wrap(err, "failed to convert to uint64")
}
// blake3.Hasher.Write() never returns an error.
// https://github.com/zeebo/blake3/blob/master/api.go#L87-L91
_, _ = hasher.Write(uint64ToByteArray(vv))
}
return hasher.Sum(nil), nil
}
// getCRC32C returns the CRC-32 with Castagnoli polynomial of the input.
func getCRC32C(input []byte) []byte {
crcTable := crc32.MakeTable(crc32.Castagnoli)
crc32C := crc32.New(crcTable)
// crc32.digest.Write() never returns an error.
// https://cs.opensource.google/go/go/+/master:src/hash/crc32/crc32.go;l=228-240
_, _ = crc32C.Write(input)
return crc32C.Sum(nil)
}
// uint64ToByteArray converts an unsigned integer to a byte array in little endian.
func uint64ToByteArray(num uint64) []byte {
size := int(unsafe.Sizeof(num))
arr := make([]byte, size)
for i := 0; i < size; i++ {
byt := *(*uint8)(unsafe.Pointer(uintptr(unsafe.Pointer(&num)) + uintptr(i)))
arr[i] = byt
}
return arr
}