Skip to content

Commit

Permalink
Support qualified table names in copy_from/to (#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
Cito committed Jul 12, 2020
1 parent 18a1ceb commit 17914b6
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 6 deletions.
6 changes: 6 additions & 0 deletions docs/contents/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
ChangeLog
=========

Version 5.2.1 (to be released)
------------------------------
- Changes to the DB-API 2 module (pgdb):
- The `copy_to()` and `copy_from()` methods now also work with table names
containing schema qualifiers (#47).

Version 5.2 (2020-06-21)
------------------------
- We now require Python version 2.7 or 3.5 and newer.
Expand Down
16 changes: 10 additions & 6 deletions pgdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,10 +1268,11 @@ def chunks():

if not table or not isinstance(table, basestring):
raise TypeError("Need a table to copy to")
if table.lower().startswith('select'):
if table.lower().startswith('select '):
raise ValueError("Must specify a table, not a query")
else:
table = '"%s"' % (table,)
table = '.'.join(map(
self.connection._cnx.escape_identifier, table.split('.', 1)))
operation = ['copy %s' % (table,)]
options = []
params = []
Expand Down Expand Up @@ -1299,7 +1300,8 @@ def chunks():
params.append(null)
if columns:
if not isinstance(columns, basestring):
columns = ','.join('"%s"' % (col,) for col in columns)
columns = ','.join(map(
self.connection._cnx.escape_identifier, columns))
operation.append('(%s)' % (columns,))
operation.append("from stdin")
if options:
Expand Down Expand Up @@ -1350,12 +1352,13 @@ def copy_to(self, stream, table,
raise TypeError("Need an output stream to copy to")
if not table or not isinstance(table, basestring):
raise TypeError("Need a table to copy to")
if table.lower().startswith('select'):
if table.lower().startswith('select '):
if columns:
raise ValueError("Columns must be specified in the query")
table = '(%s)' % (table,)
else:
table = '"%s"' % (table,)
table = '.'.join(map(
self.connection._cnx.escape_identifier, table.split('.', 1)))
operation = ['copy %s' % (table,)]
options = []
params = []
Expand Down Expand Up @@ -1394,7 +1397,8 @@ def copy_to(self, stream, table,
"The decode option is not allowed with binary format")
if columns:
if not isinstance(columns, basestring):
columns = ','.join('"%s"' % (col,) for col in columns)
columns = ','.join(map(
self.connection._cnx.escape_identifier, columns))
operation.append('(%s)' % (columns,))

operation.append("to stdout")
Expand Down
8 changes: 8 additions & 0 deletions tests/test_dbapi20_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,10 @@ def test_input_string(self):
self.assertEqual(self.table_data, [(42, 'Hello, world!')])
self.check_rowcount(1)

def test_input_string_with_schema_name(self):
self.cursor.copy_from('42\tHello, world!', 'public.copytest')
self.assertEqual(self.table_data, [(42, 'Hello, world!')])

def test_input_string_with_newline(self):
self.copy_from('42\tHello, world!\n')
self.assertEqual(self.table_data, [(42, 'Hello, world!')])
Expand Down Expand Up @@ -449,6 +453,10 @@ def test_generator(self):
self.assertEqual(rows, self.data_text)
self.check_rowcount()

def test_generator_with_schema_name(self):
ret = self.cursor.copy_to(None, 'public.copytest')
self.assertEqual(''.join(ret), self.data_text)

if str is unicode: # Python >= 3.0

def test_generator_bytes(self):
Expand Down

0 comments on commit 17914b6

Please sign in to comment.