This repository has been archived by the owner on May 13, 2019. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 4
/
aiomysql.py
195 lines (156 loc) · 5.8 KB
/
aiomysql.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
"""
The :ref:`aiomysql` connector for MySQL/MariaDB databases.
"""
import asyncio
import logging
import typing
import aiomysql
import pymysql
from asyncqlio.backends.base import BaseConnector, BaseResultSet, BaseTransaction, DictRow
from asyncqlio.exc import DatabaseException, IntegrityError
logger = logging.getLogger(__name__)
# hijack aiomysql a bit
aiomysql.DictCursor.dict_type = DictRow
class AiomysqlResultSet(BaseResultSet):
"""
Represents a result set returned by the MySQL database.
"""
def __init__(self, cursor: aiomysql.DictCursor):
self.cursor = cursor
self._keys = None
@property
def keys(self):
return self._keys
async def close(self):
return await self.cursor.close()
async def fetch_row(self) -> typing.Dict[typing.Any, typing.Any]:
"""
Fetches the next row in this result set.
"""
row = await self.cursor.fetchone()
if self._keys is None and row is not None:
self._keys = row.keys()
return row
async def fetch_many(self, n: int):
"""
Fetches the next N rows.
"""
return await self.cursor.fetchmany(size=n)
async def fetch_all(self):
"""
Fetches ALL the rows.
"""
return await self.cursor.fetchall()
class AiomysqlTransaction(BaseTransaction):
"""
Represents a transaction for aiomysql.
"""
def __init__(self, connector: 'AiomysqlConnector'):
super().__init__(connector)
#: The current acquired connection for this transaction.
self.connection = None # type: aiomysql.Connection
async def close(self, *, has_error: bool = False):
"""
Closes the current connection.
"""
if has_error:
self.connection.close()
# release it back to the pool so we don't eat all the connections
self.connector.pool.release(self.connection)
async def begin(self):
"""
Begins the current transaction.
"""
self.connection = await self.connector.pool.acquire() # type: aiomysql.Connection
await self.connection.begin()
return self
async def execute(self, sql: str, params=None):
"""
Executes some SQL in the current transaction.
"""
# parse DictCursor in order to get a dict-like cursor back
# this will use the custom DictRow class passed from before
cursor = await self.connection.cursor(cursor=aiomysql.DictCursor)
# the doc lies btw
# we can pass a dict in instead of a list/tuple
# i don't fucking trust this at all though.
try:
res = await cursor.execute(sql, params)
except pymysql.err.IntegrityError as e:
raise IntegrityError(*e.args)
except (pymysql.err.ProgrammingError, pymysql.err.InternalError) as e:
raise DatabaseException(*e.args)
finally:
await cursor.close()
return res
async def cursor(self, sql: str, params: typing.Union[typing.Mapping, typing.Iterable] = None) \
-> 'AiomysqlResultSet':
"""
Returns a :class:`.AiomysqlResultSet` for the specified SQL.
"""
logger.debug("Executing query {} with params {}".format(sql, params))
cursor = await self.connection.cursor(cursor=aiomysql.DictCursor)
await cursor.execute(sql, params)
return AiomysqlResultSet(cursor)
async def rollback(self, checkpoint: str = None):
"""
Rolls back the current transaction.
:param checkpoint: Ignored.
"""
await self.connection.rollback()
async def commit(self):
"""
Commits the current transaction.
"""
await self.connection.commit()
class AiomysqlConnector(BaseConnector):
"""
A connector that uses the `aiomysql <https://github.com/aio-libs/aiomysql>`_ library.
"""
def __init__(self, dsn):
super().__init__(dsn)
#: The current connection pool for this connector.
self.pool = None # type: aiomysql.Pool
async def connect(self, *, loop: asyncio.AbstractEventLoop = None) -> 'AiomysqlConnector':
"""
Connects this connector.
"""
# aiomysql doesnt support a nice dsn
port = self.port or 3306
loop = loop or asyncio.get_event_loop()
# XXX: Force SQL mode to be ANSI.
# This means we don't break randomly, because we attempt to use ANSI when possible.
self.params['sql_mode'] = 'ansi'
logger.info("Connecting to MySQL on mysql://{}:{}/{}".format(self.host, port, self.db))
self.pool = await aiomysql.create_pool(host=self.host, user=self.username,
password=self.password, port=port,
db=self.db, loop=loop, **self.params)
return self
async def close(self, forcefully: bool = False):
"""
Closes this connector.
"""
if forcefully:
self.pool.terminate()
else:
self.pool.close()
await self.pool.wait_closed()
def get_transaction(self) -> BaseTransaction:
"""
Gets a new transaction object.
"""
return AiomysqlTransaction(self)
def emit_param(self, name: str) -> str:
if pymysql.paramstyle == "pyformat":
return "%({})s".format(name)
elif pymysql.paramstyle == "named":
return ":{}".format(name)
else:
raise ValueError("Cannot work with paramstyle {}".format(pymysql.paramstyle))
async def get_db_server_version(self):
tr = self.get_transaction()
async with tr:
cur = await tr.cursor("SELECT VERSION();")
row = await cur.fetch_row()
return row[0]
CONNECTOR_TYPE = AiomysqlConnector