-
Notifications
You must be signed in to change notification settings - Fork 391
/
prepared_stmt.pyx
354 lines (281 loc) · 11 KB
/
prepared_stmt.pyx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from asyncpg import exceptions
@cython.final
cdef class PreparedStatementState:
def __cinit__(self, str name, str query, BaseProtocol protocol):
self.name = name
self.query = query
self.protocol = protocol
self.settings = protocol.settings
self.row_desc = self.parameters_desc = None
self.args_codecs = self.rows_codecs = None
self.args_num = self.cols_num = 0
self.cols_desc = None
self.closed = False
self.refs = 0
self.buffer = FastReadBuffer.new()
def _get_parameters(self):
cdef Codec codec
result = []
for oid in self.parameters_desc:
codec = self.settings.get_data_codec(oid)
if codec is None:
raise RuntimeError
result.append(apg_types.Type(
oid, codec.name, codec.kind, codec.schema))
return tuple(result)
def _get_attributes(self):
cdef Codec codec
if not self.row_desc:
return ()
result = []
for d in self.row_desc:
name = d[0]
oid = d[3]
codec = self.settings.get_data_codec(oid)
if codec is None:
raise RuntimeError
name = name.decode(self.settings._encoding)
result.append(
apg_types.Attribute(name,
apg_types.Type(oid, codec.name, codec.kind, codec.schema)))
return tuple(result)
def _init_types(self):
cdef:
Codec codec
set missing = set()
if self.parameters_desc:
for p_oid in self.parameters_desc:
codec = self.settings.get_data_codec(<uint32_t>p_oid)
if codec is None or not codec.has_encoder():
missing.add(p_oid)
if self.row_desc:
for rdesc in self.row_desc:
codec = self.settings.get_data_codec(<uint32_t>(rdesc[3]))
if codec is None or not codec.has_decoder():
missing.add(rdesc[3])
return missing
cpdef _init_codecs(self):
self._ensure_args_encoder()
self._ensure_rows_decoder()
def attach(self):
self.refs += 1
def detach(self):
self.refs -= 1
def mark_closed(self):
self.closed = True
cdef _encode_bind_msg(self, args):
cdef:
int idx
WriteBuffer writer
Codec codec
if len(args) > 32767:
raise exceptions.InterfaceError(
'the number of query arguments cannot exceed 32767')
writer = WriteBuffer.new()
num_args_passed = len(args)
if self.args_num != num_args_passed:
hint = 'Check the query against the passed list of arguments.'
if self.args_num == 0:
# If the server was expecting zero arguments, it is likely
# that the user tried to parametrize a statement that does
# not support parameters.
hint += (r' Note that parameters are supported only in'
r' SELECT, INSERT, UPDATE, DELETE, and VALUES'
r' statements, and will *not* work in statements '
r' like CREATE VIEW or DECLARE CURSOR.')
raise exceptions.InterfaceError(
'the server expects {x} argument{s} for this query, '
'{y} {w} passed'.format(
x=self.args_num, s='s' if self.args_num != 1 else '',
y=num_args_passed,
w='was' if num_args_passed == 1 else 'were'),
hint=hint)
if self.have_text_args:
writer.write_int16(self.args_num)
for idx in range(self.args_num):
codec = <Codec>(self.args_codecs[idx])
writer.write_int16(codec.format)
else:
# All arguments are in binary format
writer.write_int32(0x00010001)
writer.write_int16(self.args_num)
for idx in range(self.args_num):
arg = args[idx]
if arg is None:
writer.write_int32(-1)
else:
codec = <Codec>(self.args_codecs[idx])
try:
codec.encode(self.settings, writer, arg)
except (AssertionError, exceptions.InternalClientError):
# These are internal errors and should raise as-is.
raise
except exceptions.InterfaceError:
# This is already a descriptive error.
raise
except Exception as e:
# Everything else is assumed to be an encoding error
# due to invalid input.
value_repr = repr(arg)
if len(value_repr) > 40:
value_repr = value_repr[:40] + '...'
raise exceptions.DataError(
'invalid input for query argument'
' ${n}: {v} ({msg})'.format(
n=idx + 1, v=value_repr, msg=e)) from e
if self.have_text_cols:
writer.write_int16(self.cols_num)
for idx in range(self.cols_num):
codec = <Codec>(self.rows_codecs[idx])
writer.write_int16(codec.format)
else:
# All columns are in binary format
writer.write_int32(0x00010001)
return writer
cdef _ensure_rows_decoder(self):
cdef:
list cols_names
object cols_mapping
tuple row
uint32_t oid
Codec codec
list codecs
if self.cols_desc is not None:
return
if self.cols_num == 0:
self.cols_desc = record.ApgRecordDesc_New({}, ())
return
cols_mapping = collections.OrderedDict()
cols_names = []
codecs = []
for i from 0 <= i < self.cols_num:
row = self.row_desc[i]
col_name = row[0].decode(self.settings._encoding)
cols_mapping[col_name] = i
cols_names.append(col_name)
oid = row[3]
codec = self.settings.get_data_codec(oid)
if codec is None or not codec.has_decoder():
raise RuntimeError('no decoder for OID {}'.format(oid))
if not codec.is_binary():
self.have_text_cols = True
codecs.append(codec)
self.cols_desc = record.ApgRecordDesc_New(
cols_mapping, tuple(cols_names))
self.rows_codecs = tuple(codecs)
cdef _ensure_args_encoder(self):
cdef:
uint32_t p_oid
Codec codec
list codecs = []
if self.args_num == 0 or self.args_codecs is not None:
return
for i from 0 <= i < self.args_num:
p_oid = self.parameters_desc[i]
codec = self.settings.get_data_codec(p_oid)
if codec is None or not codec.has_encoder():
raise RuntimeError('no encoder for OID {}'.format(p_oid))
if codec.type not in {}:
self.have_text_args = True
codecs.append(codec)
self.args_codecs = tuple(codecs)
cdef _set_row_desc(self, object desc):
self.row_desc = _decode_row_desc(desc)
self.cols_num = <int16_t>(len(self.row_desc))
cdef _set_args_desc(self, object desc):
self.parameters_desc = _decode_parameters_desc(desc)
self.args_num = <int16_t>(len(self.parameters_desc))
cdef _decode_row(self, const char* cbuf, ssize_t buf_len):
cdef:
Codec codec
int16_t fnum
int32_t flen
object dec_row
tuple rows_codecs = self.rows_codecs
ConnectionSettings settings = self.settings
int32_t i
FastReadBuffer rbuf = self.buffer
ssize_t bl
rbuf.buf = cbuf
rbuf.len = buf_len
fnum = hton.unpack_int16(rbuf.read(2))
if fnum != self.cols_num:
raise RuntimeError(
'number of columns in result ({}) is '
'different from what was described ({})'.format(
fnum, self.cols_num))
if rows_codecs is None or len(rows_codecs) < fnum:
if fnum > 0:
# It's OK to have no rows_codecs for empty records
raise RuntimeError('invalid rows_codecs')
dec_row = record.ApgRecord_New(self.cols_desc, fnum)
for i in range(fnum):
flen = hton.unpack_int32(rbuf.read(4))
if flen == -1:
val = None
else:
# Clamp buffer size to that of the reported field length
# to make sure that codecs can rely on read_all() working
# properly.
bl = rbuf.len
if flen > bl:
# Check for overflow
rbuf._raise_ins_err(flen, bl)
rbuf.len = flen
codec = <Codec>cpython.PyTuple_GET_ITEM(rows_codecs, i)
val = codec.decode(settings, rbuf)
if rbuf.len != 0:
raise BufferError(
'unexpected trailing {} bytes in buffer'.format(
rbuf.len))
rbuf.len = bl - flen
cpython.Py_INCREF(val)
record.ApgRecord_SET_ITEM(dec_row, i, val)
if rbuf.len != 0:
raise BufferError('unexpected trailing {} bytes in buffer'.format(
rbuf.len))
return dec_row
cdef _decode_parameters_desc(object desc):
cdef:
ReadBuffer reader
int16_t nparams
uint32_t p_oid
list result = []
reader = ReadBuffer.new_message_parser(desc)
nparams = reader.read_int16()
for i from 0 <= i < nparams:
p_oid = <uint32_t>reader.read_int32()
result.append(p_oid)
return result
cdef _decode_row_desc(object desc):
cdef:
ReadBuffer reader
int16_t nfields
bytes f_name
uint32_t f_table_oid
int16_t f_column_num
uint32_t f_dt_oid
int16_t f_dt_size
int32_t f_dt_mod
int16_t f_format
list result
reader = ReadBuffer.new_message_parser(desc)
nfields = reader.read_int16()
result = []
for i from 0 <= i < nfields:
f_name = reader.read_cstr()
f_table_oid = <uint32_t>reader.read_int32()
f_column_num = reader.read_int16()
f_dt_oid = <uint32_t>reader.read_int32()
f_dt_size = reader.read_int16()
f_dt_mod = reader.read_int32()
f_format = reader.read_int16()
result.append(
(f_name, f_table_oid, f_column_num, f_dt_oid,
f_dt_size, f_dt_mod, f_format))
return result