-
Notifications
You must be signed in to change notification settings - Fork 4
/
manager.py
226 lines (182 loc) · 9.18 KB
/
manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import django
from django.db import transaction, models
from django.db.models import sql, Field, QuerySet
from typing import Dict, Any, List, Type, Optional, Tuple
from django.db.models.query import EmptyResultSet
from .compatibility import chain_query, get_model_fields
from .queryset import ReturningQuerySet
class UpdateReturningMixin(object):
@staticmethod
def _get_loaded_field_cb(target, model, fields):
"""
Callback used by get_deferred_field_names().
"""
target[model] = fields
def _insert(self, objs, fields, return_id=False, raw=False, using=None, ignore_conflicts=False):
"""
Replaces standard insert procedure for bulk_create_returning
"""
if not getattr(self.model, '_insert_returning', False):
return QuerySet._insert(self, objs, fields, return_id=return_id, raw=raw, using=using)
# Returns attname, not column.
# Before django 1.10 pk fields hasn't been returned from postgres.
# In this case, I can't match bulk_create results and return values by primary key.
# So I select all data from returned results
return_fields = self._get_fields(ignore_deferred=(django.VERSION < (1, 10)))
assert len(return_fields) == 1 and list(return_fields.keys())[0] == self.model, \
"You can't fetch relative model fields with returning operation"
self._for_write = True
kwargs = {} if django.VERSION < (2, 2) else {'ignore_conflicts': ignore_conflicts}
query = sql.InsertQuery(self.model, **kwargs)
query.insert_values(fields, objs, raw=raw)
self.model._insert_returning_cache = self._execute_sql(query, return_fields, using=using)
return self.model._insert_returning_cache.values_list(self.model._meta.pk.column, flat=True) if return_id else None
_insert.alters_data = True
_insert.queryset_only = False
def _get_fields(self, ignore_deferred=False): # type: (bool) -> Dict[models.Model: List[models.Field]]
"""
Gets a dictionary of fields for each model, selected by .only() and .defer() methods
:param ignore_deferred: If set, ignores .only() and .defer() filters
:return: A dictionary with model as key, fields list as value
"""
fields = {}
if not ignore_deferred:
self.query.deferred_to_data(fields, self._get_loaded_field_cb)
# No .only() or .defer() operations
if not fields:
# Remove all fields without columns in table
fields = {self.model: get_model_fields(self.model, concrete=True)}
return fields
def _execute_sql(self, query, return_fields, using=None):
return_fields_str = ', '.join('"%s"' % str(f.column) for f in return_fields[self.model])
if using is None:
using = self.db
self._result_cache = None
try:
res = query.get_compiler(using).as_sql()
if isinstance(res, list):
assert len(res) == 1, "Can't update relative model with returning"
res = res[0]
query_sql, query_params = res
except EmptyResultSet:
return ReturningQuerySet(None)
query_sql = query_sql + ' RETURNING %s' % return_fields_str
with transaction.atomic(using=using, savepoint=False):
return ReturningQuerySet(query_sql, model=self.model, params=query_params, using=using,
fields=[f.attname for f in return_fields[self.model]])
def _get_returning_qs(self, query_type, values=None, **updates):
# type: (Type[sql.Query], Optional[Any], **Dict[str, Any]) -> ReturningQuerySet
"""
Partial for update_returning functions
:param updates: Data to pass to update(**updates) method
:return: RawQuerySet of results
:raises AssertionError: If input data is invalid
"""
assert self.query.can_filter(), "Can not update or delete once a slice has been taken."
assert getattr(self, '_fields', None) is None,\
"Can not call delete() or update() after .values() or .values_list()"
# Returns attname, not column.
fields = self._get_fields()
assert len(fields) == 1 and list(fields.keys())[0] == self.model, \
"You can't fetch relative model fields with returning operation"
self._for_write = True
query = chain_query(self, query_type)
if updates:
query.add_update_values(updates)
if values:
query.add_update_fields(values)
# Disable not supported fields.
query._annotations = None
query.select_for_update = False
query.select_related = False
query.clear_ordering(force_empty=True)
return self._execute_sql(query, fields)
def update_returning(self, **updates):
# type: (**Dict[str, Any]) -> ReturningQuerySet
"""
Gets RawQuerySet of all fields, got with UPDATE ... RETURNING fields
:return: RawQuerySet
"""
assert updates, "No updates where provided"
return self._get_returning_qs(sql.UpdateQuery, **updates)
def _update_returning(self, values):
# type: (List[Tuple[Field, Any, Any]]) -> ReturningQuerySet
"""
A version of update_returning() that accepts field objects instead of field names.
Used primarily for model saving and not intended for use by general
code (it requires too much poking around at model internals to be
useful at that level).
"""
assert values, "No updates where provided"
return self._get_returning_qs(sql.UpdateQuery, values=values)
def delete_returning(self): # type: () -> ReturningQuerySet
"""
Gets RawQuerySet of all fields, got with DELETE ... RETURNING
:return: RawQuerySet
"""
return self._get_returning_qs(sql.DeleteQuery)
def bulk_create_returning(self, objs, batch_size=None):
# It's more logical to use QuerySet object to store this data.
# But django before 1.10 calls self.model._base_manager._insert instead of self._insert
# And generates other QuerySet.
self.model._insert_returning = True
self.model._insert_returning_cache = {}
if django.VERSION < (1, 10):
base_manager = self.model._base_manager
try:
# Compatibility for old django versions which call self.model._base_manager._insert instead of self._insert
self.model._base_manager = self.as_manager()
self.model._base_manager.model = self.model
result = self.bulk_create(objs, batch_size=batch_size)
finally:
# Restore base manager after operation, event if it failed.
# If not restored, it will be shared by other code
self.model._base_manager = base_manager
else:
result = self.bulk_create(objs, batch_size=batch_size)
# Replace values fetched from returned data
if result and result[0].pk:
# For django 1.10+ where objects can be matched
values_dict = {item[self.model._meta.pk.column]: item for item in self.model._insert_returning_cache.values()}
for item in result:
for k, v in values_dict[item.pk].items():
setattr(item, k, v)
else:
# For django before 1.10 which doesn't fetch primary key
result = list(self.model._insert_returning_cache)
# Clean up
self.model._insert_returning = False
self.model._insert_returning_cache = {}
return result
class UpdateReturningQuerySet(UpdateReturningMixin, models.QuerySet):
@classmethod
def clone_query_set(cls, qs): # type: (QuerySet) -> UpdateReturningQuerySet
"""
Copies standard QuerySet.clone() method, changing base class name
:param qs: QuerySet to copy from
:return: An UpdateReturningQuerySet, cloned from qs
"""
query = chain_query(qs)
c = cls(model=qs.model, query=query, using=qs._db, hints=qs._hints)
c._sticky_filter = qs._sticky_filter
c._for_write = qs._for_write
c._prefetch_related_lookups = qs._prefetch_related_lookups[:]
c._known_related_objects = qs._known_related_objects
# Some fields are absent in earlier django versions
if hasattr(qs, '_iterable_class'):
c._iterable_class = qs._iterable_class
if hasattr(qs, '_fields'):
c._fields = qs._fields
return c
class UpdateReturningManager(models.Manager):
def bulk_create_returning(self, objs, batch_size=None):
# In early django automatic fetching QuerySet public methods fails
return self.get_queryset().bulk_create_returning(objs, batch_size=batch_size)
def update_returning(self, **updates):
# In early django automatic fetching QuerySet public methods fails
return self.get_queryset().update_returning(**updates)
def delete_returning(self):
# In early django automatic fetching QuerySet public methods fails
return self.get_queryset().delete_returning()
def get_queryset(self):
return UpdateReturningQuerySet(using=self.db, model=self.model)