Skip to content

Commit cc73b6f

Browse files
committed
Add rudimentary QueryResult; tweak API
1. Old Connection.execute() is gone; 2. Connection.execute_script() is renamed to Connection.execute(); 3. PreparedStatement.execute() is gone; new way of how one should asynchronously iterate through results is to use 'async for' on the prepared statement: st = await con.prepare('SELECT ...') async for row in st(arg1, arg2): print(row) 4. New methods: PreparedStatement.get_value(*args, column=0) and PreparedStatement.get_first_row(*args) 5. New method PreparedStatement.get_list() returns all query results in one Python list.
1 parent 12d855f commit cc73b6f

File tree

9 files changed

+153
-109
lines changed

9 files changed

+153
-109
lines changed

asyncpg/__init__.py

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55

66
from . import introspection as _intro
77
from .exceptions import *
8-
from .transaction import Transaction
8+
from .prepared_stmt import PreparedStatement
99
from .protocol import Protocol
10+
from .transaction import Transaction
1011

1112

1213
__all__ = ('connect',) + exceptions.__all__
@@ -34,13 +35,9 @@ def transaction(self, *, isolation='read_committed', readonly=False,
3435

3536
return Transaction(self, isolation, readonly, deferrable)
3637

37-
async def execute_script(self, script):
38+
async def execute(self, script):
3839
await self._protocol.query(script)
3940

40-
async def execute(self, query, *args):
41-
stmt = await self._prepare('', query)
42-
return await stmt.execute(*args)
43-
4441
async def prepare(self, query):
4542
return await self._prepare(None, query)
4643

@@ -53,7 +50,7 @@ async def _prepare(self, name, query):
5350
self._types_stmt = await self.prepare(
5451
_intro.INTRO_LOOKUP_TYPES)
5552

56-
types = await self._types_stmt.execute(list(ready))
53+
types = await self._types_stmt.get_list(list(ready))
5754
self._protocol._add_types(types)
5855

5956
return PreparedStatement(self, state)
@@ -76,10 +73,10 @@ async def set_type_codec(self, typename, *,
7673
if self._type_by_name_stmt is None:
7774
self._type_by_name_stmt = await self.prepare(_intro.TYPE_BY_NAME)
7875

79-
typeinfo = await self._type_by_name_stmt.execute(typename, schema)
76+
typeinfo = await self._type_by_name_stmt.get_first_row(
77+
typename, schema)
8078
if not typeinfo:
8179
raise ValueError('unknown type: {}.{}'.format(schema, typename))
82-
typeinfo = list(typeinfo)[0]
8380

8481
oid = typeinfo['oid']
8582
if typeinfo['kind'] != b'b' or typeinfo['elemtype']:
@@ -99,25 +96,6 @@ def _get_unique_id(self):
9996
return 'id{}'.format(self._uid)
10097

10198

102-
class PreparedStatement:
103-
104-
__slots__ = ('_connection', '_state')
105-
106-
def __init__(self, connection, state):
107-
self._connection = connection
108-
self._state = state
109-
110-
def get_parameters(self):
111-
return self._state._get_parameters()
112-
113-
def get_attributes(self):
114-
return self._state._get_attributes()
115-
116-
async def execute(self, *args):
117-
protocol = self._connection._protocol
118-
return await protocol.execute(self._state, args)
119-
120-
12199
async def connect(iri=None, *,
122100
host=None, port=None,
123101
user=None, password=None,

asyncpg/compat.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import functools
2+
import sys
3+
4+
5+
if sys.version_info < (3, 5, 2):
6+
def aiter_compat(func):
7+
@functools.wraps(func)
8+
async def wrapper(self):
9+
return func(self)
10+
return wrapper
11+
else:
12+
def aiter_compat(func):
13+
return func

asyncpg/prepared_stmt.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from . import compat
2+
3+
4+
class PreparedStatement:
5+
6+
__slots__ = ('_connection', '_state')
7+
8+
def __init__(self, connection, state):
9+
self._connection = connection
10+
self._state = state
11+
12+
def get_parameters(self):
13+
return self._state._get_parameters()
14+
15+
def get_attributes(self):
16+
return self._state._get_attributes()
17+
18+
def __call__(self, *args):
19+
return PreparedStatementIterator(self, args)
20+
21+
async def get_list(self, *args):
22+
protocol = self._connection._protocol
23+
data = await protocol.execute(self._state, args)
24+
if data is None:
25+
data = []
26+
return data
27+
28+
async def get_value(self, *args, column=0):
29+
protocol = self._connection._protocol
30+
data = await protocol.execute(self._state, args)
31+
if data is None:
32+
return None
33+
return data[0][column]
34+
35+
async def get_first_row(self, *args):
36+
protocol = self._connection._protocol
37+
data = await protocol.execute(self._state, args)
38+
if data is None:
39+
return None
40+
return data[0]
41+
42+
43+
class PreparedStatementIterator:
44+
45+
__slots__ = ('_stmt', '_args', '_iter')
46+
47+
def __init__(self, stmt, args):
48+
self._stmt = stmt
49+
self._args = args
50+
self._iter = None
51+
52+
@compat.aiter_compat
53+
def __aiter__(self):
54+
return self
55+
56+
async def __anext__(self):
57+
if self._iter is None:
58+
protocol = self._stmt._connection._protocol
59+
data = await protocol.execute(self._stmt._state, self._args)
60+
if data is None:
61+
data = ()
62+
self._iter = iter(data)
63+
64+
try:
65+
return next(self._iter)
66+
except StopIteration:
67+
raise StopAsyncIteration() from None

asyncpg/transaction.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ async def start(self):
8686
query += ';'
8787

8888
try:
89-
await self._connection.execute_script(query)
89+
await self._connection.execute(query)
9090
except:
9191
self._state = TransactionState.FAILED
9292
raise
@@ -104,7 +104,7 @@ async def commit(self):
104104
query = 'COMMIT;'
105105

106106
try:
107-
await self._connection.execute_script(query)
107+
await self._connection.execute(query)
108108
except:
109109
self._state = TransactionState.FAILED
110110
raise
@@ -125,7 +125,7 @@ async def rollback(self):
125125
query = 'ROLLBACK;'
126126

127127
try:
128-
await self._connection.execute_script(query)
128+
await self._connection.execute(query)
129129
except:
130130
self._state = TransactionState.FAILED
131131
raise

tests/test_codecs.py

Lines changed: 16 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ async def test_standard_codecs(self):
224224

225225
for sample in sample_data:
226226
with self.subTest(sample=sample, typname=typname):
227-
rsample = list(await st.execute(sample))[0][0]
227+
rsample = await st.get_value(sample)
228228
self.assertEqual(
229229
rsample, sample,
230230
("failed to return {} object data as-is; "
@@ -238,16 +238,14 @@ async def test_standard_codecs(self):
238238
async def test_composites(self):
239239
"""Test encoding/decoding of composite types"""
240240

241-
st = await self.con.prepare('''
241+
await self.con.execute('''
242242
CREATE TYPE test_composite AS (
243243
a int,
244244
b text,
245245
c int[]
246246
)
247247
''')
248248

249-
await st.execute()
250-
251249
try:
252250
st = await self.con.prepare('''
253251
SELECT ROW(
@@ -257,7 +255,8 @@ async def test_composites(self):
257255
)::test_composite AS test
258256
''')
259257

260-
res = list(await st.execute())[0][0]
258+
res = await st.get_list()
259+
res = res[0]['test']
261260

262261
self.assertIsNone(res['a'])
263262
self.assertEqual(res['b'], '5678')
@@ -274,38 +273,31 @@ async def test_composites(self):
274273
self.assertEqual(at[0].type.kind, 'composite')
275274

276275
finally:
277-
st = await self.con.prepare('DROP TYPE test_composite')
278-
await st.execute()
276+
await self.con.execute('DROP TYPE test_composite')
279277

280278
async def test_domains(self):
281279
"""Test encoding/decoding of composite types"""
282280

283-
st = await self.con.prepare('''
281+
await self.con.execute('''
284282
CREATE DOMAIN my_dom AS int
285283
''')
286284

287-
await st.execute()
288-
289-
st = await self.con.prepare('''
285+
await self.con.execute('''
290286
CREATE DOMAIN my_dom2 AS my_dom
291287
''')
292288

293-
await st.execute()
294-
295289
try:
296290
st = await self.con.prepare('''
297291
SELECT 3::my_dom2
298292
''')
299-
300-
res = list(await st.execute())[0][0]
293+
res = await st.get_value()
301294

302295
self.assertEqual(res, 3)
303296

304297
st = await self.con.prepare('''
305298
SELECT NULL::my_dom2
306299
''')
307-
308-
res = list(await st.execute())[0][0]
300+
res = await st.get_value()
309301

310302
self.assertIsNone(res)
311303

@@ -316,20 +308,16 @@ async def test_domains(self):
316308
self.assertEqual(at[0].type.kind, 'scalar')
317309

318310
finally:
319-
st = await self.con.prepare('DROP DOMAIN my_dom2')
320-
await st.execute()
321-
st = await self.con.prepare('DROP DOMAIN my_dom')
322-
await st.execute()
311+
await self.con.execute('DROP DOMAIN my_dom2')
312+
await self.con.execute('DROP DOMAIN my_dom')
323313

324314
async def test_custom_codec_text(self):
325315
"""Test encoding/decoding using a custom codec in text mode"""
326316

327-
st = await self.con.prepare('''
317+
await self.con.execute('''
328318
CREATE EXTENSION IF NOT EXISTS hstore
329319
''')
330320

331-
await st.execute()
332-
333321
def hstore_decoder(data):
334322
result = {}
335323
items = data.split(',')
@@ -349,7 +337,8 @@ def hstore_encoder(obj):
349337
SELECT $1::hstore AS result
350338
''')
351339

352-
res = list(await st.execute({'ham': 'spam'}))[0]['result']
340+
res = await st.get_first_row({'ham': 'spam'})
341+
res = res['result']
353342

354343
self.assertEqual(res, {'ham': 'spam'})
355344

@@ -371,21 +360,17 @@ def hstore_encoder(obj):
371360
await self.con.set_type_codec('_hstore', encoder=hstore_encoder,
372361
decoder=hstore_decoder)
373362

374-
st = await self.con.prepare('''
363+
await self.con.execute('''
375364
CREATE TYPE mytype AS (a int);
376365
''')
377366

378-
await st.execute()
379-
380367
try:
381368
err = 'cannot use custom codec on non-scalar type public.mytype'
382369
with self.assertRaisesRegex(ValueError, err):
383370
await self.con.set_type_codec(
384371
'mytype', encoder=hstore_encoder,
385372
decoder=hstore_decoder)
386373
finally:
387-
st = await self.con.prepare('''
374+
await self.con.execute('''
388375
DROP TYPE mytype;
389376
''')
390-
391-
await st.execute()

tests/test_execute.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,27 @@
1+
import asyncpg
12
from asyncpg import _testbase as tb
23

34

4-
class TestExecute(tb.ConnectedTestCase):
5+
class TestExecuteScript(tb.ConnectedTestCase):
56

6-
async def test_execute_1(self):
7-
r = await self.con.execute('SELECT $1::smallint', 10)
8-
self.assertEqual(r[0][0], 10)
7+
async def test_execute_script_1(self):
8+
r = await self.con.execute('''
9+
SELECT 1;
910
10-
r = await self.con.execute('SELECT $1::smallint * 2', 10)
11-
self.assertEqual(r[0][0], 20)
11+
SELECT true FROM pg_type WHERE false = true;
1212
13-
async def test_execute_unknownoid(self):
14-
r = await self.con.execute("SELECT 'test'")
15-
self.assertEqual(r[0][0], 'test')
13+
SELECT 2;
14+
''')
15+
self.assertIsNone(r)
16+
17+
async def test_execute_script_check_transactionality(self):
18+
with self.assertRaises(asyncpg.Error):
19+
await self.con.execute('''
20+
CREATE TABLE mytab (a int);
21+
SELECT * FROM mytab WHERE 1 / 0 = 1;
22+
''')
23+
24+
with self.assertRaisesRegex(asyncpg.Error, '"mytab" does not exist'):
25+
await self.con.prepare('''
26+
SELECT * FROM mytab
27+
''')

tests/test_execute_script.py

Lines changed: 0 additions & 27 deletions
This file was deleted.

0 commit comments

Comments
 (0)