From ca9f03be3c64984311dbefbbd9e8ff0806a7f772 Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Mon, 9 Oct 2023 12:34:02 -0700 Subject: [PATCH] Close cursor portals once the iterator is exhausted (#1088) When iterating on a cursor, make sure to close the portal once iteration is done. This prevents the cursor from holding onto resources until the end of transaction. Fixes: #1008 --- asyncpg/cursor.py | 16 +++++++++++++++- asyncpg/protocol/protocol.pyx | 23 +++++++++++++++++++++++ tests/test_cursor.py | 20 +++++++++++++++----- 3 files changed, 53 insertions(+), 6 deletions(-) diff --git a/asyncpg/cursor.py b/asyncpg/cursor.py index 7ec159ba..b4abeed1 100644 --- a/asyncpg/cursor.py +++ b/asyncpg/cursor.py @@ -158,6 +158,17 @@ async def _exec(self, n, timeout): self._state, self._portal_name, n, True, timeout) return buffer + async def _close_portal(self, timeout): + self._check_ready() + + if not self._portal_name: + raise exceptions.InterfaceError( + 'cursor does not have an open portal') + + protocol = self._connection._protocol + await protocol.close_portal(self._portal_name, timeout) + self._portal_name = None + def __repr__(self): attrs = [] if self._exhausted: @@ -219,7 +230,7 @@ async def __anext__(self): ) self._state.attach() - if not self._portal_name: + if not self._portal_name and not self._exhausted: buffer = await self._bind_exec(self._prefetch, self._timeout) self._buffer.extend(buffer) @@ -227,6 +238,9 @@ async def __anext__(self): buffer = await self._exec(self._prefetch, self._timeout) self._buffer.extend(buffer) + if self._portal_name and self._exhausted: + await self._close_portal(self._timeout) + if self._buffer: return self._buffer.popleft() diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx index 1f739cc2..76c62dfc 100644 --- a/asyncpg/protocol/protocol.pyx +++ b/asyncpg/protocol/protocol.pyx @@ -327,6 +327,29 @@ cdef class BaseProtocol(CoreProtocol): finally: return await waiter + @cython.iterable_coroutine + async def close_portal(self, str portal_name, timeout): + + if self.cancel_waiter is not None: + await self.cancel_waiter + if self.cancel_sent_waiter is not None: + await self.cancel_sent_waiter + self.cancel_sent_waiter = None + + self._check_state() + timeout = self._get_timeout_impl(timeout) + + waiter = self._new_waiter(timeout) + try: + self._close( + portal_name, + True) # network op + except Exception as ex: + waiter.set_exception(ex) + self._coreproto_error() + finally: + return await waiter + @cython.iterable_coroutine async def query(self, query, timeout): if self.cancel_waiter is not None: diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 565def85..ad446bc3 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -84,11 +84,21 @@ async def test_cursor_iterable_06(self): recs = [] async with self.con.transaction(): - async for rec in self.con.cursor( - 'SELECT generate_series(0, $1::int)', 10): - recs.append(rec) - - self.assertEqual(recs, [(i,) for i in range(11)]) + await self.con.execute(''' + CREATE TABLE cursor_iterable_06 (id int); + INSERT INTO cursor_iterable_06 VALUES (0), (1); + ''') + try: + cur = self.con.cursor('SELECT * FROM cursor_iterable_06') + async for rec in cur: + recs.append(rec) + finally: + # Check that after iteration has exhausted the cursor, + # its associated portal is closed properly, unlocking + # the table. + await self.con.execute('DROP TABLE cursor_iterable_06') + + self.assertEqual(recs, [(i,) for i in range(2)]) class TestCursor(tb.ConnectedTestCase):