Skip to content

Commit c1b22d7

Browse files
committed
Improved GraphQL server integration
1 parent 15fe6c5 commit c1b22d7

File tree

2 files changed

+42
-37
lines changed

2 files changed

+42
-37
lines changed

flask_graphql/graphqlview.py

+14-27
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import json
2+
from functools import partial
23

34
from flask import Response, request
45
from flask.views import View
56

67
from graphql.type.schema import GraphQLSchema
7-
from graphql_server import run_http_query, HttpQueryError, default_format_error, load_json_body, format_execution_result
8+
from graphql_server import run_http_query, HttpQueryError, default_format_error, load_json_body, encode_execution_results, json_encode
89

910
from .render_graphiql import render_graphiql
1011

@@ -54,8 +55,10 @@ def render_graphiql(self, params, result):
5455
graphiql_template=self.graphiql_template,
5556
)
5657

58+
format_error = staticmethod(default_format_error)
59+
encode = staticmethod(json_encode)
60+
5761
def dispatch_request(self):
58-
5962
try:
6063
request_method = request.method.lower()
6164
data = self.parse_body()
@@ -72,24 +75,19 @@ def dispatch_request(self):
7275
query_data=request.args,
7376
batch_enabled=self.batch,
7477
catch=catch,
78+
7579
# Execute options
7680
root_value=self.get_root_value(),
7781
context_value=self.get_context(),
7882
middleware=self.get_middleware(),
7983
executor=self.get_executor(),
8084
)
81-
responses = [
82-
format_execution_result(execution_result, default_format_error)
83-
for execution_result in execution_results
84-
]
85-
result, status_codes = zip(*responses)
86-
status_code = max(status_codes)
87-
88-
# If is not batch
89-
if not isinstance(data, list):
90-
result = result[0]
91-
92-
result = self.json_encode(result, pretty)
85+
result, status_code = encode_execution_results(
86+
execution_results,
87+
is_batch=isinstance(data, list),
88+
format_error=self.format_error,
89+
encode=partial(self.encode, pretty=pretty)
90+
)
9391

9492
if show_graphiql:
9593
return self.render_graphiql(
@@ -105,8 +103,8 @@ def dispatch_request(self):
105103

106104
except HttpQueryError as e:
107105
return Response(
108-
self.json_encode({
109-
'errors': [default_format_error(e)]
106+
self.encode({
107+
'errors': [self.format_error(e)]
110108
}),
111109
status=e.status_code,
112110
headers=e.headers,
@@ -131,17 +129,6 @@ def parse_body(self):
131129

132130
return {}
133131

134-
@staticmethod
135-
def json_encode(data, pretty=False):
136-
if not pretty:
137-
return json.dumps(data, separators=(',', ':'))
138-
139-
return json.dumps(
140-
data,
141-
indent=2,
142-
separators=(',', ': ')
143-
)
144-
145132
def should_display_graphiql(self):
146133
if not self.graphiql or 'raw' in request.args:
147134
return False

graphql_server/__init__.py

+28-10
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def run_http_query(schema, request_method, data, query_data=None, batch_enabled=
6767
extra_data = {}
6868
# If is a batch request, we don't consume the data from the query
6969
if not is_batch:
70-
extra_data = query_data
70+
extra_data = query_data or {}
7171

7272
all_params = [get_graphql_params(entry, extra_data) for entry in data]
7373

@@ -82,6 +82,31 @@ def run_http_query(schema, request_method, data, query_data=None, batch_enabled=
8282
return responses, all_params
8383

8484

85+
def encode_execution_results(execution_results, format_error, is_batch, encode):
86+
responses = [
87+
format_execution_result(execution_result, format_error)
88+
for execution_result in execution_results
89+
]
90+
result, status_codes = zip(*responses)
91+
status_code = max(status_codes)
92+
93+
if not is_batch:
94+
result = result[0]
95+
96+
return encode(result), status_code
97+
98+
99+
def json_encode(data, pretty=False):
100+
if not pretty:
101+
return json.dumps(data, separators=(',', ':'))
102+
103+
return json.dumps(
104+
data,
105+
indent=2,
106+
separators=(',', ': ')
107+
)
108+
109+
85110
def load_json_variables(variables):
86111
if variables and isinstance(variables, six.text_type):
87112
try:
@@ -111,21 +136,14 @@ def get_response(schema, params, catch=None, allow_only_query=False, **kwargs):
111136
**kwargs
112137
)
113138
except catch:
114-
execution_result = ExecutionResult(
115-
data=None,
116-
invalid=True,
117-
)
118-
# return GraphQLResponse(None, 400)
119-
139+
return None
140+
120141
return execution_result
121142

122143

123144
def format_execution_result(execution_result, format_error):
124145
status_code = 200
125146

126-
if isinstance(execution_result, Promise):
127-
execution_result = execution_result.get()
128-
129147
if execution_result:
130148
response = {}
131149

0 commit comments

Comments
 (0)