Skip to content

Commit 5662d9f

Browse files
elprans1st1
authored andcommitted
Add support for COPY OUT
This commit adds two new Connection methods: copy_from_table() and copy_from_query() that allow copying the contents of a table or the results of a query using PostgreSQL COPY protocol. With help from @1st1. Issue MagicStack#21.
1 parent df64f55 commit 5662d9f

13 files changed

+935
-22
lines changed

Diff for: asyncpg/compat.py

+27
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,13 @@
66

77

88
import functools
9+
import os
910
import sys
1011

1112

13+
PY_36 = sys.version_info >= (3, 6)
14+
15+
1216
if sys.version_info < (3, 5, 2):
1317
def aiter_compat(func):
1418
@functools.wraps(func)
@@ -18,3 +22,26 @@ async def wrapper(self):
1822
else:
1923
def aiter_compat(func):
2024
return func
25+
26+
27+
if PY_36:
28+
fspath = os.fspath
29+
else:
30+
def fspath(path):
31+
fsp = getattr(path, '__fspath__', None)
32+
if fsp is not None and callable(fsp):
33+
path = fsp()
34+
if not isinstance(path, (str, bytes)):
35+
raise TypeError(
36+
'expected {}() to return str or bytes, not {}'.format(
37+
fsp.__qualname__, type(path).__name__
38+
))
39+
return path
40+
elif isinstance(path, (str, bytes)):
41+
return path
42+
else:
43+
raise TypeError(
44+
'expected str, bytes or path-like object, not {}'.format(
45+
type(path).__name__
46+
)
47+
)

Diff for: asyncpg/connection.py

+174
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import struct
1111
import time
1212

13+
from . import compat
1314
from . import connect_utils
1415
from . import cursor
1516
from . import exceptions
@@ -18,6 +19,7 @@
1819
from . import protocol
1920
from . import serverversion
2021
from . import transaction
22+
from . import utils
2123

2224

2325
class ConnectionMeta(type):
@@ -345,6 +347,178 @@ async def fetchrow(self, query, *args, timeout=None):
345347
return None
346348
return data[0]
347349

350+
async def copy_from_table(self, table_name, *, output,
351+
columns=None, schema_name=None, timeout=None,
352+
format=None, oids=None, delimiter=None,
353+
null=None, header=None, quote=None,
354+
escape=None, force_quote=None, encoding=None):
355+
"""Copy table contents to a file or file-like object.
356+
357+
:param str table_name:
358+
The name of the table to copy data from.
359+
360+
:param output:
361+
A :term:`path-like object <python:path-like object>`,
362+
or a :term:`file-like object <python:file-like object>`, or
363+
a :term:`coroutine function <python:coroutine function>`
364+
that takes a ``bytes`` instance as a sole argument.
365+
366+
:param list columns:
367+
An optional list of column names to copy.
368+
369+
:param str schema_name:
370+
An optional schema name to qualify the table.
371+
372+
:param float timeout:
373+
Optional timeout value in seconds.
374+
375+
The remaining kewyword arguments are ``COPY`` statement options,
376+
see `COPY statement documentation`_ for details.
377+
378+
:return: The status string of the COPY command.
379+
380+
.. versionadded:: 0.11.0
381+
382+
.. _`COPY statement documentation`: https://www.postgresql.org/docs/\
383+
current/static/sql-copy.html
384+
385+
"""
386+
tabname = utils._quote_ident(table_name)
387+
if schema_name:
388+
tabname = utils._quote_ident(schema_name) + '.' + tabname
389+
390+
if columns:
391+
cols = '({})'.format(
392+
', '.join(utils._quote_ident(c) for c in columns))
393+
else:
394+
cols = ''
395+
396+
opts = self._format_copy_opts(
397+
format=format, oids=oids, delimiter=delimiter,
398+
null=null, header=header, quote=quote, escape=escape,
399+
force_quote=force_quote, encoding=encoding
400+
)
401+
402+
copy_stmt = 'COPY {tab}{cols} TO STDOUT {opts}'.format(
403+
tab=tabname, cols=cols, opts=opts)
404+
405+
return await self._copy_out(copy_stmt, output, timeout)
406+
407+
async def copy_from_query(self, query, *args, output,
408+
timeout=None, format=None, oids=None,
409+
delimiter=None, null=None, header=None,
410+
quote=None, escape=None, force_quote=None,
411+
encoding=None):
412+
"""Copy the results of a query to a file or file-like object.
413+
414+
:param str query:
415+
The query to copy the results of.
416+
417+
:param *args:
418+
Query arguments.
419+
420+
:param output:
421+
A :term:`path-like object <python:path-like object>`,
422+
or a :term:`file-like object <python:file-like object>`, or
423+
a :term:`coroutine function <python:coroutine function>`
424+
that takes a ``bytes`` instance as a sole argument.
425+
426+
:param float timeout:
427+
Optional timeout value in seconds.
428+
429+
The remaining kewyword arguments are ``COPY`` statement options,
430+
see `COPY statement documentation`_ for details.
431+
432+
:return: The status string of the COPY command.
433+
434+
.. versionadded:: 0.11.0
435+
436+
.. _`COPY statement documentation`: https://www.postgresql.org/docs/\
437+
current/static/sql-copy.html
438+
439+
"""
440+
opts = self._format_copy_opts(
441+
format=format, oids=oids, delimiter=delimiter,
442+
null=null, header=header, quote=quote, escape=escape,
443+
force_quote=force_quote, encoding=encoding
444+
)
445+
446+
if args:
447+
query = await utils._mogrify(self, query, args)
448+
449+
copy_stmt = 'COPY ({query}) TO STDOUT {opts}'.format(
450+
query=query, opts=opts)
451+
452+
return await self._copy_out(copy_stmt, output, timeout)
453+
454+
def _format_copy_opts(self, *, format=None, oids=None, freeze=None,
455+
delimiter=None, null=None, header=None, quote=None,
456+
escape=None, force_quote=None, force_not_null=None,
457+
force_null=None, encoding=None):
458+
kwargs = dict(locals())
459+
kwargs.pop('self')
460+
opts = []
461+
462+
if force_quote is not None and isinstance(force_quote, bool):
463+
kwargs.pop('force_quote')
464+
if force_quote:
465+
opts.append('FORCE_QUOTE *')
466+
467+
for k, v in kwargs.items():
468+
if v is not None:
469+
if k in ('force_not_null', 'force_null', 'force_quote'):
470+
v = '(' + ', '.join(utils._quote_ident(c) for c in v) + ')'
471+
elif k in ('oids', 'freeze', 'header'):
472+
v = str(v)
473+
else:
474+
v = utils._quote_literal(v)
475+
476+
opts.append('{} {}'.format(k.upper(), v))
477+
478+
if opts:
479+
return '(' + ', '.join(opts) + ')'
480+
else:
481+
return ''
482+
483+
async def _copy_out(self, copy_stmt, output, timeout):
484+
try:
485+
path = compat.fspath(output)
486+
except TypeError:
487+
# output is not a path-like object
488+
path = None
489+
490+
writer = None
491+
opened_by_us = False
492+
run_in_executor = self._loop.run_in_executor
493+
494+
if path is not None:
495+
# a path
496+
f = await run_in_executor(None, open, path, 'wb')
497+
opened_by_us = True
498+
elif hasattr(output, 'write'):
499+
# file-like
500+
f = output
501+
elif callable(output):
502+
# assuming calling output returns an awaitable.
503+
writer = output
504+
else:
505+
raise TypeError(
506+
'output is expected to be a file-like object, '
507+
'a path-like object or a coroutine function, '
508+
'not {}'.format(type(output).__name__)
509+
)
510+
511+
if writer is None:
512+
async def _writer(data):
513+
await run_in_executor(None, f.write, data)
514+
writer = _writer
515+
516+
try:
517+
return await self._protocol.copy_out(copy_stmt, writer, timeout)
518+
finally:
519+
if opened_by_us:
520+
f.close()
521+
348522
async def set_type_codec(self, typename, *,
349523
schema='public', encoder, decoder, binary=False):
350524
"""Set an encoder/decoder pair for the specified data type.

Diff for: asyncpg/protocol/buffer.pxd

+4-1
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,17 @@ cdef class ReadBuffer:
9999
cdef _switch_to_next_buf(self)
100100
cdef inline read_byte(self)
101101
cdef inline char* _try_read_bytes(self, ssize_t nbytes)
102-
cdef inline read(self, ssize_t nbytes)
102+
cdef inline _read(self, char *buf, ssize_t nbytes)
103+
cdef read(self, ssize_t nbytes)
103104
cdef inline read_bytes(self, ssize_t n)
104105
cdef inline read_int32(self)
105106
cdef inline read_int16(self)
106107
cdef inline read_cstr(self)
107108
cdef int32_t has_message(self) except -1
109+
cdef inline int32_t has_message_type(self, char mtype) except -1
108110
cdef inline char* try_consume_message(self, ssize_t* len)
109111
cdef Memory consume_message(self)
112+
cdef bytearray consume_messages(self, char mtype)
110113
cdef discard_message(self)
111114
cdef inline _discard_message(self)
112115
cdef inline char get_message_type(self)

Diff for: asyncpg/protocol/buffer.pyx

+66-19
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,11 @@ cdef class ReadBuffer:
296296
'debug: second buffer of ReadBuffer is empty')
297297

298298
cdef inline char* _try_read_bytes(self, ssize_t nbytes):
299+
# Try to read *nbytes* from the first buffer.
300+
#
301+
# Returns pointer to data if there is at least *nbytes*
302+
# in the buffer, NULL otherwise.
303+
#
299304
# Important: caller must call _ensure_first_buf() prior
300305
# to calling try_read_bytes, and must not overread
301306

@@ -321,9 +326,32 @@ cdef class ReadBuffer:
321326
else:
322327
return NULL
323328

324-
cdef inline read(self, ssize_t nbytes):
329+
cdef inline _read(self, char *buf, ssize_t nbytes):
325330
cdef:
326-
object result
331+
ssize_t nread
332+
char *buf0
333+
334+
while True:
335+
buf0 = cpython.PyBytes_AS_STRING(self._buf0)
336+
337+
if self._pos0 + nbytes > self._len0:
338+
nread = self._len0 - self._pos0
339+
memcpy(buf, buf0 + self._pos0, <size_t>nread)
340+
self._pos0 = self._len0
341+
self._length -= nread
342+
nbytes -= nread
343+
buf += nread
344+
self._ensure_first_buf()
345+
346+
else:
347+
memcpy(buf, buf0 + self._pos0, <size_t>nbytes)
348+
self._pos0 += nbytes
349+
self._length -= nbytes
350+
break
351+
352+
cdef read(self, ssize_t nbytes):
353+
cdef:
354+
bytearray result
327355
ssize_t nread
328356
char *buf
329357

@@ -341,24 +369,11 @@ cdef class ReadBuffer:
341369
if self._current_message_len_unread < 0:
342370
raise BufferError('buffer overread')
343371

344-
result = bytearray()
345-
while True:
346-
if self._pos0 + nbytes > self._len0:
347-
result.extend(self._buf0[self._pos0:])
348-
nread = self._len0 - self._pos0
349-
self._pos0 = self._len0
350-
self._length -= nread
351-
nbytes -= nread
352-
self._ensure_first_buf()
372+
result = PyByteArray_FromStringAndSize(NULL, nbytes)
373+
buf = PyByteArray_AsString(result)
374+
self._read(buf, nbytes)
353375

354-
else:
355-
result.extend(self._buf0[self._pos0:self._pos0 + nbytes])
356-
self._pos0 += nbytes
357-
self._length -= nbytes
358-
return Memory.new(
359-
PyByteArray_AsString(result),
360-
result,
361-
len(result))
376+
return Memory.new(buf, result, nbytes)
362377

363378
cdef inline read_byte(self):
364379
cdef char* first_byte
@@ -506,6 +521,9 @@ cdef class ReadBuffer:
506521
self._current_message_ready = 1
507522
return 1
508523

524+
cdef inline int32_t has_message_type(self, char mtype) except -1:
525+
return self.has_message() and self.get_message_type() == mtype
526+
509527
cdef inline char* try_consume_message(self, ssize_t* len):
510528
cdef ssize_t buf_len
511529

@@ -530,6 +548,35 @@ cdef class ReadBuffer:
530548
self._discard_message()
531549
return mem
532550

551+
cdef bytearray consume_messages(self, char mtype):
552+
"""Consume consecutive messages of the same type."""
553+
cdef:
554+
char *buf
555+
ssize_t nbytes
556+
ssize_t total_bytes = 0
557+
bytearray result
558+
559+
if not self.has_message_type(mtype):
560+
return None
561+
562+
# consume_messages is a volume-oriented method, so
563+
# we assume that the remainder of the buffer will contain
564+
# messages of the requested type.
565+
result = PyByteArray_FromStringAndSize(NULL, self._length)
566+
buf = PyByteArray_AsString(result)
567+
568+
while self.has_message_type(mtype):
569+
nbytes = self._current_message_len_unread
570+
self._read(buf, nbytes)
571+
buf += nbytes
572+
total_bytes += nbytes
573+
self._discard_message()
574+
575+
# Clamp the result to an actual size read.
576+
PyByteArray_Resize(result, total_bytes)
577+
578+
return result
579+
533580
cdef discard_message(self):
534581
if self._current_message_type == 0:
535582
# Already discarded

0 commit comments

Comments
 (0)