diff --git a/asyncpg/pool.py b/asyncpg/pool.py index eaf501f4..b02fe597 100644 --- a/asyncpg/pool.py +++ b/asyncpg/pool.py @@ -14,7 +14,6 @@ from . import compat from . import connection -from . import connect_utils from . import exceptions from . import protocol @@ -311,7 +310,6 @@ class Pool: __slots__ = ( '_queue', '_loop', '_minsize', '_maxsize', '_init', '_connect_args', '_connect_kwargs', - '_working_addr', '_working_config', '_working_params', '_holders', '_initialized', '_initializing', '_closing', '_closed', '_connection_class', '_record_class', '_generation', '_setup', '_max_queries', '_max_inactive_connection_lifetime' @@ -377,10 +375,6 @@ def __init__(self, *connect_args, self._initializing = False self._queue = None - self._working_addr = None - self._working_config = None - self._working_params = None - self._connection_class = connection_class self._record_class = record_class @@ -430,9 +424,8 @@ async def _initialize(self): # first few connections in the queue, therefore we want to walk # `self._holders` in reverse. - # Connect the first connection holder in the queue so that it - # can record `_working_addr` and `_working_opts`, which will - # speed up successive connection attempts. + # Connect the first connection holder in the queue so that + # any connection issues are visible early. first_ch = self._holders[-1] # type: PoolConnectionHolder await first_ch.connect() @@ -504,36 +497,15 @@ def set_connect_args(self, dsn=None, **connect_kwargs): self._connect_args = [dsn] self._connect_kwargs = connect_kwargs - self._working_addr = None - self._working_config = None - self._working_params = None async def _get_new_connection(self): - if self._working_addr is None: - # First connection attempt on this pool. - con = await connection.connect( - *self._connect_args, - loop=self._loop, - connection_class=self._connection_class, - record_class=self._record_class, - **self._connect_kwargs) - - self._working_addr = con._addr - self._working_config = con._config - self._working_params = con._params - - else: - # We've connected before and have a resolved address, - # and parsed options and config. - con = await connect_utils._connect_addr( - loop=self._loop, - addr=self._working_addr, - timeout=self._working_params.connect_timeout, - config=self._working_config, - params=self._working_params, - connection_class=self._connection_class, - record_class=self._record_class, - ) + con = await connection.connect( + *self._connect_args, + loop=self._loop, + connection_class=self._connection_class, + record_class=self._record_class, + **self._connect_kwargs, + ) if self._init is not None: try: diff --git a/tests/test_pool.py b/tests/test_pool.py index 540efb08..2407b817 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -8,6 +8,7 @@ import asyncio import inspect import os +import pathlib import platform import random import textwrap @@ -18,6 +19,7 @@ from asyncpg import _testbase as tb from asyncpg import connection as pg_connection from asyncpg import pool as pg_pool +from asyncpg import cluster as pg_cluster _system = platform.uname().system @@ -969,6 +971,70 @@ async def worker(): await pool.release(conn) +@unittest.skipIf(os.environ.get('PGHOST'), 'unmanaged cluster') +class TestPoolReconnectWithTargetSessionAttrs(tb.ClusterTestCase): + + @classmethod + def setup_cluster(cls): + cls.cluster = cls.new_cluster(pg_cluster.TempCluster) + cls.start_cluster(cls.cluster) + + async def simulate_cluster_recovery_mode(self): + port = self.cluster.get_connection_spec()['port'] + await self.loop.run_in_executor( + None, + lambda: self.cluster.stop() + ) + + # Simulate recovery mode + (pathlib.Path(self.cluster._data_dir) / 'standby.signal').touch() + + await self.loop.run_in_executor( + None, + lambda: self.cluster.start( + port=port, + server_settings=self.get_server_settings(), + ) + ) + + async def test_full_reconnect_on_node_change_role(self): + if self.cluster.get_pg_version() < (12, 0): + self.skipTest("PostgreSQL < 12 cannot support standby.signal") + return + + pool = await self.create_pool( + min_size=1, + max_size=1, + target_session_attrs='primary' + ) + + # Force a new connection to be created + await pool.fetchval('SELECT 1') + + await self.simulate_cluster_recovery_mode() + + # current pool connection info cache is expired, + # but we don't know it yet + with self.assertRaises(asyncpg.TargetServerAttributeNotMatched) as cm: + await pool.execute('SELECT 1') + + self.assertEqual( + cm.exception.args[0], + "None of the hosts match the target attribute requirement " + "" + ) + + # force reconnect + with self.assertRaises(asyncpg.TargetServerAttributeNotMatched) as cm: + await pool.execute('SELECT 1') + + self.assertEqual( + cm.exception.args[0], + "None of the hosts match the target attribute requirement " + "" + ) + + @unittest.skipIf(os.environ.get('PGHOST'), 'using remote cluster for testing') class TestHotStandby(tb.HotStandbyTestCase): def create_pool(self, **kwargs):