Skip to content

Commit

Permalink
Merge pull request #52 from aio-libs/fix-results-seq
Browse files Browse the repository at this point in the history
Proper fix fo multiple results issure ported from pymysql
  • Loading branch information
jettify committed Jan 8, 2016
2 parents a9c029a + 3b91d38 commit 6bd3cfd
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 26 deletions.
51 changes: 27 additions & 24 deletions aiomysql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

from pymysql.charset import charset_by_name, charset_by_id
from pymysql.constants import SERVER_STATUS
from pymysql.constants.CLIENT import * # noqa
from pymysql.constants.COMMAND import * # noqa
from pymysql.constants import CLIENT
from pymysql.constants import COMMAND
from pymysql.util import byte2int, int2byte
from pymysql.converters import escape_item, encoders, decoders, escape_string
from pymysql.err import (Warning, Error,
Expand Down Expand Up @@ -177,12 +177,12 @@ def __init__(self, host="localhost", user=None, password="",
self._encoding = charset_by_name(self._charset).encoding

if local_infile:
client_flag |= LOCAL_FILES
client_flag |= CLIENT.LOCAL_FILES

client_flag |= CAPABILITIES
client_flag |= MULTI_STATEMENTS
client_flag |= CLIENT.CAPABILITIES
client_flag |= CLIENT.MULTI_STATEMENTS
if self._db:
client_flag |= CONNECT_WITH_DB
client_flag |= CLIENT.CONNECT_WITH_DB
self.client_flag = client_flag

self.cursorclass = cursorclass
Expand Down Expand Up @@ -268,7 +268,7 @@ def ensure_closed(self):
if self._writer is None:
# connection has been closed
return
send_data = struct.pack('<i', 1) + int2byte(COM_QUIT)
send_data = struct.pack('<i', 1) + int2byte(COMMAND.COM_QUIT)
self._writer.write(send_data)
yield from self._writer.drain()
self.close()
Expand Down Expand Up @@ -305,32 +305,32 @@ def _read_ok_packet(self):
def _send_autocommit_mode(self):
"""Set whether or not to commit after every execute() """
yield from self._execute_command(
COM_QUERY,
COMMAND.COM_QUERY,
"SET AUTOCOMMIT = %s" % self.escape(self.autocommit_mode))
yield from self._read_ok_packet()

@asyncio.coroutine
def begin(self):
"""Begin transaction."""
yield from self._execute_command(COM_QUERY, "BEGIN")
yield from self._execute_command(COMMAND.COM_QUERY, "BEGIN")
yield from self._read_ok_packet()

@asyncio.coroutine
def commit(self):
"""Commit changes to stable storage."""
yield from self._execute_command(COM_QUERY, "COMMIT")
yield from self._execute_command(COMMAND.COM_QUERY, "COMMIT")
yield from self._read_ok_packet()

@asyncio.coroutine
def rollback(self):
"""Roll back the current transaction."""
yield from self._execute_command(COM_QUERY, "ROLLBACK")
yield from self._execute_command(COMMAND.COM_QUERY, "ROLLBACK")
yield from self._read_ok_packet()

@asyncio.coroutine
def select_db(self, db):
"""Set current db"""
yield from self._execute_command(COM_INIT_DB, db)
yield from self._execute_command(COMMAND.COM_INIT_DB, db)
yield from self._read_ok_packet()

def escape(self, obj):
Expand Down Expand Up @@ -372,12 +372,9 @@ def cursor(self, cursor=None):
@asyncio.coroutine
def query(self, sql, unbuffered=False):
# logger.debug("DEBUG: sending query: %s", _convert_to_str(sql))
if self._result is not None and self._result.has_next:
raise ProgrammingError("Previous results have not been fetched. "
"You may not close previous cursor.")
if isinstance(sql, str):
sql = sql.encode(self.encoding, 'surrogateescape')
yield from self._execute_command(COM_QUERY, sql)
yield from self._execute_command(COMMAND.COM_QUERY, sql)
yield from self._read_query_result(unbuffered=unbuffered)
return self._affected_rows

Expand All @@ -392,7 +389,7 @@ def affected_rows(self):
@asyncio.coroutine
def kill(self, thread_id):
arg = struct.pack('<I', thread_id)
yield from self._execute_command(COM_PROCESS_KILL, arg)
yield from self._execute_command(COMMAND.COM_PROCESS_KILL, arg)
yield from self._read_ok_packet()

@asyncio.coroutine
Expand All @@ -405,7 +402,7 @@ def ping(self, reconnect=True):
else:
raise Error("Already closed")
try:
yield from self._execute_command(COM_PING, "")
yield from self._execute_command(COMMAND.COM_PING, "")
yield from self._read_ok_packet()
except Exception:
if reconnect:
Expand All @@ -419,7 +416,7 @@ def set_charset(self, charset):
"""Sets the character set for the current connection"""
# Make sure charset is supported.
encoding = charset_by_name(charset).encoding
yield from self._execute_command(COM_QUERY, "SET NAMES %s"
yield from self._execute_command(COMMAND.COM_QUERY, "SET NAMES %s"
% self.escape(charset))
yield from self._read_packet()
self._charset = charset
Expand Down Expand Up @@ -550,8 +547,13 @@ def _execute_command(self, command, sql):

# If the last query was unbuffered, make sure it finishes before
# sending new commands
if self._result is not None and self._result.unbuffered_active:
yield from self._result._finish_unbuffered_query()
if self._result is not None:
if self._result.unbuffered_active:
warnings.warn("Previous unbuffered result was left incomplete")
self._result._finish_unbuffered_query()
while self._result.has_next:
yield from self.next_result()
self._result = None

if isinstance(sql, str):
sql = sql.encode(self._encoding)
Expand Down Expand Up @@ -579,9 +581,9 @@ def _execute_command(self, command, sql):

@asyncio.coroutine
def _request_authentication(self):
self.client_flag |= CAPABILITIES
self.client_flag |= CLIENT.CAPABILITIES
if int(self.server_version.split('.', 1)[0]) >= 5:
self.client_flag |= MULTI_RESULTS
self.client_flag |= CLIENT.MULTI_RESULTS

if self._user is None:
raise ValueError("Did not specify a username")
Expand Down Expand Up @@ -780,7 +782,8 @@ def _read_load_local_packet(self, first_packet):

@asyncio.coroutine
def _print_warnings(self):
yield from self.connection._execute_command(COM_QUERY, 'SHOW WARNINGS')
yield from self.connection._execute_command(
COMMAND.COM_QUERY, 'SHOW WARNINGS')
yield from self.read()
if self.rows:
message = "\n"
Expand Down
15 changes: 13 additions & 2 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,5 +250,16 @@ def test_previous_cursor_not_closed(self):
cur1 = yield from conn.cursor()
yield from cur1.execute("SELECT 1; SELECT 2")
cur2 = yield from conn.cursor()
with self.assertRaises(aiomysql.ProgrammingError):
yield from cur2.execute("SELECT 3")
yield from cur2.execute("SELECT 3;")
resp = yield from cur2.fetchone()
self.assertEqual(resp[0], 3)

@run_until_complete
def test_commit_during_multi_result(self):
conn = yield from self.connect()
cur = yield from conn.cursor()
yield from cur.execute("SELECT 1; SELECT 2;")
yield from conn.commit()
yield from cur.execute("SELECT 3;")
resp = yield from cur.fetchone()
self.assertEqual(resp[0], 3)
2 changes: 2 additions & 0 deletions tests/test_sscursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def test_sscursor_fetchmany(self):
self.assertEqual(len(fetched_data), 2,
'fetchmany failed. Number of rows does not match')

yield from cursor.close()
# test default fetchmany size
cursor = yield from conn.cursor(SSCursor)
yield from cursor.execute('SELECT * FROM tz_data;')
fetched_data = yield from cursor.fetchmany()
self.assertEqual(len(fetched_data), 1)
Expand Down

0 comments on commit 6bd3cfd

Please sign in to comment.