11import asyncio
2+ import enum
23import getpass
34import os
45
1213
1314
1415class Connection :
16+
17+ __slots__ = ('_protocol' , '_transport' , '_loop' , '_types_stmt' ,
18+ '_type_by_name_stmt' , '_top_xact' , '_uid' )
19+
1520 def __init__ (self , protocol , transport , loop ):
1621 self ._protocol = protocol
1722 self ._transport = transport
1823 self ._loop = loop
1924 self ._types_stmt = None
2025 self ._type_by_name_stmt = None
26+ self ._top_xact = None
27+ self ._uid = 0
2128
2229 def get_settings (self ):
2330 return self ._protocol .get_settings ()
2431
32+ def transaction (self , * , isolation = 'read_committed' , readonly = False ,
33+ deferrable = False ):
34+
35+ return Transaction (self , isolation , readonly , deferrable )
36+
2537 async def execute_script (self , script ):
2638 await self ._protocol .query (script )
2739
@@ -82,8 +94,15 @@ async def set_type_codec(self, typename, *,
8294 def close (self ):
8395 self ._transport .close ()
8496
97+ def _get_unique_id (self ):
98+ self ._uid += 1
99+ return 'id{}' .format (self ._uid )
100+
85101
86102class PreparedStatement :
103+
104+ __slots__ = ('_connection' , '_state' )
105+
87106 def __init__ (self , connection , state ):
88107 self ._connection = connection
89108 self ._state = state
@@ -99,6 +118,133 @@ async def execute(self, *args):
99118 return await protocol .execute (self ._state , args )
100119
101120
121+ class TransactionState (enum .Enum ):
122+ NEW = 0
123+ STARTED = 1
124+ COMMITTED = 2
125+ ROLLEDBACK = 3
126+ FAILED = 4
127+
128+
129+ class Transaction :
130+
131+ ISOLATION_LEVELS = {'read_committed' , 'serializable' , 'repeatable_read' }
132+
133+ __slots__ = ('_connection' , '_isolation' , '_readonly' , '_deferrable' ,
134+ '_state' , '_nested' , '_id' )
135+
136+ def __init__ (self , connection , isolation , readonly , deferrable ):
137+ if isolation not in self .ISOLATION_LEVELS :
138+ raise ValueError (
139+ 'isolation is expected to be either of {}, '
140+ 'got {!r}' .format (self .ISOLATION_LEVELS , isolation ))
141+
142+ if isolation != 'serializable' :
143+ if readonly :
144+ raise ValueError (
145+ '"readonly" is only supported for '
146+ 'serializable transactions' )
147+
148+ if deferrable and not readonly :
149+ raise ValueError (
150+ '"deferrable" is only supported for '
151+ 'serializable readonly transactions' )
152+
153+ self ._connection = connection
154+ self ._isolation = isolation
155+ self ._readonly = readonly
156+ self ._deferrable = deferrable
157+ self ._state = TransactionState .NEW
158+ self ._nested = False
159+ self ._id = None
160+
161+ async def __aenter__ (self ):
162+ await self .start ()
163+
164+ async def __aexit__ (self , extype , ex , tb ):
165+ if extype is not None :
166+ await self .rollback ()
167+
168+ async def start (self ):
169+ if self ._state is not TransactionState .NEW :
170+ raise FatalError ('cannot start transaction: inconsistent state' )
171+
172+ con = self ._connection
173+
174+ if con ._top_xact is None :
175+ con ._top_xact = self
176+ else :
177+ # Nested transaction block
178+ top_xact = con ._top_xact
179+ if self ._isolation != top_xact ._isolation :
180+ raise FatalError (
181+ 'nested transaction has different isolation level: '
182+ 'current {!r} != outer {!r}' .format (
183+ self ._isolation , top_xact ._isolation ))
184+ self ._nested = True
185+
186+ if self ._nested :
187+ self ._id = con ._get_unique_id ()
188+ query = 'SAVEPOINT {};' .format (self ._id )
189+ else :
190+ if self ._isolation == 'read_committed' :
191+ query = 'BEGIN;'
192+ elif self ._isolation == 'repeatable_read' :
193+ query = 'BEGIN ISOLATION LEVEL REPEATABLE READ;'
194+ else :
195+ query = 'BEGIN ISOLATION LEVEL SERIALIZABLE'
196+ if self ._readonly :
197+ query += ' READ ONLY'
198+ if self ._deferrable :
199+ query += ' DEFERRABLE'
200+ query += ';'
201+
202+ try :
203+ await self ._connection .execute_script (query )
204+ except :
205+ self ._state = TransactionState .FAILED
206+ raise
207+ else :
208+ self ._state = TransactionState .STARTED
209+
210+ async def commit (self ):
211+ if self ._state is not TransactionState .STARTED :
212+ raise FatalError ('cannot commit transaction: inconsistent state' )
213+
214+ if self ._nested :
215+ query = 'RELEASE SAVEPOINT {};' .format (self ._id )
216+ else :
217+ query = 'COMMIT;'
218+
219+ try :
220+ await self ._connection .execute_script (query )
221+ except :
222+ self ._state = TransactionState .FAILED
223+ raise
224+ else :
225+ self ._state = TransactionState .COMMITTED
226+
227+ async def rollback (self ):
228+ if self ._connection ._top_xact is self :
229+ self ._connection ._top_xact = None
230+
231+ if self ._state is not TransactionState .STARTED :
232+ raise FatalError ('cannot rollback transaction: inconsistent state' )
233+
234+ if self ._nested :
235+ query = 'ROLLBACK TO {};' .format (self ._id )
236+ else :
237+ query = 'ROLLBACK;'
238+
239+ try :
240+ await self ._connection .execute_script (query )
241+ except :
242+ self ._state = TransactionState .FAILED
243+ raise
244+ else :
245+ self ._state = TransactionState .ROLLEDBACK
246+
247+
102248async def connect (iri = None , * ,
103249 host = None , port = None ,
104250 user = None , password = None ,
0 commit comments