diff --git a/src/MySQLdb/cursors.py b/src/MySQLdb/cursors.py index 70fbeea4..c5fdf0c7 100644 --- a/src/MySQLdb/cursors.py +++ b/src/MySQLdb/cursors.py @@ -64,6 +64,7 @@ class BaseCursor: def __init__(self, connection): self.connection = connection + self.warning_count = 0 self.description = None self.description_flags = None self.rowcount = 0 @@ -140,6 +141,7 @@ def _do_get_result(self, db): self.description = result.describe() self.description_flags = result.field_flags() + self.warning_count = db.warning_count() self.rowcount = db.affected_rows() self.rownumber = 0 self.lastrowid = db.insert_id() @@ -325,6 +327,7 @@ def callproc(self, procname, args=()): def _query(self, q): db = self._get_db() self._result = None + self.warning_count = 0 self.rowcount = None self.lastrowid = None db.query(q) @@ -435,6 +438,7 @@ def fetchone(self): self._check_executed() r = self._fetch_row(1) if not r: + self.warning_count = self._get_db().warning_count() return None self.rownumber = self.rownumber + 1 return r[0] @@ -443,7 +447,10 @@ def fetchmany(self, size=None): """Fetch up to size rows from the cursor. Result set may be smaller than size. If size is not defined, cursor.arraysize is used.""" self._check_executed() - r = self._fetch_row(size or self.arraysize) + size = size or self.arraysize + r = self._fetch_row(size) + if len(r) < size: + self.warning_count = self._get_db().warning_count() self.rownumber = self.rownumber + len(r) return r @@ -451,6 +458,7 @@ def fetchall(self): """Fetches all available rows from the cursor.""" self._check_executed() r = self._fetch_row(0) + self.warning_count = self._get_db().warning_count() self.rownumber = self.rownumber + len(r) return r diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 1d2c3655..65cb236d 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -1,5 +1,6 @@ import pytest import MySQLdb.cursors +from MySQLdb.constants import ER from configdb import connection_factory @@ -243,3 +244,43 @@ def test_binary_prefix(): "INSERT INTO test_binary_prefix (id, json) VALUES (%(id)s, %(json)s)", ({"id": 1, "json": "{}"}, {"id": 2, "json": "{}"}), ) + + +def test_warning_count(): + conn = connect() + cursor = conn.cursor() + + cursor.execute("DROP TABLE IF EXISTS `no_exists_table`") + assert cursor.warning_count == 1 + + cursor.execute("SHOW WARNINGS") + warning = cursor.fetchone() + assert warning[1] == ER.BAD_TABLE_ERROR + assert "no_exists_table" in warning[2] + + cursor.execute("SELECT 1") + assert cursor.warning_count == 0 + + +def test_sscursor_warning_count(): + conn = connect() + cursor = conn.cursor(MySQLdb.cursors.SSCursor) + + cursor.execute("DROP TABLE IF EXISTS `no_exists_table`") + assert cursor.warning_count == 1 + + cursor.execute("SHOW WARNINGS") + warning = cursor.fetchone() + assert warning[1] == ER.BAD_TABLE_ERROR + assert "no_exists_table" in warning[2] + assert cursor.fetchone() is None + + cursor.execute("SELECT 1") + assert cursor.fetchone() == (1,) + assert cursor.fetchone() is None + assert cursor.warning_count == 0 + + cursor.execute("SELECT CAST('abc' AS SIGNED)") + rows = cursor.fetchmany(2) + assert len(rows) == 1 + assert cursor.warning_count == 1