Skip to content

Commit 348dbf8

Browse files
committed
Implement connection.transaction()
1 parent 6c4883d commit 348dbf8

File tree

2 files changed

+218
-0
lines changed

2 files changed

+218
-0
lines changed

asyncpg/__init__.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import enum
23
import getpass
34
import os
45

@@ -12,16 +13,27 @@
1213

1314

1415
class Connection:
16+
17+
__slots__ = ('_protocol', '_transport', '_loop', '_types_stmt',
18+
'_type_by_name_stmt', '_top_xact', '_uid')
19+
1520
def __init__(self, protocol, transport, loop):
1621
self._protocol = protocol
1722
self._transport = transport
1823
self._loop = loop
1924
self._types_stmt = None
2025
self._type_by_name_stmt = None
26+
self._top_xact = None
27+
self._uid = 0
2128

2229
def get_settings(self):
2330
return self._protocol.get_settings()
2431

32+
def transaction(self, *, isolation='read_committed', readonly=False,
33+
deferrable=False):
34+
35+
return Transaction(self, isolation, readonly, deferrable)
36+
2537
async def execute_script(self, script):
2638
await self._protocol.query(script)
2739

@@ -82,8 +94,15 @@ async def set_type_codec(self, typename, *,
8294
def close(self):
8395
self._transport.close()
8496

97+
def _get_unique_id(self):
98+
self._uid += 1
99+
return 'id{}'.format(self._uid)
100+
85101

86102
class PreparedStatement:
103+
104+
__slots__ = ('_connection', '_state')
105+
87106
def __init__(self, connection, state):
88107
self._connection = connection
89108
self._state = state
@@ -99,6 +118,133 @@ async def execute(self, *args):
99118
return await protocol.execute(self._state, args)
100119

101120

121+
class TransactionState(enum.Enum):
122+
NEW = 0
123+
STARTED = 1
124+
COMMITTED = 2
125+
ROLLEDBACK = 3
126+
FAILED = 4
127+
128+
129+
class Transaction:
130+
131+
ISOLATION_LEVELS = {'read_committed', 'serializable', 'repeatable_read'}
132+
133+
__slots__ = ('_connection', '_isolation', '_readonly', '_deferrable',
134+
'_state', '_nested', '_id')
135+
136+
def __init__(self, connection, isolation, readonly, deferrable):
137+
if isolation not in self.ISOLATION_LEVELS:
138+
raise ValueError(
139+
'isolation is expected to be either of {}, '
140+
'got {!r}'.format(self.ISOLATION_LEVELS, isolation))
141+
142+
if isolation != 'serializable':
143+
if readonly:
144+
raise ValueError(
145+
'"readonly" is only supported for '
146+
'serializable transactions')
147+
148+
if deferrable and not readonly:
149+
raise ValueError(
150+
'"deferrable" is only supported for '
151+
'serializable readonly transactions')
152+
153+
self._connection = connection
154+
self._isolation = isolation
155+
self._readonly = readonly
156+
self._deferrable = deferrable
157+
self._state = TransactionState.NEW
158+
self._nested = False
159+
self._id = None
160+
161+
async def __aenter__(self):
162+
await self.start()
163+
164+
async def __aexit__(self, extype, ex, tb):
165+
if extype is not None:
166+
await self.rollback()
167+
168+
async def start(self):
169+
if self._state is not TransactionState.NEW:
170+
raise FatalError('cannot start transaction: inconsistent state')
171+
172+
con = self._connection
173+
174+
if con._top_xact is None:
175+
con._top_xact = self
176+
else:
177+
# Nested transaction block
178+
top_xact = con._top_xact
179+
if self._isolation != top_xact._isolation:
180+
raise FatalError(
181+
'nested transaction has different isolation level: '
182+
'current {!r} != outer {!r}'.format(
183+
self._isolation, top_xact._isolation))
184+
self._nested = True
185+
186+
if self._nested:
187+
self._id = con._get_unique_id()
188+
query = 'SAVEPOINT {};'.format(self._id)
189+
else:
190+
if self._isolation == 'read_committed':
191+
query = 'BEGIN;'
192+
elif self._isolation == 'repeatable_read':
193+
query = 'BEGIN ISOLATION LEVEL REPEATABLE READ;'
194+
else:
195+
query = 'BEGIN ISOLATION LEVEL SERIALIZABLE'
196+
if self._readonly:
197+
query += ' READ ONLY'
198+
if self._deferrable:
199+
query += ' DEFERRABLE'
200+
query += ';'
201+
202+
try:
203+
await self._connection.execute_script(query)
204+
except:
205+
self._state = TransactionState.FAILED
206+
raise
207+
else:
208+
self._state = TransactionState.STARTED
209+
210+
async def commit(self):
211+
if self._state is not TransactionState.STARTED:
212+
raise FatalError('cannot commit transaction: inconsistent state')
213+
214+
if self._nested:
215+
query = 'RELEASE SAVEPOINT {};'.format(self._id)
216+
else:
217+
query = 'COMMIT;'
218+
219+
try:
220+
await self._connection.execute_script(query)
221+
except:
222+
self._state = TransactionState.FAILED
223+
raise
224+
else:
225+
self._state = TransactionState.COMMITTED
226+
227+
async def rollback(self):
228+
if self._connection._top_xact is self:
229+
self._connection._top_xact = None
230+
231+
if self._state is not TransactionState.STARTED:
232+
raise FatalError('cannot rollback transaction: inconsistent state')
233+
234+
if self._nested:
235+
query = 'ROLLBACK TO {};'.format(self._id)
236+
else:
237+
query = 'ROLLBACK;'
238+
239+
try:
240+
await self._connection.execute_script(query)
241+
except:
242+
self._state = TransactionState.FAILED
243+
raise
244+
else:
245+
self._state = TransactionState.ROLLEDBACK
246+
247+
102248
async def connect(iri=None, *,
103249
host=None, port=None,
104250
user=None, password=None,

tests/test_transaction.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import asyncpg
2+
3+
from asyncpg import _testbase as tb
4+
5+
6+
class TestTransaction(tb.ConnectedTestCase):
7+
8+
async def test_transaction_regular(self):
9+
10+
try:
11+
async with self.con.transaction():
12+
13+
await self.con.execute_script('''
14+
CREATE TABLE mytab (a int);
15+
''')
16+
17+
1 / 0
18+
19+
except ZeroDivisionError:
20+
pass
21+
else:
22+
self.fail('ZeroDivisionError was not raised')
23+
24+
with self.assertRaisesRegex(asyncpg.Error, '"mytab" does not exist'):
25+
await self.con.prepare('''
26+
SELECT * FROM mytab
27+
''')
28+
29+
async def test_transaction_nested(self):
30+
31+
try:
32+
async with self.con.transaction():
33+
34+
await self.con.execute_script('''
35+
CREATE TABLE mytab (a int);
36+
''')
37+
38+
async with self.con.transaction():
39+
40+
await self.con.execute_script('''
41+
INSERT INTO mytab (a) VALUES (1), (2);
42+
''')
43+
44+
try:
45+
async with self.con.transaction():
46+
47+
await self.con.execute_script('''
48+
INSERT INTO mytab (a) VALUES (3), (4);
49+
''')
50+
51+
1 / 0
52+
except ZeroDivisionError:
53+
pass
54+
else:
55+
self.fail('ZeroDivisionError was not raised')
56+
57+
res = await self.con.execute('SELECT * FROM mytab;')
58+
self.assertEqual(len(res), 2)
59+
self.assertEqual(res[0][0], 1)
60+
self.assertEqual(res[1][0], 2)
61+
62+
1 / 0
63+
64+
except ZeroDivisionError:
65+
pass
66+
else:
67+
self.fail('ZeroDivisionError was not raised')
68+
69+
with self.assertRaisesRegex(asyncpg.Error, '"mytab" does not exist'):
70+
await self.con.prepare('''
71+
SELECT * FROM mytab
72+
''')

0 commit comments

Comments
 (0)