Skip to content

Commit

Permalink
copy complexity to RequestContext
Browse files Browse the repository at this point in the history
  • Loading branch information
vvakame committed Oct 29, 2018
1 parent 0d5c65b commit a027ac2
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
4 changes: 4 additions & 0 deletions graphql/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ type RequestContext struct {
RawQuery string
Variables map[string]interface{}
Doc *ast.QueryDocument

ComplexityLimit int
OperationComplexity int

// ErrorPresenter will be used to generate the error
// message from errors given to Error().
ErrorPresenter ErrorPresenterFunc
Expand Down
19 changes: 11 additions & 8 deletions handler/graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type Config struct {
complexityLimit int
}

func (c *Config) newRequestContext(doc *ast.QueryDocument, query string, variables map[string]interface{}) *graphql.RequestContext {
func (c *Config) newRequestContext(es graphql.ExecutableSchema, doc *ast.QueryDocument, op *ast.OperationDefinition, query string, variables map[string]interface{}) *graphql.RequestContext {
reqCtx := graphql.NewRequestContext(doc, query, variables)
if hook := c.recover; hook != nil {
reqCtx.Recover = hook
Expand All @@ -59,6 +59,12 @@ func (c *Config) newRequestContext(doc *ast.QueryDocument, query string, variabl
reqCtx.Tracer = &graphql.NopTracer{}
}

if c.complexityLimit > 0 {
reqCtx.ComplexityLimit = c.complexityLimit
operationComplexity := complexity.Calculate(es, op, variables)
reqCtx.OperationComplexity = operationComplexity
}

return reqCtx
}

Expand Down Expand Up @@ -298,7 +304,7 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc
sendError(w, http.StatusUnprocessableEntity, err)
return
}
reqCtx := cfg.newRequestContext(doc, reqParams.Query, vars)
reqCtx := cfg.newRequestContext(exec, doc, op, reqParams.Query, vars)
ctx := graphql.WithRequestContext(r.Context(), reqCtx)

defer func() {
Expand All @@ -308,12 +314,9 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc
}
}()

if cfg.complexityLimit > 0 {
queryComplexity := complexity.Calculate(exec, op, vars)
if queryComplexity > cfg.complexityLimit {
sendErrorf(w, http.StatusUnprocessableEntity, "query has complexity %d, which exceeds the limit of %d", queryComplexity, cfg.complexityLimit)
return
}
if reqCtx.ComplexityLimit > 0 && reqCtx.OperationComplexity > cfg.complexityLimit {
sendErrorf(w, http.StatusUnprocessableEntity, "operation has complexity %d, which exceeds the limit of %d", operationComplexity, cfg.complexityLimit)
return
}

switch op.Operation {
Expand Down

0 comments on commit a027ac2

Please sign in to comment.