Skip to content

Commit

Permalink
Split dotted table names when escaping in inserttable
Browse files Browse the repository at this point in the history
This is the same pragmatic solution as used in the copy methods of pgdb.
We should implement a proper solution by allowing tuples
or a separate schema parameter in the next version.
  • Loading branch information
Cito committed Jan 30, 2022
1 parent 663dd90 commit b06ffaf
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
17 changes: 13 additions & 4 deletions pgconn.c
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ static PyObject *
conn_inserttable(connObject *self, PyObject *args)
{
PGresult *result;
char *table, *buffer, *bufpt, *bufmax;
char *table, *buffer, *bufpt, *bufmax, *s, *t;
int encoding;
size_t bufsiz;
PyObject *rows, *iter_row, *item, *columns = NULL;
Expand Down Expand Up @@ -753,9 +753,18 @@ conn_inserttable(connObject *self, PyObject *args)
/* starts query */
bufpt = buffer;
bufmax = bufpt + MAX_BUFFER_SIZE;
table = PQescapeIdentifier(self->cnx, table, strlen(table));
bufpt += snprintf(bufpt, (size_t) (bufmax - bufpt), "copy %s", table);
PQfreemem(table);
bufpt += snprintf(bufpt, (size_t) (bufmax - bufpt), "copy ");

s = table;
do {
t = strchr(s, '.'); if (!t) t = s + strlen(s);
table = PQescapeIdentifier(self->cnx, s, (size_t) (t - s));
if (bufpt < bufmax)
bufpt += snprintf(bufpt, (size_t) (bufmax - bufpt), "%s", table);
PQfreemem(table);
s = t; if (*s && bufpt < bufmax) *bufpt++ = *s++;
} while (*s);

if (columns) {
/* adds a string like f" ({','.join(columns)})" */
if (bufpt < bufmax)
Expand Down
13 changes: 13 additions & 0 deletions tests/test_classic_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1963,6 +1963,11 @@ def testInserttableOnlyTwoColumns(self):
+ (None,) * 6 for i in range(20)]
self.assertEqual(self.get_back(), data)

def testInserttableWithDottedTableName(self):
data = self.data
self.c.inserttable('public.test', data)
self.assertEqual(self.get_back(), data)

def testInserttableWithInvalidTableName(self):
data = [(42,)]
# check that the table name is not inserted unescaped
Expand All @@ -1976,6 +1981,14 @@ def testInserttableWithInvalidTableName(self):
# make sure that it works if parameters are passed properly
self.c.inserttable('test', data, ['i4'])

def testInserttableWithInvalidDataType(self):
try:
self.c.inserttable('test', 42)
except TypeError as e:
self.assertIn('expects an iterable as second argument', str(e))
else:
self.assertFalse('expected an error')

def testInserttableWithInvalidColumnName(self):
data = [(2, 4)]
# check that the column names are not inserted unescaped
Expand Down

0 comments on commit b06ffaf

Please sign in to comment.