/
utils.go
159 lines (124 loc) · 4.12 KB
/
utils.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
// Copyright 2019 Aporeto Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package manipmemory
import (
"fmt"
"reflect"
"strings"
memdb "github.com/hashicorp/go-memdb"
"go.aporeto.io/elemental"
)
// stringBasedFieldIndex is used to extract a field from an object
// using reflection and builds an index on that field. The Indexer
// takes objects that the underlying is string, even though the original
// type is not string. For example, if you declare a type as
// type ABC string
// then you should use this indexer. It implements the memdb indexer
// interface.
type stringBasedFieldIndex struct {
Field string
Lowercase bool
}
// FromObject implements the memdb indexer interface.
func (s *stringBasedFieldIndex) FromObject(obj interface{}) (bool, []byte, error) {
v := reflect.ValueOf(obj)
v = reflect.Indirect(v) // Dereference the pointer if any
fv := v.FieldByName(s.Field)
if !fv.IsValid() {
return false, nil,
fmt.Errorf("field '%s' for %#v is invalid", s.Field, obj)
}
val := fv.String()
if val == "" {
return false, nil, nil
}
if s.Lowercase {
val = strings.ToLower(val)
}
// Add the null character as a terminator
val += "\x00"
return true, []byte(val), nil
}
// FromArgs implements the memdb indexer interface.
func (s *stringBasedFieldIndex) FromArgs(args ...interface{}) ([]byte, error) {
if len(args) != 1 {
return nil, fmt.Errorf("must provide only a single argument")
}
t := reflect.TypeOf(args[0])
if t.Kind() != reflect.String {
return nil, fmt.Errorf("argument must be a string: %#v", args[0])
}
arg := reflect.ValueOf(args[0]).String()
if s.Lowercase {
arg = strings.ToLower(arg)
}
// Add the null character as a terminator
arg += "\x00"
return []byte(arg), nil
}
// createSchema creates the memdb schema from the configuration of the identities.
func createSchema(c *IdentitySchema) (*memdb.TableSchema, error) {
tableSchema := &memdb.TableSchema{
Name: c.Identity.Category,
Indexes: map[string]*memdb.IndexSchema{},
}
for _, index := range c.Indexes {
var indexConfig memdb.Indexer
switch index.Type {
case IndexTypeSlice:
indexConfig = &memdb.StringSliceFieldIndex{Field: index.Attribute}
case IndexTypeMap:
indexConfig = &memdb.StringMapFieldIndex{Field: index.Attribute}
case IndexTypeString:
indexConfig = &memdb.StringFieldIndex{Field: index.Attribute}
case IndexTypeBoolean:
attr := index.Attribute
indexConfig = &memdb.ConditionalIndex{Conditional: func(obj interface{}) (bool, error) {
return boolIndex(obj, attr)
}}
case IndexTypeStringBased:
indexConfig = &stringBasedFieldIndex{Field: index.Attribute}
default: // if the caller is a bozo
return nil, fmt.Errorf("invalid index type: %d", index.Type)
}
tableSchema.Indexes[index.Name] = &memdb.IndexSchema{
Name: index.Name,
Unique: index.Unique,
Indexer: indexConfig,
AllowMissing: true,
}
}
return tableSchema, nil
}
// boolIndex is a conditional indexer for booleans.
func boolIndex(obj interface{}, field string) (bool, error) {
v := reflect.ValueOf(obj)
v = reflect.Indirect(v) // Dereference the pointer if any
fv := v.FieldByName(field)
if !fv.IsValid() {
return false, fmt.Errorf("field '%s' for %#v is invalid", field, obj)
}
return fv.Bool(), nil
}
func mergeIn(target, source *map[string]elemental.Identifiable) {
for k, v := range *source {
(*target)[k] = v
}
}
func intersection(target, source *map[string]elemental.Identifiable, queryStart bool) {
combined := map[string]elemental.Identifiable{}
for k, v := range *source {
if _, ok := (*target)[k]; ok || queryStart {
combined[k] = v
}
}
*target = combined
}