Skip to content

Commit

Permalink
Fix Connection.escape() with Unicode input (#608)
Browse files Browse the repository at this point in the history
After aed1dd2, Connection.escape() used ASCII to escape Unicode input.
This commit makes it uses connection encoding instead.
  • Loading branch information
methane committed May 18, 2023
1 parent 44d0f7a commit b162ddd
Showing 1 changed file with 43 additions and 21 deletions.
64 changes: 43 additions & 21 deletions src/MySQLdb/_mysql.c
Original file line number Diff line number Diff line change
Expand Up @@ -943,7 +943,7 @@ _mysql_escape_string(
{
PyObject *str;
char *in, *out;
int len;
unsigned long len;
Py_ssize_t size;
if (!PyArg_ParseTuple(args, "s#:escape_string", &in, &size)) return NULL;
str = PyBytes_FromStringAndSize((char *) NULL, size*2+1);
Expand Down Expand Up @@ -980,35 +980,52 @@ _mysql_string_literal(
_mysql_ConnectionObject *self,
PyObject *o)
{
PyObject *str, *s;
char *in, *out;
unsigned long len;
Py_ssize_t size;
PyObject *s; // input string or bytes. need to decref.

if (self && PyModule_Check((PyObject*)self))
self = NULL;

if (PyBytes_Check(o)) {
s = o;
Py_INCREF(s);
} else {
s = PyObject_Str(o);
if (!s) return NULL;
{
PyObject *t = PyUnicode_AsASCIIString(s);
Py_DECREF(s);
if (!t) return NULL;
}
else {
PyObject *t = PyObject_Str(o);
if (!t) return NULL;

const char *encoding = (self && self->open) ?
_get_encoding(&self->connection) : utf8;
if (encoding == utf8) {
s = t;
}
else {
s = PyUnicode_AsEncodedString(t, encoding, "strict");
Py_DECREF(t);
if (!s) return NULL;
}
}
in = PyBytes_AsString(s);
size = PyBytes_GET_SIZE(s);
str = PyBytes_FromStringAndSize((char *) NULL, size*2+3);

// Prepare input string (in, size)
const char *in;
Py_ssize_t size;
if (PyUnicode_Check(s)) {
in = PyUnicode_AsUTF8AndSize(s, &size);
} else {
assert(PyBytes_Check(s));
in = PyBytes_AsString(s);
size = PyBytes_GET_SIZE(s);
}

// Prepare output buffer (str, out)
PyObject *str = PyBytes_FromStringAndSize((char *) NULL, size*2+3);
if (!str) {
Py_DECREF(s);
return PyErr_NoMemory();
}
out = PyBytes_AS_STRING(str);
char *out = PyBytes_AS_STRING(str);

// escape
unsigned long len;
if (self && self->open) {
#if MYSQL_VERSION_ID >= 50707 && !defined(MARIADB_BASE_VERSION) && !defined(MARIADB_VERSION_ID)
len = mysql_real_escape_string_quote(&(self->connection), out+1, in, size, '\'');
Expand All @@ -1018,10 +1035,14 @@ _mysql_string_literal(
} else {
len = mysql_escape_string(out+1, in, size);
}
*out = *(out+len+1) = '\'';
if (_PyBytes_Resize(&str, len+2) < 0) return NULL;

Py_DECREF(s);
return (str);
*out = *(out+len+1) = '\'';
if (_PyBytes_Resize(&str, len+2) < 0) {
Py_DECREF(str);
return NULL;
}
return str;
}

static PyObject *
Expand Down Expand Up @@ -1499,8 +1520,9 @@ _mysql_ResultObject_discard(
// do nothing
}
Py_END_ALLOW_THREADS
if (mysql_errno(self->conn)) {
return _mysql_Exception(self->conn);
_mysql_ConnectionObject *conn = (_mysql_ConnectionObject *)self->conn;
if (mysql_errno(&conn->connection)) {
return _mysql_Exception(conn);
}
Py_RETURN_NONE;
}
Expand Down

0 comments on commit b162ddd

Please sign in to comment.