Skip to content

Commit

Permalink
refactor resolvers in order to accept query args
Browse files Browse the repository at this point in the history
  • Loading branch information
karol-gruszczyk committed Feb 26, 2018
1 parent f890ac5 commit 87395a1
Show file tree
Hide file tree
Showing 10 changed files with 111 additions and 80 deletions.
11 changes: 11 additions & 0 deletions slothql/conftest.py
Expand Up @@ -3,6 +3,7 @@

import django
from django.conf import settings
from django.db import models

import graphql
from graphql.type.definition import GraphQLType
Expand Down Expand Up @@ -55,3 +56,13 @@ def field_mock():
@pytest.fixture()
def partials_equal():
return lambda p1, p2: p1.func == p2.func and p1.args == p2.args and p1.keywords == p2.keywords


@pytest.fixture()
def manager_mock():
return mock.Mock(spec=models.Manager)


@pytest.fixture()
def queryset_mock():
return mock.Mock(spec=models.QuerySet)
16 changes: 12 additions & 4 deletions slothql/django/types/model.py
@@ -1,6 +1,7 @@
import inspect
from typing import Type, Iterable, Dict, Union

import graphql
from django.db import models

from slothql import Field
Expand Down Expand Up @@ -77,7 +78,14 @@ class Meta:
abstract = True

@classmethod
def resolve(cls, obj, info):
if obj is None:
return cls._meta.model._default_manager.get_queryset()
return obj.get_queryset() if isinstance(obj, models.Manager) else obj
def filter_queryset(cls, queryset: models.QuerySet, args: dict):
assert isinstance(queryset, models.QuerySet), f'expected QuerySet, received {repr(queryset)}'
return queryset.filter()

@classmethod
def resolve(cls, parent, info: graphql.ResolveInfo, args: dict):
if parent is None:
queryset = cls._meta.model._default_manager.get_queryset()
else:
queryset = parent.get_queryset() if isinstance(parent, models.Manager) else parent
return cls.filter_queryset(queryset, args)
16 changes: 8 additions & 8 deletions slothql/django/types/tests/resolution.py
Expand Up @@ -36,14 +36,14 @@ class Meta:
fields = '__all__'


def test_resolve__relation(info_mock):
manager = mock.Mock(models.Manager)
manager.get_queryset.return_value = [1, 2, 3]
parent = mock.Mock(spec=Parent, children=manager)
assert [1, 2, 3] == Parent.children.resolver(parent, info_mock(field_name='children'))
def test_resolve__relation(info_mock, manager_mock, queryset_mock):
manager_mock.get_queryset.return_value = queryset_mock
parent = mock.Mock(spec=Parent, children=manager_mock)
assert queryset_mock.filter() == Parent.children.resolver(parent, info_mock(field_name='children'))


def test_resolve__default(info_mock):
with mock.patch.object(Child._meta.model._default_manager, 'get_queryset', return_value=[1, 2, 3]) as get_queryset:
assert [1, 2, 3] == Parent.children.resolver(None, info_mock(field_name='children'))
def test_resolve__default(info_mock, queryset_mock):
model = Child._meta.model
with mock.patch.object(model._default_manager, 'get_queryset', return_value=queryset_mock) as get_queryset:
assert queryset_mock.filter() == Parent.children.resolver(None, info_mock(field_name='children'))
get_queryset.assert_called_with()
22 changes: 11 additions & 11 deletions slothql/types/fields/field.py
Expand Up @@ -3,45 +3,45 @@
import graphql

from slothql.utils import LazyInitMixin
from slothql import types
from slothql.types.base import LazyType, resolve_lazy_type, BaseType

from .mixins import ListMixin
from .resolver import Resolver
from .resolver import get_resolver, Resolver, PartialResolver, ResolveArgs


class Field(LazyInitMixin, ListMixin, graphql.GraphQLField):
__slots__ = ()

def get_default_resolver(self, of_type):
from slothql import types
def get_default_resolver(self, of_type: BaseType) -> Resolver:
if isinstance(of_type, types.Object):
return lambda obj, info: of_type.resolve(self.resolve_field(obj, info), info)
return lambda obj, info, args: of_type.resolve(self.resolve_field(obj, info, args), info, args)
return self.resolve_field

def get_resolver(self, resolver, of_type: BaseType):
return Resolver(self, resolver).func or self.get_default_resolver(of_type)
def get_resolver(self, resolver: PartialResolver, of_type: BaseType) -> Resolver:
return get_resolver(self, resolver) or self.get_default_resolver(of_type)

def __init__(self, of_type: LazyType, resolver=None, source: str = None, **kwargs):
def __init__(self, of_type: LazyType, resolver: PartialResolver = None, source: str = None, **kwargs):
of_type = resolve_lazy_type(of_type)
resolver = self.get_resolver(resolver, of_type)
assert callable(resolver), f'resolver needs to be callable, not {resolver}'

assert source is None or isinstance(source, str), f'source= has to be of type str'
self.source = source

super().__init__(type=of_type._type, resolver=functools.partial(self.resolve, resolver), **kwargs)
super().__init__(type=of_type._type, resolver=functools.partial(self.resolve, resolver), args={}, **kwargs)

@classmethod
def resolve(cls, resolver, obj, info: graphql.ResolveInfo):
return resolver(obj, info)
def resolve(cls, resolver: Resolver, obj, info: graphql.ResolveInfo, **kwargs):
return resolver(obj, info, kwargs)

def __repr__(self) -> str:
return f'<Field: {repr(self.type)}>'

def get_internal_name(self, name: str) -> str:
return self.source or name

def resolve_field(self, obj, info: graphql.ResolveInfo):
def resolve_field(self, obj, info: graphql.ResolveInfo, args: ResolveArgs):
if obj is None:
return None
name = self.get_internal_name(info.field_name)
Expand Down
77 changes: 39 additions & 38 deletions slothql/types/fields/resolver.py
@@ -1,43 +1,44 @@
import inspect
import functools
from typing import Callable, Dict, Optional, Any

import graphql
from graphql.language.ast import Value

from slothql.utils.functional import is_method, get_function_signature

ResolveArgs = Dict[str, Value]
PartialResolver = Callable[[Optional[Any], Optional[graphql.ResolveInfo], Optional[ResolveArgs]], Callable]
Resolver = Callable[[Any, graphql.ResolveInfo, ResolveArgs], Callable]


def _get_function(field, resolver: PartialResolver = None) -> Optional[Resolver]:
if resolver is None:
return None
if isinstance(resolver, staticmethod):
return resolver.__func__
if isinstance(resolver, classmethod):
return functools.partial(resolver.__func__, type(field))
if is_method(resolver):
return functools.partial(resolver, field)
return resolver


def _inject_missing_args(func: PartialResolver) -> Resolver:
signature, arg_count = get_function_signature(func)
assert arg_count <= 3, f'{func} expected arguments to be of signature (parent, info, args), received {signature}'
if arg_count < 3:
@functools.wraps(func)
def resolver(parent, info, args):
if arg_count == 0:
return func()
elif arg_count == 1:
return func(parent)
return func(parent, info)

class Resolver:
@staticmethod
def is_method(func):
return func.__code__.co_varnames and 'self' == func.__code__.co_varnames[0]

def __init__(self, field, resolver):
func = self.get_function(field, resolver)
self.func = func and self.inject_missing_args(func)

@classmethod
def get_function(cls, field, resolver):
if resolver is None:
return None
if isinstance(resolver, staticmethod):
return resolver.__func__
if isinstance(resolver, classmethod):
return functools.partial(resolver.__func__, type(field))
if cls.is_method(resolver):
return functools.partial(resolver, field)
return resolver
return func


@classmethod
def inject_missing_args(cls, func):
if isinstance(func, functools.partial):
signature = inspect.signature(func.func)
arg_count = len(signature.parameters) - len(func.args)
else:
assert callable(func), f'Expected callable, got {func}'
signature = inspect.signature(func)
arg_count = len(signature.parameters)
assert arg_count <= 2, \
f'{func} expected arguments to be of signature (obj, info), received {signature}'
if arg_count < 2:
@functools.wraps(func)
def resolver(obj, info):
return func() if arg_count == 0 else func(obj)

return resolver
return func
def get_resolver(field, resolver: PartialResolver) -> Resolver:
func = _get_function(field, resolver)
return func and _inject_missing_args(func)
2 changes: 1 addition & 1 deletion slothql/types/fields/tests/field.py
Expand Up @@ -35,4 +35,4 @@ def test_default_resolver(self, type_mock, partials_equal):
))
def test_resolve_field(obj, expected, info_mock, type_mock):
for field_name, expected_value in expected.items():
assert expected_value == Field(type_mock()).resolve_field(obj, info_mock(field_name=field_name))
assert expected_value == Field(type_mock()).resolve_field(obj, info_mock(field_name=field_name), {})
25 changes: 10 additions & 15 deletions slothql/types/fields/tests/resolver.py
@@ -1,6 +1,6 @@
import pytest

from ..resolver import Resolver
from ..resolver import PartialResolver, get_resolver


class A:
Expand Down Expand Up @@ -38,19 +38,14 @@ def class3(cls, obj, info):
return 'foo'


@pytest.mark.parametrize('func', (
@pytest.mark.parametrize('resolver', (
lambda: 'foo',
lambda o: 'foo',
lambda o, i: 'foo',
A.method1,
A.method2,
A.method3,
A.static1,
A.static2,
A.static3,
A.class1,
A.class2,
A.class3,
lambda parent: 'foo',
lambda parent, info: 'foo',
lambda parent, info, args: 'foo',
A.method1, A.method2, A.method3,
A.static1, A.static2, A.static3,
A.class1, A.class2, A.class3,
))
def test_resolve(func):
assert 'foo' == Resolver(None, func).func(None, None)
def test_resolve(resolver: PartialResolver, field_mock):
assert 'foo' == get_resolver(field_mock, resolver)(None, None, None)
2 changes: 1 addition & 1 deletion slothql/types/fields/tests/source.py
Expand Up @@ -9,7 +9,7 @@ def test_field_source(type_mock, info_mock):
class A(slothql.Object):
foo = slothql.Field(type_mock(), source='bar')

assert 'baz' == A.foo.resolve_field({'bar': 'baz'}, info_mock(field_name='foo'))
assert 'baz' == A.foo.resolve_field({'bar': 'baz'}, info_mock(field_name='foo'), {})


def test_integration():
Expand Down
4 changes: 2 additions & 2 deletions slothql/types/object.py
Expand Up @@ -38,5 +38,5 @@ class Meta:
abstract = True

@classmethod
def resolve(cls, obj, info):
return obj
def resolve(cls, parent, info: graphql.ResolveInfo, args: dict):
return parent
16 changes: 16 additions & 0 deletions slothql/utils/functional.py
@@ -0,0 +1,16 @@
import inspect
import functools
from typing import Callable, Tuple


def is_method(func):
return func.__code__.co_varnames and 'self' == func.__code__.co_varnames[0] # FIXME: noqa


def get_function_signature(func: Callable) -> Tuple[inspect.Signature, int]:
if isinstance(func, functools.partial):
signature = inspect.signature(func.func)
return signature, len(signature.parameters) - len(func.args)
assert callable(func), f'Expected callable, got {func}'
signature = inspect.signature(func)
return signature, len(signature.parameters)

0 comments on commit 87395a1

Please sign in to comment.