diff --git a/src/MySQLdb/_mysql.c b/src/MySQLdb/_mysql.c index b030af16..1f52d90b 100644 --- a/src/MySQLdb/_mysql.c +++ b/src/MySQLdb/_mysql.c @@ -943,7 +943,7 @@ _mysql_escape_string( { PyObject *str; char *in, *out; - int len; + unsigned long len; Py_ssize_t size; if (!PyArg_ParseTuple(args, "s#:escape_string", &in, &size)) return NULL; str = PyBytes_FromStringAndSize((char *) NULL, size*2+1); @@ -980,10 +980,7 @@ _mysql_string_literal( _mysql_ConnectionObject *self, PyObject *o) { - PyObject *str, *s; - char *in, *out; - unsigned long len; - Py_ssize_t size; + PyObject *s; // input string or bytes. need to decref. if (self && PyModule_Check((PyObject*)self)) self = NULL; @@ -991,24 +988,44 @@ _mysql_string_literal( if (PyBytes_Check(o)) { s = o; Py_INCREF(s); - } else { - s = PyObject_Str(o); - if (!s) return NULL; - { - PyObject *t = PyUnicode_AsASCIIString(s); - Py_DECREF(s); - if (!t) return NULL; + } + else { + PyObject *t = PyObject_Str(o); + if (!t) return NULL; + + const char *encoding = (self && self->open) ? + _get_encoding(&self->connection) : utf8; + if (encoding == utf8) { s = t; } + else { + s = PyUnicode_AsEncodedString(t, encoding, "strict"); + Py_DECREF(t); + if (!s) return NULL; + } } - in = PyBytes_AsString(s); - size = PyBytes_GET_SIZE(s); - str = PyBytes_FromStringAndSize((char *) NULL, size*2+3); + + // Prepare input string (in, size) + const char *in; + Py_ssize_t size; + if (PyUnicode_Check(s)) { + in = PyUnicode_AsUTF8AndSize(s, &size); + } else { + assert(PyBytes_Check(s)); + in = PyBytes_AsString(s); + size = PyBytes_GET_SIZE(s); + } + + // Prepare output buffer (str, out) + PyObject *str = PyBytes_FromStringAndSize((char *) NULL, size*2+3); if (!str) { Py_DECREF(s); return PyErr_NoMemory(); } - out = PyBytes_AS_STRING(str); + char *out = PyBytes_AS_STRING(str); + + // escape + unsigned long len; if (self && self->open) { #if MYSQL_VERSION_ID >= 50707 && !defined(MARIADB_BASE_VERSION) && !defined(MARIADB_VERSION_ID) len = mysql_real_escape_string_quote(&(self->connection), out+1, in, size, '\''); @@ -1018,10 +1035,14 @@ _mysql_string_literal( } else { len = mysql_escape_string(out+1, in, size); } - *out = *(out+len+1) = '\''; - if (_PyBytes_Resize(&str, len+2) < 0) return NULL; + Py_DECREF(s); - return (str); + *out = *(out+len+1) = '\''; + if (_PyBytes_Resize(&str, len+2) < 0) { + Py_DECREF(str); + return NULL; + } + return str; } static PyObject * @@ -1499,8 +1520,9 @@ _mysql_ResultObject_discard( // do nothing } Py_END_ALLOW_THREADS - if (mysql_errno(self->conn)) { - return _mysql_Exception(self->conn); + _mysql_ConnectionObject *conn = (_mysql_ConnectionObject *)self->conn; + if (mysql_errno(&conn->connection)) { + return _mysql_Exception(conn); } Py_RETURN_NONE; }