/
hook.go
110 lines (94 loc) · 2.67 KB
/
hook.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
package psql
import (
"context"
"fmt"
"regexp"
"strings"
"github.com/opentracing/opentracing-go"
otlog "github.com/opentracing/opentracing-go/log"
"github.com/spf13/cast"
)
type TracingHook struct {
tracer opentracing.Tracer
}
func NewTracingHook(tracing opentracing.Tracer) *TracingHook {
return &TracingHook{
tracer: tracing,
}
}
func (h *TracingHook) getOperationName(query string) string {
defaultOperationName := "database"
selectReg := regexp.MustCompile(`SELECT`)
insertReg := regexp.MustCompile(`INSERT\s+INTO`)
updateReg := regexp.MustCompile(`UPDATE\s+.+\s+SET`)
deleteReg := regexp.MustCompile(`DELETE\s+FROM`)
query = strings.ToUpper(query)
selectIndex := selectReg.FindStringIndex(query)
insertIndex := insertReg.FindStringIndex(query)
updateIndex := updateReg.FindStringIndex(query)
deleteIndex := deleteReg.FindStringIndex(query)
if selectIndex == nil && insertIndex == nil && updateIndex == nil && deleteIndex == nil {
return defaultOperationName
}
if deleteIndex != nil {
return "DELETE"
}
if updateIndex != nil {
return "UPDATE"
}
if insertIndex != nil {
return "INSERT"
}
if selectIndex != nil {
return "SELECT"
}
return "NONE"
}
// Before hook will print the query with it's args and return the context with the timestamp
func (h *TracingHook) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
if ctx != nil {
span := opentracing.SpanFromContext(ctx)
if span != nil {
span, ctx = opentracing.StartSpanFromContext(ctx, "database", opentracing.ChildOf(span.Context()))
span.SetTag("operation", h.getOperationName(query))
span.LogFields(
otlog.String("statement", query),
)
if args != nil && len(args) > 0 {
var argsString = []string{}
for index, arg := range args {
argsString = append(argsString, fmt.Sprintf(`$$%s:%s`, cast.ToString(index+1), cast.ToString(arg)))
}
span.LogFields(
otlog.String("args", strings.Join(argsString, ",")),
)
}
}
}
return ctx, nil
}
// After hook will get the timestamp registered on the Before hook and print the elapsed time
func (h *TracingHook) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
if ctx != nil {
span := opentracing.SpanFromContext(ctx)
if span != nil {
defer span.Finish()
span.SetTag("error", false)
}
}
return ctx, nil
}
// Hook OnError
func (h *TracingHook) OnError(ctx context.Context, err error, query string, args ...interface{}) error {
if ctx != nil {
span := opentracing.SpanFromContext(ctx)
if span != nil {
defer span.Finish()
span.SetTag("error", true)
span.LogFields(
otlog.Message(err.Error()),
)
}
}
return err
}