Skip to content

Commit

Permalink
Merge 580e5b2 into d62e931
Browse files Browse the repository at this point in the history
  • Loading branch information
karol-gruszczyk committed Mar 7, 2018
2 parents d62e931 + 580e5b2 commit 04b0118
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 17 deletions.
105 changes: 96 additions & 9 deletions slothql/schema.py
@@ -1,14 +1,101 @@
import collections
from typing import Dict

import graphql
from graphql.type.introspection import IntrospectionSchema
from graphql.type.typemap import GraphQLTypeMap

from slothql.utils import snake_to_camelcase
from .types.base import LazyType, resolve_lazy_type

FieldMap = Dict[str, graphql.GraphQLField]


class CamelCaseTypeMap(GraphQLTypeMap):
# FIXME: this is a really bad workaround, needs to be fixed ASAP
_type_map = {}

def __init__(self, types):
self.__class__._type_map = {}
super().__init__(types)

@classmethod
def reducer(cls, type_map: dict, of_type):
if of_type is None:
return type_map
return super().reducer(map=type_map, type=cls.get_graphql_type(of_type))

@classmethod
def get_graphql_type(cls, of_type):
if isinstance(of_type, (graphql.GraphQLNonNull, graphql.GraphQLList)):
return type(of_type)(type=cls.get_graphql_type(of_type.of_type))
if of_type.name in cls._type_map:
return cls._type_map[of_type.name]
if not of_type.name.startswith('__') and isinstance(of_type, graphql.GraphQLObjectType):
fields = of_type.fields
of_type = graphql.GraphQLObjectType(
name=of_type.name,
fields={},
interfaces=of_type.interfaces,
is_type_of=of_type.is_type_of,
description=of_type.description,
)
cls._type_map[of_type.name] = of_type
of_type.fields = cls.construct_fields(fields)
cls._type_map[of_type.name] = of_type
return of_type

from slothql.types.base import LazyType, resolve_lazy_type
@classmethod
def construct_fields(cls, fields: FieldMap) -> FieldMap:
return {
snake_to_camelcase(name): graphql.GraphQLField(
type=cls.get_graphql_type(field.type),
args={snake_to_camelcase(name): arg for name, arg in field.args.items()},
resolver=field.resolver,
deprecation_reason=field.deprecation_reason,
description=field.description,
)
for name, field in fields.items()
}


class Schema(graphql.GraphQLSchema):
def __init__(self, query: LazyType, mutation=None, subscription=None, directives=None, types=None):
super().__init__(
query=resolve_lazy_type(query)._type,
mutation=mutation,
subscription=subscription,
directives=directives,
types=types,
)
def __init__(self, query: LazyType, mutation=None, subscription=None, directives=None, types=None,
auto_camelcase: bool = False):
query = resolve_lazy_type(query)._type

assert isinstance(query, graphql.GraphQLObjectType), f'Schema query must be Object Type but got: {query}.'
if mutation:
assert isinstance(mutation, graphql.GraphQLObjectType), \
f'Schema mutation must be Object Type but got: {mutation}.'

if subscription:
assert isinstance(subscription, graphql.GraphQLObjectType), \
f'Schema subscription must be Object Type but got: {subscription}.'

if types:
assert isinstance(types, collections.Iterable), \
f'Schema types must be iterable if provided but got: {types}.'

self._query = query
self._mutation = mutation
self._subscription = subscription
if directives is None:
directives = graphql.specified_directives

assert all(isinstance(d, graphql.GraphQLDirective) for d in directives), \
f'Schema directives must be List[GraphQLDirective] if provided but got: {directives}.'
self._directives = directives

initial_types = [
query,
mutation,
subscription,
IntrospectionSchema,
]
if types:
initial_types += types
self._type_map = CamelCaseTypeMap(initial_types) if auto_camelcase else GraphQLTypeMap(initial_types)

def get_query_type(self):
return self._type_map[self._query.name]
34 changes: 30 additions & 4 deletions slothql/tests/schema.py
@@ -1,6 +1,6 @@
import pytest

from graphql import graphql
import graphql

import slothql

Expand All @@ -22,19 +22,45 @@ def test_can_init_with_callable_query(self):
def test_execution(self):
schema = slothql.Schema(query=self.query_class)
query = 'query { hello }'
assert 'world' == graphql(schema, query).data['hello']
assert {'data': {'hello': 'world'}} == slothql.gql(schema, query)

@pytest.mark.parametrize('call', (True, False))
def test_complex_schema(self, call):
class Nested(slothql.Object):
nested = slothql.Field(self.query_class() if call else self.query_class, lambda *_: {'world': 'not hello'})

query = 'query { nested { hello } }'
assert {'nested': {'hello': 'world'}} == graphql(slothql.Schema(query=Nested), query).data
assert {'data': {'nested': {'hello': 'world'}}} == slothql.gql(slothql.Schema(query=Nested), query)

def test_nested_in_null(self):
class Nested(slothql.Object):
nested = slothql.Field(self.query_class(), resolver=lambda *_: None)

query = 'query { nested { hello } }'
assert {'nested': None} == graphql(slothql.Schema(query=Nested), query).data
assert {'data': {'nested': None}} == slothql.gql(slothql.Schema(query=Nested), query)


def test_camelcase_schema_integration__queries():
class FooBar(slothql.Object):
foo_bar = slothql.Field(lambda: FooBar, resolver=lambda: {})
some_weird_field = slothql.String(resolver=lambda: 'some weird field')

class Query(slothql.Object):
string_field = slothql.String(resolver=lambda: 'string field')
foo_bar = slothql.Field(FooBar, resolver=lambda: {})

schema = slothql.Schema(query=Query, auto_camelcase=True)
assert {'data': {'stringField': 'string field', 'fooBar': {'someWeirdField': 'some weird field'}}} \
== slothql.gql(schema, 'query { stringField fooBar { someWeirdField } }')

# shouldn't modify the actual types
assert {'string_field', 'foo_bar'} == Query()._type.fields.keys() == Query()._type._fields.keys()
assert {'some_weird_field', 'foo_bar'} == FooBar()._type.fields.keys() == FooBar()._type._fields.keys()


def test_camelcase_schema_integration__introspection_query():
class Query(slothql.Object):
field = slothql.String()

schema = slothql.Schema(query=Query, auto_camelcase=True)
assert 'errors' not in slothql.gql(schema, graphql.introspection_query)
7 changes: 4 additions & 3 deletions slothql/types/fields/field.py
Expand Up @@ -2,13 +2,13 @@

import graphql

from slothql.arguments.utils import parse_argument
from slothql.utils import LazyInitMixin
from slothql import types
from slothql.utils import LazyInitMixin
from slothql.arguments.utils import parse_argument
from slothql.types.base import LazyType, resolve_lazy_type, BaseType

from .mixins import ListMixin
from .resolver import get_resolver, Resolver, PartialResolver, ResolveArgs
from .resolver import get_resolver, Resolver, PartialResolver, ResolveArgs, is_valid_resolver


class Field(LazyInitMixin, ListMixin, graphql.GraphQLField):
Expand All @@ -23,6 +23,7 @@ def get_resolver(self, resolver: PartialResolver, of_type: BaseType) -> Resolver
return get_resolver(self, resolver) or self.get_default_resolver(of_type)

def __init__(self, of_type: LazyType, resolver: PartialResolver = None, source: str = None, **kwargs):
assert resolver is None or is_valid_resolver(resolver), f'Resolver has to be callable, but got {resolver}'
of_type = resolve_lazy_type(of_type)
resolver = self.get_resolver(resolver, of_type)
assert callable(resolver), f'resolver needs to be callable, not {resolver}'
Expand Down
5 changes: 5 additions & 0 deletions slothql/types/fields/resolver.py
@@ -1,3 +1,4 @@
import inspect
import functools
from typing import Callable, Dict, Optional, Any

Expand Down Expand Up @@ -42,3 +43,7 @@ def resolver(parent, info, args):
def get_resolver(field, resolver: PartialResolver) -> Resolver:
func = _get_function(field, resolver)
return func and _inject_missing_args(func)


def is_valid_resolver(resolver: PartialResolver) -> bool:
return inspect.isfunction(resolver) or isinstance(resolver, (classmethod, staticmethod))
2 changes: 1 addition & 1 deletion slothql/types/fields/tests/lazy_fields.py
Expand Up @@ -13,7 +13,7 @@ class B(slothql.Object):
class Query(slothql.Object):
root = slothql.Field(A, resolver=lambda: {'b': {'a': {'field': 'foo'}}})

schema = slothql.Schema(Query)
schema = slothql.Schema(query=Query)

query = 'query { root { b { a { field } } } }'
assert {'data': {'root': {'b': {'a': {'field': 'foo'}}}}} == slothql.gql(schema, query)
Expand Down
2 changes: 2 additions & 0 deletions slothql/utils/__init__.py
@@ -1,9 +1,11 @@
from .attr import get_attr_fields, is_magic_name, get_attrs
from .case import snake_to_camelcase
from .query import query_from_raw_json
from .laziness import LazyInitMixin

__all__ = (
'get_attr_fields', 'is_magic_name', 'get_attrs',
'snake_to_camelcase',
'query_from_raw_json',
'LazyInitMixin',
)
5 changes: 5 additions & 0 deletions slothql/utils/case.py
@@ -0,0 +1,5 @@
def snake_to_camelcase(string: str) -> str:
first_char = next((i for i, c in enumerate(string) if c != '_'), len(string))
prefix, suffix = string[:first_char], string[first_char:]
words = [i or '_' for i in suffix.split('_')] if suffix else []
return prefix + ''.join(word.title() if i else word for i, word in enumerate(words))
20 changes: 20 additions & 0 deletions slothql/utils/tests/case.py
@@ -0,0 +1,20 @@
import pytest

from ..case import snake_to_camelcase


@pytest.mark.parametrize('value, expected', (
('', ''), ('_', '_'), ('__', '__'),
('fooBar', 'fooBar'),
('FooBar', 'FooBar'),
('foo_bar_baz', 'fooBarBaz'),
('foo__bar__baz', 'foo_Bar_Baz'),
('_foo_bar_baz', '_fooBarBaz'),
('__foo_bar_baz', '__fooBarBaz'),
('foo_bar_baz_', 'fooBarBaz_'),
('foo_bar_baz__', 'fooBarBaz__'),
('_foo_bar_baz_', '_fooBarBaz_'),
('__foo_bar_baz__', '__fooBarBaz__'),
))
def test_to_camelcase(value, expected):
assert expected == snake_to_camelcase(value)

0 comments on commit 04b0118

Please sign in to comment.