-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
context.go
197 lines (161 loc) · 5.29 KB
/
context.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
package graphql
import (
"context"
"fmt"
"sync"
"github.com/vektah/gqlparser/ast"
"github.com/vektah/gqlparser/gqlerror"
)
type Resolver func(ctx context.Context) (res interface{}, err error)
type FieldMiddleware func(ctx context.Context, next Resolver) (res interface{}, err error)
type RequestMiddleware func(ctx context.Context, next func(ctx context.Context) []byte) []byte
type RequestContext struct {
RawQuery string
Variables map[string]interface{}
Doc *ast.QueryDocument
// ErrorPresenter will be used to generate the error
// message from errors given to Error().
ErrorPresenter ErrorPresenterFunc
Recover RecoverFunc
ResolverMiddleware FieldMiddleware
DirectiveMiddleware FieldMiddleware
RequestMiddleware RequestMiddleware
errorsMu sync.Mutex
Errors gqlerror.List
extensionsMu sync.Mutex
Extensions map[string]interface{}
}
func DefaultResolverMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) {
return next(ctx)
}
func DefaultDirectiveMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) {
return next(ctx)
}
func DefaultRequestMiddleware(ctx context.Context, next func(ctx context.Context) []byte) []byte {
return next(ctx)
}
func NewRequestContext(doc *ast.QueryDocument, query string, variables map[string]interface{}) *RequestContext {
return &RequestContext{
Doc: doc,
RawQuery: query,
Variables: variables,
ResolverMiddleware: DefaultResolverMiddleware,
DirectiveMiddleware: DefaultDirectiveMiddleware,
RequestMiddleware: DefaultRequestMiddleware,
Recover: DefaultRecover,
ErrorPresenter: DefaultErrorPresenter,
}
}
type key string
const (
request key = "request_context"
resolver key = "resolver_context"
)
func GetRequestContext(ctx context.Context) *RequestContext {
val := ctx.Value(request)
if val == nil {
return nil
}
return val.(*RequestContext)
}
func WithRequestContext(ctx context.Context, rc *RequestContext) context.Context {
return context.WithValue(ctx, request, rc)
}
type ResolverContext struct {
Parent *ResolverContext
// The name of the type this field belongs to
Object string
// These are the args after processing, they can be mutated in middleware to change what the resolver will get.
Args map[string]interface{}
// The raw field
Field CollectedField
// The index of array in path.
Index *int
// The result object of resolver
Result interface{}
}
func (r *ResolverContext) Path() []interface{} {
var path []interface{}
for it := r; it != nil; it = it.Parent {
if it.Index != nil {
path = append(path, *it.Index)
} else if it.Field.Field != nil {
path = append(path, it.Field.Alias)
}
}
// because we are walking up the chain, all the elements are backwards, do an inplace flip.
for i := len(path)/2 - 1; i >= 0; i-- {
opp := len(path) - 1 - i
path[i], path[opp] = path[opp], path[i]
}
return path
}
func GetResolverContext(ctx context.Context) *ResolverContext {
val, _ := ctx.Value(resolver).(*ResolverContext)
return val
}
func WithResolverContext(ctx context.Context, rc *ResolverContext) context.Context {
rc.Parent = GetResolverContext(ctx)
return context.WithValue(ctx, resolver, rc)
}
// This is just a convenient wrapper method for CollectFields
func CollectFieldsCtx(ctx context.Context, satisfies []string) []CollectedField {
resctx := GetResolverContext(ctx)
return CollectFields(ctx, resctx.Field.Selections, satisfies)
}
// Errorf sends an error string to the client, passing it through the formatter.
func (c *RequestContext) Errorf(ctx context.Context, format string, args ...interface{}) {
c.errorsMu.Lock()
defer c.errorsMu.Unlock()
c.Errors = append(c.Errors, c.ErrorPresenter(ctx, fmt.Errorf(format, args...)))
}
// Error sends an error to the client, passing it through the formatter.
func (c *RequestContext) Error(ctx context.Context, err error) {
c.errorsMu.Lock()
defer c.errorsMu.Unlock()
c.Errors = append(c.Errors, c.ErrorPresenter(ctx, err))
}
// HasError returns true if the current field has already errored
func (c *RequestContext) HasError(rctx *ResolverContext) bool {
c.errorsMu.Lock()
defer c.errorsMu.Unlock()
path := rctx.Path()
for _, err := range c.Errors {
if equalPath(err.Path, path) {
return true
}
}
return false
}
func equalPath(a []interface{}, b []interface{}) bool {
if len(a) != len(b) {
return false
}
for i := 0; i < len(a); i++ {
if a[i] != b[i] {
return false
}
}
return true
}
// AddError is a convenience method for adding an error to the current response
func AddError(ctx context.Context, err error) {
GetRequestContext(ctx).Error(ctx, err)
}
// AddErrorf is a convenience method for adding an error to the current response
func AddErrorf(ctx context.Context, format string, args ...interface{}) {
GetRequestContext(ctx).Errorf(ctx, format, args...)
}
// RegisterExtension registers an extension, returns error if extension has already been registered
func (c *RequestContext) RegisterExtension(key string, value interface{}) error {
c.extensionsMu.Lock()
defer c.extensionsMu.Unlock()
if c.Extensions == nil {
c.Extensions = make(map[string]interface{})
}
if _, ok := c.Extensions[key]; ok {
return fmt.Errorf("extension already registered for key %s", key)
}
c.Extensions[key] = value
return nil
}