Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/MySQLdb/cursors.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class BaseCursor:

def __init__(self, connection):
self.connection = connection
self.warning_count = 0
self.description = None
self.description_flags = None
self.rowcount = 0
Expand Down Expand Up @@ -140,6 +141,7 @@ def _do_get_result(self, db):
self.description = result.describe()
self.description_flags = result.field_flags()

self.warning_count = db.warning_count()
self.rowcount = db.affected_rows()
self.rownumber = 0
self.lastrowid = db.insert_id()
Expand Down Expand Up @@ -325,6 +327,7 @@ def callproc(self, procname, args=()):
def _query(self, q):
db = self._get_db()
self._result = None
self.warning_count = 0
self.rowcount = None
self.lastrowid = None
db.query(q)
Expand Down Expand Up @@ -435,6 +438,7 @@ def fetchone(self):
self._check_executed()
r = self._fetch_row(1)
if not r:
self.warning_count = self._get_db().warning_count()
return None
self.rownumber = self.rownumber + 1
return r[0]
Expand All @@ -443,14 +447,18 @@ def fetchmany(self, size=None):
"""Fetch up to size rows from the cursor. Result set may be smaller
than size. If size is not defined, cursor.arraysize is used."""
self._check_executed()
r = self._fetch_row(size or self.arraysize)
size = size or self.arraysize
r = self._fetch_row(size)
if len(r) < size:
self.warning_count = self._get_db().warning_count()
self.rownumber = self.rownumber + len(r)
return r

def fetchall(self):
"""Fetches all available rows from the cursor."""
self._check_executed()
r = self._fetch_row(0)
self.warning_count = self._get_db().warning_count()
self.rownumber = self.rownumber + len(r)
return r

Expand Down
41 changes: 41 additions & 0 deletions tests/test_cursor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import MySQLdb.cursors
from MySQLdb.constants import ER
from configdb import connection_factory


Expand Down Expand Up @@ -243,3 +244,43 @@ def test_binary_prefix():
"INSERT INTO test_binary_prefix (id, json) VALUES (%(id)s, %(json)s)",
({"id": 1, "json": "{}"}, {"id": 2, "json": "{}"}),
)


def test_warning_count():
conn = connect()
cursor = conn.cursor()

cursor.execute("DROP TABLE IF EXISTS `no_exists_table`")
assert cursor.warning_count == 1

cursor.execute("SHOW WARNINGS")
warning = cursor.fetchone()
assert warning[1] == ER.BAD_TABLE_ERROR
assert "no_exists_table" in warning[2]

cursor.execute("SELECT 1")
assert cursor.warning_count == 0


def test_sscursor_warning_count():
conn = connect()
cursor = conn.cursor(MySQLdb.cursors.SSCursor)

cursor.execute("DROP TABLE IF EXISTS `no_exists_table`")
assert cursor.warning_count == 1

cursor.execute("SHOW WARNINGS")
warning = cursor.fetchone()
assert warning[1] == ER.BAD_TABLE_ERROR
assert "no_exists_table" in warning[2]
assert cursor.fetchone() is None

cursor.execute("SELECT 1")
assert cursor.fetchone() == (1,)
assert cursor.fetchone() is None
assert cursor.warning_count == 0

cursor.execute("SELECT CAST('abc' AS SIGNED)")
rows = cursor.fetchmany(2)
assert len(rows) == 1
assert cursor.warning_count == 1
Loading