Skip to content

Commit

Permalink
Make inserttable() accept any iterable (#66)
Browse files Browse the repository at this point in the history
  • Loading branch information
Cito committed Jan 28, 2022
1 parent a0025f8 commit 8b77708
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 83 deletions.
28 changes: 14 additions & 14 deletions docs/contents/pg/connection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,12 @@ returns without waiting for the query to complete. The database connection
cannot be used for other operations until the query completes, but the
application can do other things, including executing queries using other
database connections. The application can call ``select()`` using the
``fileno`` obtained by the connection's :meth:`Connection.fileno` method
``fileno`` obtained by the connection's :meth:`Connection.fileno` method
to determine when the query has results to return.

This method always returns a :class:`Query` object. This object differs
from the :class:`Query` object returned by :meth:`Connection.query` in a
few ways. Most importantly, when :meth:`Connection.send_query` is used, the
few ways. Most importantly, when :meth:`Connection.send_query` is used, the
application must call one of the result-returning methods such as
:meth:`Query.getresult` or :meth:`Query.dictresult` until it either raises
an exception or returns ``None``.
Expand Down Expand Up @@ -285,7 +285,7 @@ it's no different from a connection made using blocking calls.

The required steps are to pass the parameter ``nowait=True`` to the
:meth:`pg.connect` call, then call :meth:`Connection.poll` until it either
returns :const:`POLLING_OK` or raises an exception. To avoid blocking
returns :const:`POLLING_OK` or raises an exception. To avoid blocking
in :meth:`Connection.poll`, use `select()` or `poll()` to wait for the
connection to be readable or writable, depending on the return code of the
previous call to :meth:`Connection.poll`. The initial state of the connection
Expand Down Expand Up @@ -484,27 +484,27 @@ first, otherwise :meth:`Connection.getnotify` will always return ``None``.
.. versionchanged:: 4.1
Support for payload strings was added in version 4.1.

inserttable -- insert a list into a table
-----------------------------------------
inserttable -- insert an iterable into a table
----------------------------------------------

.. method:: Connection.inserttable(table, values, [columns])

Insert a Python list into a database table
Insert a Python iterable into a database table

:param str table: the table name
:param list values: list of rows values
:param list columns: list of column names
:param list values: iterable of row values, which must be lists or tuples
:param list columns: list or tuple of column names
:rtype: None
:raises TypeError: invalid connection, bad argument type, or too many arguments
:raises MemoryError: insert buffer could not be allocated
:raises ValueError: unsupported values

This method allows to *quickly* insert large blocks of data in a table:
It inserts the whole values list into the given table. Internally, it
uses the COPY command of the PostgreSQL database. The list is a list
of tuples/lists that define the values for each inserted row. The rows
values may contain string, integer, long or double (real) values.
``columns`` is an optional sequence of column names to be passed on
This method allows to *quickly* insert large blocks of data in a table.
Internally, it uses the COPY command of the PostgreSQL database.
The method takes an iterable of row values which must be tuples or lists
of the same size, containing the values for each inserted row.
These may contain string, integer, long or double (real) values.
``columns`` is an optional tuple or list of column names to be passed on
to the COPY command.

.. warning::
Expand Down
130 changes: 66 additions & 64 deletions pgconn.c
Original file line number Diff line number Diff line change
Expand Up @@ -682,9 +682,9 @@ conn_is_non_blocking(connObject *self, PyObject *noargs)

/* Insert table */
static char conn_inserttable__doc__[] =
"inserttable(table, data, [columns]) -- insert list into table\n\n"
"The fields in the list must be in the same order as in the table\n"
"or in the list of columns if one is specified.\n";
"inserttable(table, data, [columns]) -- insert iterable into table\n\n"
"The fields in the iterable must be in the same order as in the table\n"
"or in the list or tuple of columns if one is specified.\n";

static PyObject *
conn_inserttable(connObject *self, PyObject *args)
Expand All @@ -693,69 +693,60 @@ conn_inserttable(connObject *self, PyObject *args)
char *table, *buffer, *bufpt, *bufmax;
int encoding;
size_t bufsiz;
PyObject *list, *sublist, *item, *columns = NULL;
PyObject *(*getitem) (PyObject *, Py_ssize_t);
PyObject *(*getsubitem) (PyObject *, Py_ssize_t);
PyObject *(*getcolumn) (PyObject *, Py_ssize_t);
Py_ssize_t i, j, m, n = 0;
PyObject *rows, *iter_row, *item, *columns = NULL;
Py_ssize_t i, j, m, n;

if (!self->cnx) {
PyErr_SetString(PyExc_TypeError, "Connection is not valid");
return NULL;
}

/* gets arguments */
if (!PyArg_ParseTuple(args, "sO|O", &table, &list, &columns)) {
if (!PyArg_ParseTuple(args, "sO|O", &table, &rows, &columns)) {
PyErr_SetString(
PyExc_TypeError,
"Method inserttable() expects a string and a list as arguments");
return NULL;
}

/* checks list type */
if (PyList_Check(list)) {
m = PyList_Size(list);
getitem = PyList_GetItem;
}
else if (PyTuple_Check(list)) {
m = PyTuple_Size(list);
getitem = PyTuple_GetItem;
}
else {
if (!(iter_row = PyObject_GetIter(rows)))
{
PyErr_SetString(
PyExc_TypeError,
"Method inserttable() expects a list or a tuple"
"Method inserttable() expects an iterable"
" as second argument");
return NULL;
}
m = PySequence_Check(rows) ? PySequence_Size(rows) : -1;
if (!m) {
/* no rows specified, nothing to do */
Py_DECREF(iter_row); Py_INCREF(Py_None); return Py_None;
}

/* checks columns type */
if (columns) {
if (PyList_Check(columns)) {
n = PyList_Size(columns);
getcolumn = PyList_GetItem;
}
else if (PyTuple_Check(columns)) {
n = PyTuple_Size(columns);
getcolumn = PyTuple_GetItem;
}
else {
if (!(PyTuple_Check(columns) || PyList_Check(columns))) {
PyErr_SetString(
PyExc_TypeError,
"Method inserttable() expects a list or a tuple"
" as third argument");
"Method inserttable() expects a tuple or a list"
" as second argument");

This comment has been minimized.

Copy link
@justinpryzby

justinpryzby Jan 29, 2022

Contributor

third

return NULL;
}

n = PySequence_Fast_GET_SIZE(columns);
if (!n) {
/* no columns specified, nothing to do */
Py_INCREF(Py_None);
return Py_None;
Py_DECREF(iter_row); Py_INCREF(Py_None); return Py_None;
}
} else {
n = -1; /* number of columns not yet known */
}

/* allocate buffer */
if (!(buffer = PyMem_Malloc(MAX_BUFFER_SIZE)))
return PyErr_NoMemory();
if (!(buffer = PyMem_Malloc(MAX_BUFFER_SIZE))) {
Py_DECREF(iter_row); return PyErr_NoMemory();
}

encoding = PQclientEncoding(self->cnx);

Expand All @@ -770,7 +761,7 @@ conn_inserttable(connObject *self, PyObject *args)
if (bufpt < bufmax)
bufpt += snprintf(bufpt, (size_t) (bufmax - bufpt), " (");
for (j = 0; j < n; ++j) {
PyObject *obj = getcolumn(columns, j);
PyObject *obj = PySequence_Fast_GET_ITEM(columns, j);
Py_ssize_t slen;
char *col;

Expand All @@ -779,13 +770,18 @@ conn_inserttable(connObject *self, PyObject *args)
}
else if (PyUnicode_Check(obj)) {
obj = get_encoded_string(obj, encoding);
if (!obj) return NULL; /* pass the UnicodeEncodeError */
if (!obj) {
Py_DECREF(iter_row);
return NULL; /* pass the UnicodeEncodeError */
}
PyBytes_AsStringAndSize(obj, &col, &slen);
Py_DECREF(obj);
} else {
PyErr_SetString(
PyExc_TypeError,
"The third argument must contain only strings");
Py_DECREF(iter_row);
return NULL;
}
col = PQescapeIdentifier(self->cnx, col, (size_t) slen);
if (bufpt < bufmax)
Expand All @@ -797,49 +793,46 @@ conn_inserttable(connObject *self, PyObject *args)
if (bufpt < bufmax)
snprintf(bufpt, (size_t) (bufmax - bufpt), " from stdin");
if (bufpt >= bufmax) {
PyMem_Free(buffer); return PyErr_NoMemory();
PyMem_Free(buffer); Py_DECREF(iter_row);
return PyErr_NoMemory();
}

Py_BEGIN_ALLOW_THREADS
result = PQexec(self->cnx, buffer);
Py_END_ALLOW_THREADS

if (!result) {
PyMem_Free(buffer);
PyMem_Free(buffer); Py_DECREF(iter_row);
PyErr_SetString(PyExc_ValueError, PQerrorMessage(self->cnx));
return NULL;
}

PQclear(result);

/* feed table */
for (i = 0; i < m; ++i) {
sublist = getitem(list, i);
if (PyTuple_Check(sublist)) {
j = PyTuple_Size(sublist);
getsubitem = PyTuple_GetItem;
}
else if (PyList_Check(sublist)) {
j = PyList_Size(sublist);
getsubitem = PyList_GetItem;
}
else {
for (i = 0; m < 0 || i < m; ++i) {

if (!(columns = PyIter_Next(iter_row))) break;

if (!(PyTuple_Check(columns) || PyList_Check(columns))) {
PyMem_Free(buffer);
Py_DECREF(columns); Py_DECREF(columns); Py_DECREF(iter_row);
PyErr_SetString(
PyExc_TypeError,
"The second argument must contain a tuple or a list");
"The second argument must contain tuples or lists");
return NULL;
}
if (i) {
if (j != n) {
PyMem_Free(buffer);
PyErr_SetString(
PyExc_TypeError,
"Arrays contained in second arg must have same size");
return NULL;
}
}
else {
n = j; /* never used before this assignment */

j = PySequence_Fast_GET_SIZE(columns);
if (n < 0) {
n = j;
} else if (j != n) {
PyMem_Free(buffer);
Py_DECREF(columns); Py_DECREF(iter_row);
PyErr_SetString(
PyExc_TypeError,
"The second arg must contain sequences of the same size");
return NULL;
}

/* builds insert line */
Expand All @@ -851,7 +844,7 @@ conn_inserttable(connObject *self, PyObject *args)
*bufpt++ = '\t'; --bufsiz;
}

item = getsubitem(sublist, j);
item = PySequence_Fast_GET_ITEM(columns, j);

/* convert item to string and append to buffer */
if (item == Py_None) {
Expand All @@ -877,6 +870,7 @@ conn_inserttable(connObject *self, PyObject *args)
PyObject *s = get_encoded_string(item, encoding);
if (!s) {
PyMem_Free(buffer);
Py_DECREF(item); Py_DECREF(columns); Py_DECREF(iter_row);
return NULL; /* pass the UnicodeEncodeError */
}
else {
Expand Down Expand Up @@ -916,22 +910,30 @@ conn_inserttable(connObject *self, PyObject *args)
}

if (bufsiz <= 0) {
PyMem_Free(buffer); return PyErr_NoMemory();
PyMem_Free(buffer); Py_DECREF(columns); Py_DECREF(iter_row);
return PyErr_NoMemory();
}

}

Py_DECREF(columns);

*bufpt++ = '\n'; *bufpt = '\0';

/* sends data */
if (PQputline(self->cnx, buffer)) {
PyErr_SetString(PyExc_IOError, PQerrorMessage(self->cnx));
PQendcopy(self->cnx);
PyMem_Free(buffer);
PyMem_Free(buffer); Py_DECREF(iter_row);
return NULL;
}
}

Py_DECREF(iter_row);
if (PyErr_Occurred()) {
PyMem_Free(buffer); return NULL; /* pass the iteration error */
}

/* ends query */
if (PQputline(self->cnx, "\\.\n")) {
PyErr_SetString(PyExc_IOError, PQerrorMessage(self->cnx));
Expand Down
37 changes: 32 additions & 5 deletions tests/test_classic_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1890,15 +1890,42 @@ def testInserttableFromTupleOfLists(self):
self.c.inserttable('test', data)
self.assertEqual(self.get_back(), self.data)

def testInserttableFromSetofTuples(self):
data = {row for row in self.data}
def testInserttableWithDifferentRowSizes(self):
data = self.data[:-1] + [self.data[-1][:-1]]
try:
self.c.inserttable('test', data)
except TypeError as e:
r = str(e)
else:
r = 'this is fine'
self.assertIn('list or a tuple as second argument', r)
self.assertIn('second arg must contain sequences of the same size', r)

def testInserttableFromSetofTuples(self):
data = {row for row in self.data}
self.c.inserttable('test', data)
self.assertEqual(self.get_back(), self.data)

def testInserttableFromDictAsInterable(self):
data = {row: None for row in self.data}
self.c.inserttable('test', data)
self.assertEqual(self.get_back(), self.data)

def testInserttableFromDictKeys(self):
data = {row: None for row in self.data}
keys = data.keys()
self.c.inserttable('test', keys)
self.assertEqual(self.get_back(), self.data)

def testInserttableFromDictValues(self):
data = {i: row for i, row in enumerate(self.data)}
values = data.values()
self.c.inserttable('test', values)
self.assertEqual(self.get_back(), self.data)

def testInserttableFromGeneratorOfTuples(self):
data = (row for row in self.data)
self.c.inserttable('test', data)
self.assertEqual(self.get_back(), self.data)

def testInserttableFromListOfSets(self):
data = [set(row) for row in self.data]
Expand All @@ -1908,7 +1935,7 @@ def testInserttableFromListOfSets(self):
r = str(e)
else:
r = 'this is fine'
self.assertIn('second argument must contain a tuple or a list', r)
self.assertIn('second argument must contain tuples or lists', r)

def testInserttableMultipleRows(self):
num_rows = 100
Expand Down Expand Up @@ -2078,7 +2105,7 @@ def testInserttableNoEncoding(self):
def testInserttableTooLargeColumnSpecification(self):
# should catch buffer overflow when building the column specification
self.assertRaises(MemoryError, self.c.inserttable,
'test', [], ['very_long_column_name'] * 1000)
'test', self.data, ['very_long_column_name'] * 1000)


class TestDirectSocketAccess(unittest.TestCase):
Expand Down

0 comments on commit 8b77708

Please sign in to comment.