Skip to content

Commit

Permalink
refactor arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
karol-gruszczyk committed Mar 13, 2018
1 parent 19c906d commit 65f8643
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 38 deletions.
3 changes: 0 additions & 3 deletions slothql/arguments/filters.py
Expand Up @@ -2,11 +2,8 @@
import functools
from typing import Iterable, Callable, Union, Optional

from graphql.language import ast

from slothql.types import scalars

Filter = Callable[[Iterable, ast.Value], Iterable]
FilterValue = Union[int, str, bool, list]


Expand Down
6 changes: 3 additions & 3 deletions slothql/arguments/tests/integration.py
Expand Up @@ -5,9 +5,9 @@

@pytest.mark.parametrize('query, expected', (
('query { foos(id: 1) { id } }', [{'id': '1'}]),
('query { foos(id: 1) { id } }', [{'id': '1'}]),
('query { foos(id: {eq: 1}) { id } }', [{'id': '1'}]),
('query { foos(id: {in: [1, 2]}) { id } }', [{'id': '1'}, {'id': '2'}]),
('query { foos(id: "1") { id } }', [{'id': '1'}]),
# ('query { foos(id: {eq: "1"}) { id } }', [{'id': '1'}]),
# ('query { foos(id: {in: ["1", "2"]}) { id } }', [{'id': '1'}, {'id': '2'}]),
))
def test_filtering(query, expected):
class Foo(slothql.Object):
Expand Down
30 changes: 19 additions & 11 deletions slothql/schema.py
Expand Up @@ -79,10 +79,13 @@ def get_graphql_type(self, of_type: slothql.BaseType) -> GraphQLType:
graphql_type.fields = {
(snake_to_camelcase(name) if self.to_camelcase else name): graphql.GraphQLField(
type=self.get_type(field),
args=field.args,
args={
(snake_to_camelcase(arg_name) if self.to_camelcase else arg_name): self.get_argument(arg_field)
for arg_name, arg_field in field.filter_args.items()
},
resolver=field.resolver,
deprecation_reason=None,
description=None,
description=field.description,
) for name, field in of_type._meta.fields.items()
}
return graphql_type
Expand All @@ -97,6 +100,14 @@ def get_type(self, field: slothql.Field):
graphql_type = graphql.GraphQLList(type=graphql_type)
return graphql_type

def get_argument(self, field) -> graphql.GraphQLArgument:
return graphql.GraphQLArgument(
type=self.get_input_type(field.of_type),
)

def get_input_type(self, of_type):
return self.get_scalar_type(of_type)


class Schema(graphql.GraphQLSchema):
def __init__(self, query: LazyType, mutation=None, subscription=None, directives=None, types=None,
Expand All @@ -106,17 +117,14 @@ def __init__(self, query: LazyType, mutation=None, subscription=None, directives
query = query and graphql_type_map[resolve_lazy_type(query)._meta.name]
mutation = None
assert isinstance(query, graphql.GraphQLObjectType), f'Schema query must be Object Type but got: {query}.'
if mutation:
assert isinstance(mutation, graphql.GraphQLObjectType), \
f'Schema mutation must be Object Type but got: {mutation}.'
assert mutation is None or isinstance(mutation, graphql.GraphQLObjectType), \
f'Schema mutation must be Object Type but got: {mutation}.'

if subscription:
assert isinstance(subscription, graphql.GraphQLObjectType), \
f'Schema subscription must be Object Type but got: {subscription}.'
assert subscription is None or isinstance(subscription, graphql.GraphQLObjectType), \
f'Schema subscription must be Object Type but got: {subscription}.'

if types:
assert isinstance(types, collections.Iterable), \
f'Schema types must be iterable if provided but got: {types}.'
assert types is None or isinstance(types, collections.Iterable), \
f'Schema types must be iterable if provided but got: {types}.'

self._query = query
self._mutation = mutation
Expand Down
42 changes: 24 additions & 18 deletions slothql/types/fields/field.py
Expand Up @@ -5,33 +5,39 @@
import graphql

from slothql import types
from slothql.utils import LazyInitMixin
from slothql.arguments.utils import parse_argument
from slothql.types.base import LazyType, resolve_lazy_type, BaseType

from .resolver import get_resolver, Resolver, PartialResolver, ResolveArgs, is_valid_resolver


class Field(LazyInitMixin):
def get_default_resolver(self, of_type: BaseType) -> Resolver:
if isinstance(of_type, types.Object):
return lambda obj, info, args: of_type.resolve(self.resolve_field(obj, info, args), info, args)
return self.resolve_field

def get_resolver(self, resolver: PartialResolver, of_type: BaseType) -> Resolver:
return get_resolver(self, resolver) or self.get_default_resolver(of_type)

def __init__(self, of_type: LazyType, resolver: PartialResolver = None, source: str = None, **kwargs):
class Field:
def __init__(self, of_type: LazyType, resolver: PartialResolver = None, description: str = None,
source: str = None, many: bool = False, null: bool = True):
self._type = of_type

assert resolver is None or is_valid_resolver(resolver), f'Resolver has to be callable, but got {resolver}'
self._resolver = resolver

assert source is None or isinstance(source, str), f'source= has to be of type str'
assert description is None or isinstance(description, str), \
f'description needs to be of type str, not {description}'
self.description = description

assert source is None or isinstance(source, str), f'source= has to be of type str, not {source}'
self.source = source

self.many = kwargs.pop('many', False)
assert isinstance(self.many, bool), f'many has to be of type bool, not {self.many}'
assert many is None or isinstance(many, bool), f'many= has to be of type bool, not {many}'
self.many = many

assert null is None or isinstance(null, bool), f'null= has to be of type bool, not {null}'
self.null = null

def get_default_resolver(self, of_type: BaseType) -> Resolver:
if isinstance(of_type, types.Object):
return lambda obj, info, args: of_type.resolve(self.resolve_field(obj, info, args), info, args)
return self.resolve_field

def get_resolver(self, resolver: PartialResolver, of_type: BaseType) -> Resolver:
return get_resolver(self, resolver) or self.get_default_resolver(of_type)

@cached_property
def of_type(self) -> BaseType:
Expand All @@ -47,8 +53,8 @@ def resolver(self) -> Resolver:
return functools.partial(self.resolve, resolver)

@cached_property
def args(self) -> dict:
return self.of_type.args() if isinstance(self.of_type, types.Object) else {}
def filter_args(self) -> dict:
return self.of_type.filter_args() if isinstance(self.of_type, types.Object) else {}

@cached_property
def filters(self) -> dict:
Expand All @@ -61,7 +67,7 @@ def apply_filters(self, resolved, args: dict):
return resolved

def resolve(self, resolver: Resolver, obj, info: graphql.ResolveInfo, **kwargs):
args = {name: parse_argument(value) for name, value in kwargs.items()}
args = {name: value for name, value in kwargs.items()}
resolved = resolver(obj, info, args)
return self.apply_filters(resolved, args) if self.many else resolved

Expand Down
8 changes: 8 additions & 0 deletions slothql/types/object.py
Expand Up @@ -2,6 +2,7 @@

import graphql

from slothql.types import scalars
from slothql.arguments.filters import get_filter_fields
from slothql.types.base import BaseType, TypeMeta, TypeOptions
from slothql.types.fields import Field
Expand Down Expand Up @@ -40,6 +41,13 @@ class Meta:
def resolve(cls, parent, info: graphql.ResolveInfo, args: dict):
return parent

@classmethod
def filter_args(cls) -> Dict[str, Field]:
return {
name: Field(field.of_type)
for name, field in cls._meta.fields.items() if isinstance(field.of_type, scalars.ScalarType)
}

@classmethod
def args(cls) -> Dict[str, graphql.GraphQLArgument]:
return {name: graphql.GraphQLArgument(graphql.GraphQLString) for name, of_type in cls._meta.fields.items()}
Expand Down
14 changes: 11 additions & 3 deletions slothql/types/scalars.py
Expand Up @@ -8,6 +8,14 @@ class ScalarType(BaseType):
def serialize(cls, value):
return value

@classmethod
def parse(cls, value):
return value

@classmethod
def parse_literal(cls, value):
return value


class IntegerType(ScalarType):
class Meta:
Expand Down Expand Up @@ -42,9 +50,9 @@ class Meta:
def patch_default_scalar(scalar_type: ScalarType, graphql_type: scalars.GraphQLScalarType):
graphql_type.name = scalar_type._meta.name
graphql_type.description = scalar_type._meta.description
graphql_type.serialize = scalar_type.serialize
graphql_type.parse_value = scalar_type.serialize
graphql_type.parse_literal = scalar_type.serialize
# graphql_type.serialize = scalar_type.serialize
# graphql_type.parse_value = scalar_type.parse
# graphql_type.parse_literal = scalar_type.parse_literal


patch_default_scalar(IntegerType(), scalars.GraphQLInt)
Expand Down

0 comments on commit 65f8643

Please sign in to comment.