diff --git a/open_city_profile/graphene.py b/open_city_profile/graphene.py index 8ead93ec..6bad7d7a 100644 --- a/open_city_profile/graphene.py +++ b/open_city_profile/graphene.py @@ -8,15 +8,16 @@ from graphene_django import DjangoObjectType from graphene_django.forms.converter import convert_form_field from graphene_django.types import ALL_FIELDS +from graphql_sync_dataloaders import SyncDataLoader from parler.models import TranslatableModel from profiles.loaders import ( - AddressesByProfileIdLoader, - EmailsByProfileIdLoader, - PhonesByProfileIdLoader, - PrimaryAddressForProfileLoader, - PrimaryEmailForProfileLoader, - PrimaryPhoneForProfileLoader, + addresses_by_profile_id_loader, + emails_by_profile_id_loader, + phones_by_profile_id_loader, + primary_address_for_profile_loader, + primary_email_for_profile_loader, + primary_phone_for_profile_loader, ) @@ -52,12 +53,12 @@ class UUIDMultipleChoiceFilter(MultipleChoiceFilter): _LOADERS = { - "addresses_by_profile_id_loader": AddressesByProfileIdLoader, - "emails_by_profile_id_loader": EmailsByProfileIdLoader, - "phones_by_profile_id_loader": PhonesByProfileIdLoader, - "primary_address_for_profile_loader": PrimaryAddressForProfileLoader, - "primary_email_for_profile_loader": PrimaryEmailForProfileLoader, - "primary_phone_for_profile_loader": PrimaryPhoneForProfileLoader, + "addresses_by_profile_id_loader": addresses_by_profile_id_loader, + "emails_by_profile_id_loader": emails_by_profile_id_loader, + "phones_by_profile_id_loader": phones_by_profile_id_loader, + "primary_address_for_profile_loader": primary_address_for_profile_loader, + "primary_email_for_profile_loader": primary_email_for_profile_loader, + "primary_phone_for_profile_loader": primary_phone_for_profile_loader, } @@ -69,9 +70,8 @@ def resolve(self, next, root, info, **kwargs): context = info.context if not self.cached_loaders: - for loader_name, loader_class in _LOADERS.items(): - setattr(context, loader_name, loader_class()) - + for loader_name, loader_function in _LOADERS.items(): + setattr(context, loader_name, SyncDataLoader(loader_function)) self.cached_loaders = True return next(root, info, **kwargs) diff --git a/open_city_profile/tests/conftest.py b/open_city_profile/tests/conftest.py index 78df621b..9684d075 100644 --- a/open_city_profile/tests/conftest.py +++ b/open_city_profile/tests/conftest.py @@ -11,6 +11,7 @@ from graphene_django.settings import graphene_settings from graphene_django.views import instantiate_middleware from graphql import build_client_schema, get_introspection_query +from graphql_sync_dataloaders import DeferredExecutionContext from helusers.authz import UserAuthorization from open_city_profile.schema import schema @@ -127,6 +128,11 @@ def keycloak_setup(settings): settings.KEYCLOAK_CLIENT_SECRET = "test-keycloak-client-secret" +@pytest.fixture +def execution_context_class(): + return DeferredExecutionContext + + @pytest.fixture def user(): return UserFactory() @@ -178,9 +184,10 @@ def superuser_gql_client(): @pytest.fixture -def gql_schema(anon_user_gql_client): +def gql_schema(anon_user_gql_client, execution_context_class): introspection = anon_user_gql_client.execute( - get_introspection_query(descriptions=False) + get_introspection_query(descriptions=False), + execution_context_class=execution_context_class, ) return build_client_schema(introspection["data"]) diff --git a/open_city_profile/urls.py b/open_city_profile/urls.py index 472036f3..dbcd592c 100644 --- a/open_city_profile/urls.py +++ b/open_city_profile/urls.py @@ -6,6 +6,7 @@ from django.utils.translation import gettext_lazy as _ from django.views.decorators.csrf import csrf_exempt from django.views.generic import TemplateView +from graphql_sync_dataloaders import DeferredExecutionContext from open_city_profile.views import GraphQLView @@ -14,7 +15,10 @@ path( "graphql/", csrf_exempt( - GraphQLView.as_view(graphiql=settings.ENABLE_GRAPHIQL or settings.DEBUG) + GraphQLView.as_view( + graphiql=settings.ENABLE_GRAPHIQL or settings.DEBUG, + execution_context_class=DeferredExecutionContext, + ) ), ), path("auth/", include("helusers.urls")), diff --git a/profiles/loaders.py b/profiles/loaders.py index 3292eec9..29892c35 100644 --- a/profiles/loaders.py +++ b/profiles/loaders.py @@ -1,54 +1,48 @@ +import uuid from collections import defaultdict - -from promise import Promise -from promise.dataloader import DataLoader +from typing import Callable, List from profiles.models import Address, Email, Phone -def loader_for_profile(model): - class BaseByProfileIdLoader(DataLoader): - def batch_load_fn(self, profile_ids): - items_by_profile_ids = defaultdict(list) - for item in model.objects.filter(profile_id__in=profile_ids).iterator(): - items_by_profile_ids[item.profile_id].append(item) - return Promise.resolve( - [items_by_profile_ids.get(profile_id, []) for profile_id in profile_ids] - ) +def loader_for_profile(model) -> Callable: + def batch_load_fn(profile_ids: List[uuid.UUID]) -> List[List[model]]: + items_by_profile_ids = defaultdict(list) + for item in model.objects.filter(profile_id__in=profile_ids).iterator(): + items_by_profile_ids[item.profile_id].append(item) + + return [items_by_profile_ids[profile_id] for profile_id in profile_ids] - return BaseByProfileIdLoader + return batch_load_fn -def loader_for_profile_primary(model): - class BaseByProfileIdPrimaryLoader(DataLoader): - def batch_load_fn(self, profile_ids): - items_by_profile_ids = defaultdict() - for item in model.objects.filter( - profile_id__in=profile_ids, primary=True - ).iterator(): - items_by_profile_ids[item.profile_id] = item +def loader_for_profile_primary(model) -> Callable: + def batch_load_fn(profile_ids: List[uuid.UUID]) -> List[model]: + items_by_profile_ids = {} + for item in model.objects.filter( + profile_id__in=profile_ids, primary=True + ).iterator(): + items_by_profile_ids[item.profile_id] = item - return Promise.resolve( - [items_by_profile_ids.get(profile_id) for profile_id in profile_ids] - ) + return [items_by_profile_ids.get(profile_id) for profile_id in profile_ids] - return BaseByProfileIdPrimaryLoader + return batch_load_fn -EmailsByProfileIdLoader = loader_for_profile(Email) -PhonesByProfileIdLoader = loader_for_profile(Phone) -AddressesByProfileIdLoader = loader_for_profile(Address) +addresses_by_profile_id_loader = loader_for_profile(Address) +emails_by_profile_id_loader = loader_for_profile(Email) +phones_by_profile_id_loader = loader_for_profile(Phone) -PrimaryEmailForProfileLoader = loader_for_profile_primary(Email) -PrimaryPhoneForProfileLoader = loader_for_profile_primary(Phone) -PrimaryAddressForProfileLoader = loader_for_profile_primary(Address) +primary_address_for_profile_loader = loader_for_profile_primary(Address) +primary_email_for_profile_loader = loader_for_profile_primary(Email) +primary_phone_for_profile_loader = loader_for_profile_primary(Phone) __all__ = [ - "AddressesByProfileIdLoader", - "EmailsByProfileIdLoader", - "PhonesByProfileIdLoader", - "PrimaryAddressForProfileLoader", - "PrimaryEmailForProfileLoader", - "PrimaryPhoneForProfileLoader", + "addresses_by_profile_id_loader", + "emails_by_profile_id_loader", + "phones_by_profile_id_loader", + "primary_address_for_profile_loader", + "primary_email_for_profile_loader", + "primary_phone_for_profile_loader", ] diff --git a/profiles/schema.py b/profiles/schema.py index 33f3bb46..ce46fd0b 100644 --- a/profiles/schema.py +++ b/profiles/schema.py @@ -547,13 +547,13 @@ def resolve_primary_address(self: Profile, info, **kwargs): return info.context.primary_address_for_profile_loader.load(self.id) def resolve_emails(self: Profile, info, **kwargs): - return info.context.emails_by_profile_id_loader.load(self.id) + return self.emails.all() def resolve_phones(self: Profile, info, **kwargs): - return info.context.phones_by_profile_id_loader.load(self.id) + return self.phones.all() def resolve_addresses(self: Profile, info, **kwargs): - return info.context.addresses_by_profile_id_loader.load(self.id) + return self.addresses.all() @key(fields="id") diff --git a/profiles/tests/test_gql_claim_profile_mutation.py b/profiles/tests/test_gql_claim_profile_mutation.py index 415b471e..f66e6d2f 100644 --- a/profiles/tests/test_gql_claim_profile_mutation.py +++ b/profiles/tests/test_gql_claim_profile_mutation.py @@ -129,7 +129,9 @@ def test_can_not_delete_primary_email(user_gql_client): assert_match_error_code(executed, "PROFILE_MUST_HAVE_PRIMARY_EMAIL") -def test_changing_an_email_address_marks_it_unverified(user_gql_client): +def test_changing_an_email_address_marks_it_unverified( + user_gql_client, execution_context_class +): profile = ProfileFactory(user=None) email = EmailFactory(profile=profile, verified=True) claim_token = ClaimTokenFactory(profile=profile) @@ -163,7 +165,11 @@ def test_changing_an_email_address_marks_it_unverified(user_gql_client): }, } - executed = user_gql_client.execute(CLAIM_PROFILE_MUTATION, variables=variables) + executed = user_gql_client.execute( + CLAIM_PROFILE_MUTATION, + variables=variables, + execution_context_class=execution_context_class, + ) assert "errors" not in executed assert executed["data"] == expected_data diff --git a/profiles/tests/test_gql_my_profile_query.py b/profiles/tests/test_gql_my_profile_query.py index 55f7ca21..2b79339c 100644 --- a/profiles/tests/test_gql_my_profile_query.py +++ b/profiles/tests/test_gql_my_profile_query.py @@ -130,7 +130,9 @@ def test_normal_user_can_query_addresses(user_gql_client): assert dict(executed["data"]) == expected_data -def test_normal_user_can_query_primary_contact_details(user_gql_client): +def test_normal_user_can_query_primary_contact_details( + user_gql_client, execution_context_class +): profile = ProfileFactory(user=user_gql_client.user) phone = PhoneFactory(profile=profile, primary=True) email = EmailFactory(profile=profile, primary=True) @@ -179,7 +181,9 @@ def test_normal_user_can_query_primary_contact_details(user_gql_client): }, } } - executed = user_gql_client.execute(query) + executed = user_gql_client.execute( + query, execution_context_class=execution_context_class + ) assert dict(executed["data"]) == expected_data diff --git a/profiles/tests/test_gql_update_my_profile_mutation.py b/profiles/tests/test_gql_update_my_profile_mutation.py index a2f579fa..b919b570 100644 --- a/profiles/tests/test_gql_update_my_profile_mutation.py +++ b/profiles/tests/test_gql_update_my_profile_mutation.py @@ -1061,7 +1061,7 @@ def test_can_not_remove_address_of_another_profile(user_gql_client): def test_change_primary_contact_details( - user_gql_client, email_data, phone_data, address_data + user_gql_client, email_data, phone_data, address_data, execution_context_class ): profile = ProfileFactory(user=user_gql_client.user) PhoneFactory(profile=profile, primary=True) @@ -1159,7 +1159,10 @@ def test_change_primary_contact_details( address_type=address_data["address_type"], primary="true", ) - executed = user_gql_client.execute(mutation) + executed = user_gql_client.execute( + mutation, execution_context_class=execution_context_class + ) + assert "errors" not in executed assert dict(executed["data"]) == expected_data diff --git a/requirements.in b/requirements.in index 452609b7..8d29caad 100644 --- a/requirements.in +++ b/requirements.in @@ -11,6 +11,7 @@ django-sanitized-dump django-searchable-encrypted-fields graphene-django graphene-federation +graphql-sync-dataloaders git+https://github.com/City-of-Helsinki/graphene-validator.git@graphene3 ipython iso3166 diff --git a/requirements.txt b/requirements.txt index c342d89c..e28fcba3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -81,10 +81,13 @@ graphql-core==3.2.3 # graphene-django # graphene-federation # graphql-relay + # graphql-sync-dataloaders graphql-relay==3.2.0 # via # graphene # graphene-django +graphql-sync-dataloaders==0.1.1 + # via -r requirements.in idna==3.6 # via requests ipython==8.21.0