forked from samsarahq/thunder
/
sql.go
274 lines (252 loc) · 8.35 KB
/
sql.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
package fields
import (
"database/sql"
"database/sql/driver"
"encoding"
"encoding/json"
"fmt"
"reflect"
"time"
)
// Valuer fulfills the sql/driver.Valuer interface which deserializes our
// struct field value into a valid SQL value.
type Valuer struct {
*Descriptor
value reflect.Value
}
// Value satisfies the sql/driver.Valuer interface.
// The value should be one of the following:
// int64
// float64
// bool
// []byte
// string
// time.Time
// nil - for NULL values
func (f Valuer) Value() (driver.Value, error) {
// Return early if the value is nil. Ideally we would do a `i == nil` comparison here, but
// unfortunately for us, `nil` is typed and that would always return false. This has to be
// before `.Interface()` as that method panics otherwise.
switch f.value.Kind() {
// IsNil panics if the value isn't one of these kinds.
case reflect.Chan, reflect.Map, reflect.Func,
reflect.Ptr, reflect.Interface, reflect.Slice:
if f.value.IsNil() {
return nil, nil
}
case reflect.Invalid:
return nil, nil
}
i := f.value.Interface()
// If our interface supports driver.Valuer we can immediately short-circuit as this is what the
// MySQL driver would do.
if valuer, ok := i.(driver.Valuer); ok {
return valuer.Value()
}
// Override serialization behavior with tags (these take precedence over how a type would
// usually be serialized).
// Example:
// struct {
// Blob proto.Blob `sql:",binary"` // ensures that Marshal or MarshalBinary is used.
// IP IP `sql:",string"` // ensures that its MarshalText method
// // is used for serialization.
// JSON map[string]string `sql:",json"` // ensures that json.Marshal is used on the value.
// }
switch {
case f.Tags.Contains("binary"):
if iface, ok := i.(marshaler); ok {
return iface.Marshal()
}
if iface, ok := i.(encoding.BinaryMarshaler); ok {
return iface.MarshalBinary()
}
case f.Tags.Contains("string"):
if iface, ok := i.(encoding.TextMarshaler); ok {
return iface.MarshalText()
}
case f.Tags.Contains("json"):
if iface, ok := i.(json.Marshaler); ok {
return iface.MarshalJSON()
}
return json.Marshal(i)
}
// At this point we have already handled `nil` above, so we can assume that all
// other values can be coerced into dereferenced types of bool/int/float/string.
if f.value.Kind() == reflect.Ptr {
f.value = f.value.Elem()
}
// Coerce our value into a valid sql/driver.Value (see sql/driver.IsValue).
// This not only converts base types into their sql counterparts (like int32 -> int64) but also
// handles custom types (like `type customString string` -> string).
switch f.Kind {
case reflect.Bool:
return f.value.Bool(), nil
case
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return f.value.Int(), nil
case reflect.Float32, reflect.Float64:
return f.value.Float(), nil
case reflect.String:
return f.value.String(), nil
}
// If we can't figure out what the type is supposed to be, we pass it straight through to SQL,
// which will return an error if it can't handle it.
// This means we don't have to handle []byte or time.Time specially, since they'll just pass on
// through.
return f.value.Interface(), nil
}
var _ driver.Valuer = Valuer{}
var _ driver.Valuer = &Valuer{}
// Scanner fulfills the sql.Scanner interface which deserializes SQL values
// into the type dictated by our descriptor.
type Scanner struct {
*Descriptor
value reflect.Value
isValid bool
}
// Interface returns the value deserialized into the scanner.
func (s Scanner) Interface() interface{} {
if s.value.IsValid() {
return s.value.Interface()
}
return nil
}
// CopyTo copies the scanner value to another reflect.Value. This is used for setting structs.
func (s *Scanner) CopyTo(to reflect.Value) {
s.copy(s.value, to, s.isValid)
}
// Scan satisfies the sql.Scanner interface.
// The src value will be one of the following:
// int64
// float64
// bool
// []byte
// string
// time.Time
// nil - for NULL values
func (s *Scanner) Scan(src interface{}) error {
// Get a value of the pointer of our type. The Scanner and Unmarshalers should
// only be implemented as dereference methods, since they would do nothing otherwise. Therefore
// we can safely assume that we should check for these interfaces on the pointer value.
s.value = reflect.New(s.Type)
i := s.value.Interface()
// Our value however should continue referencing a non-pointer for easier assignment.
s.value = s.value.Elem()
// If our interface supports sql.Scanner we can immediately short-circuit as this is what the
// MySQL driver would do.
if scanner, ok := i.(sql.Scanner); ok {
// If we have a scanner it will handle its own validity.
s.isValid = true
return scanner.Scan(src)
}
// Keep track of whether our value was empty.
s.isValid = src != nil
// Null values are simply set to zero. Because we're not holding on to pointers, we need to
// represent this as a boolean. This comes _after_ the scanner step, just in case the scanner
// handles nil differently.
if !s.isValid {
return nil
}
// Handle coercion into native types []byte and time.Time (this method will return an error if
// we don't handle them). These are pointers here because we want to pass around a pointer
// for interfaces.
switch i.(type) {
case *[]byte:
if str, ok := src.(string); ok {
s.value.Set(reflect.ValueOf([]byte(str)))
return nil
}
if b, ok := src.([]byte); ok {
bCopy := make([]byte, len(b), len(b))
copy(bCopy, b)
s.value.Set(reflect.ValueOf(bCopy))
return nil
}
case *time.Time:
if _, ok := src.(time.Time); ok {
s.value.Set(reflect.ValueOf(src))
return nil
}
}
// Override deserialization behavior with tags (these take precedence over how a type would
// usually be deserialized).
// Example:
// struct {
// Blob proto.Blob `sql:",binary"` // ensures that Unmarshal or UnmarshalBinary is used.
// IP IP `sql:",string"` // ensures that its UnmarshalText method
// // is used for deserialization.
// JSON map[string]string `sql:",json"` // ensures that json.Unmarshal is used on the value.
// }
switch {
case s.Tags.Contains("binary"):
b, ok := src.([]byte)
if !ok {
return fmt.Errorf("binary column must be of type []byte, got %T", src)
}
if iface, ok := i.(unmarshaler); ok {
return iface.Unmarshal(b)
}
if iface, ok := i.(encoding.BinaryUnmarshaler); ok {
return iface.UnmarshalBinary(b)
}
case s.Tags.Contains("string"):
if str, ok := src.(string); ok {
src = []byte(str)
}
b, isByte := src.([]byte)
if !isByte {
return fmt.Errorf("string/text column must be of type []byte or string, got %T", src)
}
if iface, ok := i.(encoding.TextUnmarshaler); isByte && ok {
return iface.UnmarshalText(b)
}
case s.Tags.Contains("json"):
if str, ok := src.(string); ok {
src = []byte(str)
}
b, isByte := src.([]byte)
if !isByte {
return fmt.Errorf("json column must be of type string or []byte, got %T", src)
}
// Implicitly will check for json.Unmarshaler.
return json.Unmarshal(b, i)
}
// If a MySQL value can be coerced into our type, we do so here.
// This not only converts base types into their sql counterparts (like int64 -> int32) but also
// handles custom types (like string -> `type customString string`).
switch s.Kind {
case reflect.Bool:
b := sql.NullBool{}
if err := b.Scan(src); err != nil {
return err
}
s.value.Set(reflect.ValueOf(b.Bool).Convert(s.Type))
return nil
case
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
i := sql.NullInt64{}
if err := i.Scan(src); err != nil {
return err
}
s.value.Set(reflect.ValueOf(i.Int64).Convert(s.Type))
return nil
case reflect.Float32, reflect.Float64:
float := sql.NullFloat64{}
if err := float.Scan(src); err != nil {
return err
}
s.value.Set(reflect.ValueOf(float.Float64).Convert(s.Type))
return nil
case reflect.String:
str := sql.NullString{}
if err := str.Scan(src); err != nil {
return err
}
s.value.Set(reflect.ValueOf(str.String).Convert(s.Type))
return nil
}
return fmt.Errorf("couldn't coerce type %T into %T", src, i)
}
var _ sql.Scanner = &Scanner{}