Skip to content

Commit

Permalink
Split dotted table names when escaping in inserttable (#61)
Browse files Browse the repository at this point in the history
This is the same pragmatic solution as used in the copy methods of pgdb.
We should implement a proper solution by allowing tuples
or a separate schema parameter in the next version.
  • Loading branch information
Cito committed Jan 30, 2022
1 parent 663dd90 commit 307ec95
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/contents/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Version 5.2.3 (to be released)
- This version officially supports the new Python 3.10 and PostgreSQL 14.
- Some improvements and fixes in the `inserttable()` method of the `pg` module:
- Sync with `PQendcopy()` when there was an error (#60)
- Allow specifying a schema in the table name (#61)
- Improved check for internal result (#62)
- Catch buffer overflows when building the copy command
- Data can now be passed as an iterable, not just list or tuple (#66)
Expand Down
17 changes: 13 additions & 4 deletions pgconn.c
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ static PyObject *
conn_inserttable(connObject *self, PyObject *args)
{
PGresult *result;
char *table, *buffer, *bufpt, *bufmax;
char *table, *buffer, *bufpt, *bufmax, *s, *t;
int encoding;
size_t bufsiz;
PyObject *rows, *iter_row, *item, *columns = NULL;
Expand Down Expand Up @@ -753,9 +753,18 @@ conn_inserttable(connObject *self, PyObject *args)
/* starts query */
bufpt = buffer;
bufmax = bufpt + MAX_BUFFER_SIZE;
table = PQescapeIdentifier(self->cnx, table, strlen(table));
bufpt += snprintf(bufpt, (size_t) (bufmax - bufpt), "copy %s", table);
PQfreemem(table);
bufpt += snprintf(bufpt, (size_t) (bufmax - bufpt), "copy ");

s = table;
do {
t = strchr(s, '.'); if (!t) t = s + strlen(s);
table = PQescapeIdentifier(self->cnx, s, (size_t) (t - s));
if (bufpt < bufmax)
bufpt += snprintf(bufpt, (size_t) (bufmax - bufpt), "%s", table);
PQfreemem(table);
s = t; if (*s && bufpt < bufmax) *bufpt++ = *s++;
} while (*s);

if (columns) {
/* adds a string like f" ({','.join(columns)})" */
if (bufpt < bufmax)
Expand Down
13 changes: 13 additions & 0 deletions tests/test_classic_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1963,6 +1963,11 @@ def testInserttableOnlyTwoColumns(self):
+ (None,) * 6 for i in range(20)]
self.assertEqual(self.get_back(), data)

def testInserttableWithDottedTableName(self):
data = self.data
self.c.inserttable('public.test', data)
self.assertEqual(self.get_back(), data)

def testInserttableWithInvalidTableName(self):
data = [(42,)]
# check that the table name is not inserted unescaped
Expand All @@ -1976,6 +1981,14 @@ def testInserttableWithInvalidTableName(self):
# make sure that it works if parameters are passed properly
self.c.inserttable('test', data, ['i4'])

def testInserttableWithInvalidDataType(self):
try:
self.c.inserttable('test', 42)
except TypeError as e:
self.assertIn('expects an iterable as second argument', str(e))
else:
self.assertFalse('expected an error')

def testInserttableWithInvalidColumnName(self):
data = [(2, 4)]
# check that the column names are not inserted unescaped
Expand Down

0 comments on commit 307ec95

Please sign in to comment.