Skip to content

Commit

Permalink
generate graphql types in schema
Browse files Browse the repository at this point in the history
  • Loading branch information
karol-gruszczyk committed Mar 12, 2018
1 parent d6d1703 commit 5d52d01
Show file tree
Hide file tree
Showing 16 changed files with 112 additions and 172 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
graphql-core>=2<3
cached-property==1.4

# optional
django>=2.0<3
Expand Down
15 changes: 8 additions & 7 deletions slothql/arguments/filters.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import operator
import functools
from typing import Iterable, Callable, Dict, Union
from typing import Iterable, Callable, Union, Optional

import graphql
from graphql.language import ast

from slothql.types import scalars

Filter = Callable[[Iterable, ast.Value], Iterable]
FilterValue = Union[int, str, bool, list]

Expand Down Expand Up @@ -57,11 +58,11 @@ def apply(self, collection: Iterable, field_name: str, value: FilterValue) -> It
}, 'eq')


def get_filter_fields(of_type: graphql.GraphQLScalarType) -> Dict[str, graphql.GraphQLField]:
if of_type == graphql.GraphQLID:
def get_filter_fields(scalar_type: scalars.ScalarType) -> Optional[FilterSet]:
if isinstance(scalar_type, scalars.IDType):
return IDFilterSet
elif of_type == graphql.GraphQLString:
elif isinstance(scalar_type, scalars.StringType):
return StringFilterSet
elif of_type == graphql.GraphQLInt:
elif isinstance(scalar_type, scalars.IntegerType):
return IntegerFilterSet
return {}
return None
116 changes: 46 additions & 70 deletions slothql/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import Dict

import graphql
from graphql.type import GraphQLEnumValue
from graphql.type.definition import GraphQLType
from graphql.type.introspection import IntrospectionSchema
from graphql.type.typemap import GraphQLTypeMap

Expand All @@ -14,54 +16,6 @@
FieldMap = Dict[str, graphql.GraphQLField]


class CamelCaseTypeMap(GraphQLTypeMap):
# FIXME: this is a really bad workaround, needs to be fixed ASAP
_type_map = {}

def __init__(self, types):
self.__class__._type_map = {}
super().__init__(types)

@classmethod
def reducer(cls, type_map: dict, of_type):
if of_type is None:
return type_map
return super().reducer(map=type_map, type=cls.get_graphql_type(of_type))

@classmethod
def get_graphql_type(cls, of_type):
if isinstance(of_type, (graphql.GraphQLNonNull, graphql.GraphQLList)):
return type(of_type)(type=cls.get_graphql_type(of_type.of_type))
if of_type.name in cls._type_map:
return cls._type_map[of_type.name]
if not of_type.name.startswith('__') and isinstance(of_type, graphql.GraphQLObjectType):
fields = of_type.fields
of_type = graphql.GraphQLObjectType(
name=of_type.name,
fields={},
interfaces=of_type.interfaces,
is_type_of=of_type.is_type_of,
description=of_type.description,
)
cls._type_map[of_type.name] = of_type
of_type.fields = cls.construct_fields(fields)
cls._type_map[of_type.name] = of_type
return of_type

@classmethod
def construct_fields(cls, fields: FieldMap) -> FieldMap:
return {
snake_to_camelcase(name): graphql.GraphQLField(
type=cls.get_graphql_type(field.type),
args={snake_to_camelcase(name): arg for name, arg in field.args.items()},
resolver=field.resolver,
deprecation_reason=field.deprecation_reason,
description=field.description,
)
for name, field in fields.items()
}


class TypeMap(dict):
def __init__(self, *types: slothql.BaseType):
super().__init__(functools.reduce(self.type_reducer, types, {}))
Expand All @@ -80,37 +34,65 @@ def type_reducer(self, type_map: dict, of_type: slothql.BaseType):

class ProxyTypeMap(dict):
def __init__(self, type_map: TypeMap, to_camelcase: bool = False):
super().__init__()
self.to_camelcase = to_camelcase
super().__init__(functools.reduce(self.type_reducer, type_map.values(), {}))

def type_reducer(self, type_map: dict, of_type: slothql.BaseType) -> dict:
if of_type._meta.name in type_map:
return type_map

if isinstance(of_type, slothql.Object):
for of_type in type_map.values():
self.get_graphql_type(of_type)

def get_scalar_type(self, of_type: scalars.ScalarType):
if isinstance(of_type, scalars.IDType):
return graphql.GraphQLID
elif isinstance(of_type, scalars.StringType):
return graphql.GraphQLString
elif isinstance(of_type, scalars.BooleanType):
return graphql.GraphQLBoolean
elif isinstance(of_type, scalars.IntegerType):
return graphql.GraphQLInt
elif isinstance(of_type, scalars.FloatType):
return graphql.GraphQLFloat
raise NotImplementedError(f'{of_type} conversion is not implemented')

def get_graphql_type(self, of_type: slothql.BaseType) -> GraphQLType:
if of_type._meta.name in self:
return self[of_type._meta.name]
elif isinstance(of_type, scalars.ScalarType):
graphql_type = self.get_scalar_type(of_type)
elif isinstance(of_type, slothql.Enum):
return graphql.GraphQLEnumType(
name=of_type._meta.name,
values={
(snake_to_camelcase(name) if self.to_camelcase else name): GraphQLEnumValue(
value=value.value, description=value.description)
for name, value in of_type._meta.enum_values.items()
},
description=of_type._meta.description,
)
elif isinstance(of_type, slothql.Object):
graphql_type = graphql.GraphQLObjectType(
name=of_type._meta.name,
fields={},
interfaces=None,
is_type_of=None,
description=None,
)
type_map[of_type._meta.name] = graphql_type
self[graphql_type.name] = graphql_type
graphql_type.fields = {
(snake_to_camelcase(name) if self.to_camelcase else name): graphql.GraphQLField(
type=self.get_type(type_map, field),
type=self.get_type(field),
args=field.args,
resolver=field.resolver,
deprecation_reason=field.deprecation_reason,
description=field.description,
deprecation_reason=None,
description=None,
) for name, field in of_type._meta.fields.items()
}
if isinstance(of_type, scalars.ScalarType):
type_map[of_type._meta.name] = of_type._type
return type_map
return graphql_type
else:
raise NotImplementedError(f'Unsupported type {of_type}')
self[graphql_type.name] = graphql_type
return graphql_type

def get_type(self, type_map: dict, field: slothql.Field):
graphql_type = self.type_reducer(type_map, field.of_type)[field.of_type._meta.name]
def get_type(self, field: slothql.Field):
graphql_type = self.get_graphql_type(field.of_type)
if field.many:
graphql_type = graphql.GraphQLList(type=graphql_type)
return graphql_type
Expand Down Expand Up @@ -153,9 +135,3 @@ def __init__(self, query: LazyType, mutation=None, subscription=None, directives
IntrospectionSchema,
] + (types or [])
self._type_map = GraphQLTypeMap(initial_types)

def build_query_type(self, root_type: slothql.Object) -> graphql.GraphQLObjectType:
raise NotImplementedError

def get_query_type(self):
return self._type_map[self._query.name]
4 changes: 2 additions & 2 deletions slothql/tests/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ class Query(slothql.Object):
== slothql.gql(schema, 'query { stringField fooBar { someWeirdField } }')

# shouldn't modify the actual types
assert {'string_field', 'foo_bar'} == Query()._type.fields.keys() == Query()._type._fields.keys()
assert {'some_weird_field', 'foo_bar'} == FooBar()._type.fields.keys() == FooBar()._type._fields.keys()
assert {'string_field', 'foo_bar'} == Query()._meta.fields.keys()
assert {'some_weird_field', 'foo_bar'} == FooBar()._meta.fields.keys()


def test_camelcase_schema_integration__introspection_query():
Expand Down
9 changes: 0 additions & 9 deletions slothql/types/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import inspect
from typing import Union, Type, Callable, Tuple, Iterable

from graphql.type.definition import GraphQLType

from slothql.utils import is_magic_name, get_attr_fields
from slothql.utils.singleton import Singleton

Expand Down Expand Up @@ -71,13 +69,6 @@ def __new__(cls, *more):
assert not cls._meta.abstract, f'Abstract type {cls.__name__} can not be instantiated'
return super().__new__(cls)

def __init__(self, type_: GraphQLType):
self._type = type_

@classmethod
def get_output_type(cls) -> GraphQLType:
raise NotImplementedError


LazyType = Union[Type[BaseType], BaseType, Callable]

Expand Down
23 changes: 9 additions & 14 deletions slothql/types/enum.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import graphql
from typing import Type

from graphql.type import GraphQLEnumValue

from .base import BaseType, TypeMeta, TypeOptions


class EnumOptions(TypeOptions):
__slots__ = 'values',
__slots__ = 'enum_values',

def __init__(self, attrs: dict):
super().__init__(attrs)
assert self.abstract or self.values, f'"{self.name}" is missing valid `Enum` values'
assert self.abstract or self.enum_values, f'"{self.name}" is missing valid `Enum` values'


class EnumMeta(TypeMeta):
Expand All @@ -21,20 +18,18 @@ def __new__(mcs, *args, options_class: Type[EnumOptions] = EnumOptions, **kwargs
@classmethod
def get_option_attrs(mcs, name: str, base_attrs: dict, attrs: dict, meta_attrs: dict):
return {**super().get_option_attrs(name, base_attrs, attrs, meta_attrs), **{
'values': {field_name: field for field_name, field in attrs.items() if isinstance(field, EnumValue)},
'enum_values': {field_name: field for field_name, field in attrs.items() if isinstance(field, EnumValue)},
}}


class EnumValue(GraphQLEnumValue):
pass
class EnumValue:
__slots__ = 'value', 'description'

def __init__(self, value, description: str = None):
self.value = value
self.description = description


class Enum(BaseType, metaclass=EnumMeta):
class Meta:
abstract = True

def __init__(self):
super().__init__(type_=graphql.GraphQLEnumType(
name=self._meta.name,
values=self._meta.values,
))
43 changes: 27 additions & 16 deletions slothql/types/fields/field.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import functools

from cached_property import cached_property

import graphql

from slothql import types
from slothql.utils import LazyInitMixin
from slothql.arguments.utils import parse_argument
from slothql.types.base import LazyType, resolve_lazy_type, BaseType

from .mixins import ListMixin
from .resolver import get_resolver, Resolver, PartialResolver, ResolveArgs, is_valid_resolver


class Field(LazyInitMixin, ListMixin, graphql.GraphQLField):
__slots__ = ()

class Field(LazyInitMixin):
def get_default_resolver(self, of_type: BaseType) -> Resolver:
if isinstance(of_type, types.Object):
return lambda obj, info, args: of_type.resolve(self.resolve_field(obj, info, args), info, args)
Expand All @@ -23,25 +22,37 @@ def get_resolver(self, resolver: PartialResolver, of_type: BaseType) -> Resolver
return get_resolver(self, resolver) or self.get_default_resolver(of_type)

def __init__(self, of_type: LazyType, resolver: PartialResolver = None, source: str = None, **kwargs):
self._type = of_type

assert resolver is None or is_valid_resolver(resolver), f'Resolver has to be callable, but got {resolver}'
of_type = resolve_lazy_type(of_type)
resolver = self.get_resolver(resolver, of_type)
assert callable(resolver), f'resolver needs to be callable, not {resolver}'
self._resolver = resolver

assert source is None or isinstance(source, str), f'source= has to be of type str'
self.source = source

args = of_type.args() if isinstance(of_type, types.Object) else {}
self.many = kwargs.pop('many', False)
assert isinstance(self.many, bool), f'many has to be of type bool, not {self.many}'

@cached_property
def of_type(self) -> BaseType:
resolved_type = resolve_lazy_type(self._type)
assert isinstance(resolved_type, BaseType), \
f'{self} "of_type" needs to be of type BaseType, not {resolved_type}'
return resolved_type

@cached_property
def resolver(self) -> Resolver:
resolver = self.get_resolver(self._resolver, self.of_type)
assert callable(resolver), f'resolver needs to be callable, not {resolver}'
return functools.partial(self.resolve, resolver)

super().__init__(
type=of_type._type,
resolver=functools.partial(self.resolve, resolver),
args=args,
**kwargs
)
self.of_type = of_type
@cached_property
def args(self) -> dict:
return self.of_type.args() if isinstance(self.of_type, types.Object) else {}

self.filters = of_type.filters() if isinstance(of_type, types.Object) else {}
@cached_property
def filters(self) -> dict:
return self.of_type.filters() if isinstance(self.of_type, types.Object) else {}

def apply_filters(self, resolved, args: dict):
for field_name, value in args.items():
Expand Down
12 changes: 0 additions & 12 deletions slothql/types/fields/mixins.py

This file was deleted.

2 changes: 1 addition & 1 deletion slothql/types/fields/tests/shortcuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
slothql.Time,
))
def test_field_init(field):
assert field().type
assert field()._type
9 changes: 1 addition & 8 deletions slothql/types/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,9 @@ def get_option_attrs(mcs, name: str, base_attrs: dict, attrs: dict, meta_attrs:


class Object(BaseType, metaclass=ObjectMeta):
def __init__(self):
super().__init__(graphql.GraphQLObjectType(name=self.__class__.__name__, fields=self._meta.fields))

class Meta:
abstract = True

@classmethod
def get_output_type(cls) -> graphql.GraphQLObjectType:
return graphql.GraphQLObjectType(name=cls.__name__, fields=cls._meta.fields)

@classmethod
def resolve(cls, parent, info: graphql.ResolveInfo, args: dict):
return parent
Expand All @@ -53,4 +46,4 @@ def args(cls) -> Dict[str, graphql.GraphQLArgument]:

@classmethod
def filters(cls) -> Dict[str, FilterSet]:
return {name: get_filter_fields(field.type) for name, field in cls._meta.fields.items()}
return {name: get_filter_fields(field.of_type) for name, field in cls._meta.fields.items()}

0 comments on commit 5d52d01

Please sign in to comment.