/
dialect.go
74 lines (65 loc) · 1.52 KB
/
dialect.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
package sql
import (
"github.com/aacfactory/errors"
"github.com/aacfactory/fns/context"
"github.com/aacfactory/fns/runtime"
"github.com/aacfactory/fns/services"
)
var (
dialectFnName = []byte("dialect")
dialectContextKeyPrefix = []byte("@fns:sql:dialect:")
)
func ForceDialect(ctx context.Context, dialect string) context.Context {
ep := endpointName
if epn := used(ctx); len(epn) > 0 {
ep = epn
}
key := append(dialectContextKeyPrefix, ep...)
stored, has := context.LocalValue[string](ctx, key)
if has && stored == dialect {
return ctx
}
ctx.SetLocalValue(key, dialect)
return ctx
}
func Dialect(ctx context.Context) (dialect string, err error) {
ep := endpointName
if epn := used(ctx); len(epn) > 0 {
ep = epn
}
key := append(dialectContextKeyPrefix, ep...)
has := false
dialect, has = context.LocalValue[string](ctx, key)
if has {
return
}
eps := runtime.Endpoints(ctx)
response, handleErr := eps.Request(ctx, ep, dialectFnName, nil)
if handleErr != nil {
err = handleErr
return
}
dialect, err = services.ValueOfResponse[string](response)
if err != nil {
err = errors.Warning("sql: dialect failed").WithCause(err)
return
}
ctx.SetLocalValue(key, dialect)
return
}
type dialectFn struct {
dialect string
}
func (fn *dialectFn) Name() string {
return string(dialectFnName)
}
func (fn *dialectFn) Internal() bool {
return true
}
func (fn *dialectFn) Readonly() bool {
return false
}
func (fn *dialectFn) Handle(_ services.Request) (v interface{}, err error) {
v = fn.dialect
return
}