-
Notifications
You must be signed in to change notification settings - Fork 227
/
Copy pathbatching.py
143 lines (119 loc) · 5.42 KB
/
batching.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""The dataloader uses "select in loading" strategy to load related entities."""
from asyncio import get_event_loop
from typing import Any, Dict
import sqlalchemy
from sqlalchemy.orm import Session, strategies
from sqlalchemy.orm.query import QueryContext
from sqlalchemy.util import immutabledict
from .utils import (
SQL_VERSION_HIGHER_EQUAL_THAN_1_4,
SQL_VERSION_HIGHER_EQUAL_THAN_2,
is_graphene_version_less_than,
)
def get_data_loader_impl() -> Any: # pragma: no cover
"""Graphene >= 3.1.1 ships a copy of aiodataloader with minor fixes. To preserve backward-compatibility,
aiodataloader is used in conjunction with older versions of graphene"""
if is_graphene_version_less_than("3.1.1"):
from aiodataloader import DataLoader
else:
from graphene.utils.dataloader import DataLoader
return DataLoader
DataLoader = get_data_loader_impl()
class RelationshipLoader(DataLoader):
cache = False
def __init__(self, relationship_prop, selectin_loader):
super().__init__()
self.relationship_prop = relationship_prop
self.selectin_loader = selectin_loader
async def batch_load_fn(self, parents):
"""
Batch loads the relationships of all the parents as one SQL statement.
There is no way to do this out-of-the-box with SQLAlchemy but
we can piggyback on some internal APIs of the `selectin`
eager loading strategy. It's a bit hacky but it's preferable
than re-implementing and maintainnig a big chunk of the `selectin`
loader logic ourselves.
The approach here is to build a regular query that
selects the parent and `selectin` load the relationship.
But instead of having the query emits 2 `SELECT` statements
when callling `all()`, we skip the first `SELECT` statement
and jump right before the `selectin` loader is called.
To accomplish this, we have to construct objects that are
normally built in the first part of the query in order
to call directly `SelectInLoader._load_for_path`.
TODO Move this logic to a util in the SQLAlchemy repo as per
SQLAlchemy's main maitainer suggestion.
See https://git.io/JewQ7
"""
child_mapper = self.relationship_prop.mapper
parent_mapper = self.relationship_prop.parent
session = Session.object_session(parents[0])
# These issues are very unlikely to happen in practice...
for parent in parents:
# assert parent.__mapper__ is parent_mapper
# All instances must share the same session
assert session is Session.object_session(parent)
# The behavior of `selectin` is undefined if the parent is dirty
assert parent not in session.dirty
# Should the boolean be set to False? Does it matter for our purposes?
states = [(sqlalchemy.inspect(parent), True) for parent in parents]
# For our purposes, the query_context will only used to get the session
query_context = None
if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
parent_mapper_query = session.query(parent_mapper.entity)
query_context = parent_mapper_query._compile_context()
else:
query_context = QueryContext(session.query(parent_mapper.entity))
if SQL_VERSION_HIGHER_EQUAL_THAN_2: # pragma: no cover
self.selectin_loader._load_for_path(
query_context,
parent_mapper._path_registry,
states,
None,
child_mapper,
None,
None, # recursion depth can be none
immutabledict(), # default value for selectinload->lazyload
)
elif SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
self.selectin_loader._load_for_path(
query_context,
parent_mapper._path_registry,
states,
None,
child_mapper,
None,
)
else:
self.selectin_loader._load_for_path(
query_context,
parent_mapper._path_registry,
states,
None,
child_mapper,
)
return [getattr(parent, self.relationship_prop.key) for parent in parents]
# Cache this across `batch_load_fn` calls
# This is so SQL string generation is cached under-the-hood via `bakery`
# Caching the relationship loader for each relationship prop.
RELATIONSHIP_LOADERS_CACHE: Dict[
sqlalchemy.orm.relationships.RelationshipProperty, RelationshipLoader
] = {}
def get_batch_resolver(relationship_prop):
"""Get the resolve function for the given relationship."""
def _get_loader(relationship_prop):
"""Retrieve the cached loader of the given relationship."""
loader = RELATIONSHIP_LOADERS_CACHE.get(relationship_prop, None)
if loader is None or loader.loop != get_event_loop():
selectin_loader = strategies.SelectInLoader(
relationship_prop, (("lazy", "selectin"),)
)
loader = RelationshipLoader(
relationship_prop=relationship_prop,
selectin_loader=selectin_loader,
)
RELATIONSHIP_LOADERS_CACHE[relationship_prop] = loader
return loader
async def resolve(root, info, **args):
return await _get_loader(relationship_prop).load(root)
return resolve