diff --git a/MySQLdb/connection.py b/MySQLdb/connection.py index 96e53a6..f58fa9d 100644 --- a/MySQLdb/connection.py +++ b/MySQLdb/connection.py @@ -27,7 +27,8 @@ class Connection(object): def __init__(self, host=None, user=None, passwd=None, db=None, port=0, client_flag=0, charset=None, init_command=None, connect_timeout=None, - sql_mode=None, encoders=None, decoders=None, use_unicode=True): + sql_mode=None, encoders=None, decoders=None, use_unicode=True, + conv=None): self._db = libmysql.c.mysql_init(None) @@ -54,8 +55,13 @@ def __init__(self, host=None, user=None, passwd=None, db=None, port=0, encoders = converters.DEFAULT_ENCODERS if decoders is None: decoders = converters.DEFAULT_DECODERS - self.encoders = encoders - self.decoders = decoders + self.real_encoders = encoders + self.real_decoders = decoders + + # MySQLdb compatibility + if conv is None: + conv = converters.conversions + self.encoders = dict(conv.iteritems()) if charset is not None: res = libmysql.c.mysql_set_character_set(self._db, charset) @@ -123,9 +129,9 @@ def cursor(self, cursor_class=None, encoders=None, decoders=None): if cursor_class is None: cursor_class = cursors.Cursor if encoders is None: - encoders = self.encoders[:] + encoders = self.real_encoders[:] if decoders is None: - decoders = self.decoders[:] + decoders = self.real_decoders[:] return cursor_class(self, encoders=encoders, decoders=decoders) def string_literal(self, obj): @@ -143,5 +149,8 @@ def get_server_info(self): self._check_closed() return libmysql.c.mysql_get_server_info(self._db) + def ping(self): + return libmysql.c.mysql_ping(self._db) + def connect(*args, **kwargs): - return Connection(*args, **kwargs) \ No newline at end of file + return Connection(*args, **kwargs) diff --git a/MySQLdb/converters.py b/MySQLdb/converters.py index c17cbc7..50c0239 100644 --- a/MySQLdb/converters.py +++ b/MySQLdb/converters.py @@ -30,6 +30,7 @@ def datetime_encoder(connection, obj): int: literal_encoder, bool: lambda connection, obj: str(int(obj)), unicode: unicode_to_quoted_sql, + str: unicode_to_quoted_sql, datetime: datetime_encoder, } @@ -98,4 +99,8 @@ def fallback_decoder(connection, field): DEFAULT_DECODERS = [ fallback_decoder, -] \ No newline at end of file +] + +# MySQLdb compatibility +conversions = _simple_field_decoders +conversions.update(_simple_field_encoders) diff --git a/MySQLdb/cursors.py b/MySQLdb/cursors.py index c72c51c..0c4f2e5 100644 --- a/MySQLdb/cursors.py +++ b/MySQLdb/cursors.py @@ -302,7 +302,9 @@ def fetchmany(self, size): break self.rows.append(row) if self.row_index >= len(self.rows): - return [] + # MySQLdb compatbility: There are applications checking for tuple + # instead of empty lists or even an iterator... + return () row_end = self.row_index + size if row_end >= len(self.rows): row_end = len(self.rows) diff --git a/MySQLdb/libmysql.py b/MySQLdb/libmysql.py index fb1bc2a..ccd3cad 100644 --- a/MySQLdb/libmysql.py +++ b/MySQLdb/libmysql.py @@ -139,6 +139,9 @@ class MYSQL_FIELD(ctypes.Structure): c.mysql_character_set_name.argtypes = [MYSQL_P] c.mysql_character_set_name.restype = ctypes.c_char_p +c.mysql_ping.argtypes = [MYSQL_P] +c.mysql_ping.restype = None + # Second thing is an enum, it looks to be a long on Linux systems. c.mysql_options.argtypes = [MYSQL_P, ctypes.c_long, ctypes.c_char_p] c.mysql_options.restype = ctypes.c_int diff --git a/tests/test_connection.py b/tests/test_connection.py index ad6b6ca..83145a5 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -38,3 +38,6 @@ def test_closed_error(self, connection): with py.test.raises(connection.InterfaceError) as exc: connection.rollback() assert str(exc.value) == "(0, '')" + + def test_ping(self, connection): + connection.ping()