Skip to content

Commit

Permalink
add schema type building for camelcase support
Browse files Browse the repository at this point in the history
  • Loading branch information
karol-gruszczyk committed Mar 7, 2018
1 parent 8e50a4a commit 8078e5d
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 28 deletions.
109 changes: 82 additions & 27 deletions slothql/schema.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,92 @@
import collections
from typing import Dict

import graphql
from graphql.type.introspection import IntrospectionSchema
from graphql.type.typemap import GraphQLTypeMap

from slothql.utils import snake_to_camelcase
from .types.base import LazyType, resolve_lazy_type

FieldMap = Dict[str, graphql.GraphQLField]


class CamelCaseTypeMap(GraphQLTypeMap):
@classmethod
def reducer(cls, type_map: dict, of_type):
if of_type is None:
return type_map
return super().reducer(map=type_map, type=cls.get_graphql_type(type_map, of_type))

@classmethod
def get_graphql_type(cls, type_map: dict, of_type):
if isinstance(of_type, (graphql.GraphQLNonNull, graphql.GraphQLList)):
return type(of_type)(type=cls.get_graphql_type(type_map, of_type.of_type))
if of_type.name in type_map:
return type_map[of_type.name]
if isinstance(of_type, graphql.GraphQLObjectType):
fields = of_type.fields
of_type = type_map[of_type.name] = graphql.GraphQLObjectType(
name=of_type.name,
fields={},
interfaces=of_type.interfaces,
is_type_of=of_type.is_type_of,
description=of_type.description,
)
of_type.fields = cls.construct_fields(type_map, fields)
return of_type

@classmethod
def construct_fields(cls, type_map: dict, fields: FieldMap) -> FieldMap:
return {
snake_to_camelcase(name): graphql.GraphQLField(
type=cls.get_graphql_type(type_map, field.type),
args=field.args,
resolver=field.resolver,
deprecation_reason=field.deprecation_reason,
description=field.description,
)
for name, field in fields.items()
}


class Schema(graphql.GraphQLSchema):
def __init__(self, query: LazyType, mutation=None, subscription=None, directives=None, types=None,
auto_camelcase: bool = False):
query = resolve_lazy_type(query)._type
super().__init__(
query=self.construct_graphql_type(query) if auto_camelcase else query,
mutation=mutation,
subscription=subscription,
directives=directives,
types=types,
)

def construct_graphql_type(self, obj):
if isinstance(obj, graphql.GraphQLObjectType):
return graphql.GraphQLObjectType(
name=obj.name,
fields={
snake_to_camelcase(name): graphql.GraphQLField(
type=self.construct_graphql_type(field.type),
args=field.args,
resolver=field.resolver,
deprecation_reason=field.deprecation_reason,
description=field.description,
)
for name, field in obj.fields.items()
},
interfaces=obj.interfaces,
is_type_of=obj.is_type_of,
description=obj.description,
)
return obj

assert isinstance(query, graphql.GraphQLObjectType), f'Schema query must be Object Type but got: {query}.'
if mutation:
assert isinstance(mutation, graphql.GraphQLObjectType), \
f'Schema mutation must be Object Type but got: {mutation}.'

if subscription:
assert isinstance(subscription, graphql.GraphQLObjectType), \
f'Schema subscription must be Object Type but got: {subscription}.'

if types:
assert isinstance(types, collections.Iterable), \
f'Schema types must be iterable if provided but got: {types}.'

self._query = query
self._mutation = mutation
self._subscription = subscription
if directives is None:
directives = graphql.specified_directives

assert all(isinstance(d, graphql.GraphQLDirective) for d in directives), \
f'Schema directives must be List[GraphQLDirective] if provided but got: {directives}.'
self._directives = directives

initial_types = [
query,
mutation,
subscription,
IntrospectionSchema,
]
if types:
initial_types += types
self._type_map = CamelCaseTypeMap(initial_types) if auto_camelcase else GraphQLTypeMap(initial_types)

def get_query_type(self):
return self._type_map[self._query.name]
2 changes: 1 addition & 1 deletion slothql/tests/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,4 @@ class Query(slothql.Object):

# shouldn't modify the actual types
assert {'string_field', 'foo_bar'} == Query()._type.fields.keys() == Query()._type._fields.keys()
assert {'some_weird_field'} == FooBar()._type.fields.keys() == FooBar()._type._fields.keys()
assert {'some_weird_field', 'foo_bar'} == FooBar()._type.fields.keys() == FooBar()._type._fields.keys()

0 comments on commit 8078e5d

Please sign in to comment.