Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reconnect when server role changed #1053

Merged
merged 1 commit into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 9 additions & 37 deletions asyncpg/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from . import compat
from . import connection
from . import connect_utils
from . import exceptions
from . import protocol

Expand Down Expand Up @@ -311,7 +310,6 @@ class Pool:
__slots__ = (
'_queue', '_loop', '_minsize', '_maxsize',
'_init', '_connect_args', '_connect_kwargs',
'_working_addr', '_working_config', '_working_params',
'_holders', '_initialized', '_initializing', '_closing',
'_closed', '_connection_class', '_record_class', '_generation',
'_setup', '_max_queries', '_max_inactive_connection_lifetime'
Expand Down Expand Up @@ -377,10 +375,6 @@ def __init__(self, *connect_args,
self._initializing = False
self._queue = None

self._working_addr = None
self._working_config = None
self._working_params = None

self._connection_class = connection_class
self._record_class = record_class

Expand Down Expand Up @@ -430,9 +424,8 @@ async def _initialize(self):
# first few connections in the queue, therefore we want to walk
# `self._holders` in reverse.

# Connect the first connection holder in the queue so that it
# can record `_working_addr` and `_working_opts`, which will
# speed up successive connection attempts.
# Connect the first connection holder in the queue so that
# any connection issues are visible early.
first_ch = self._holders[-1] # type: PoolConnectionHolder
await first_ch.connect()

Expand Down Expand Up @@ -504,36 +497,15 @@ def set_connect_args(self, dsn=None, **connect_kwargs):

self._connect_args = [dsn]
self._connect_kwargs = connect_kwargs
self._working_addr = None
self._working_config = None
self._working_params = None

async def _get_new_connection(self):
if self._working_addr is None:
# First connection attempt on this pool.
con = await connection.connect(
*self._connect_args,
loop=self._loop,
connection_class=self._connection_class,
record_class=self._record_class,
**self._connect_kwargs)

self._working_addr = con._addr
self._working_config = con._config
self._working_params = con._params

else:
# We've connected before and have a resolved address,
# and parsed options and config.
con = await connect_utils._connect_addr(
loop=self._loop,
addr=self._working_addr,
timeout=self._working_params.connect_timeout,
config=self._working_config,
params=self._working_params,
connection_class=self._connection_class,
record_class=self._record_class,
)
con = await connection.connect(
*self._connect_args,
loop=self._loop,
connection_class=self._connection_class,
record_class=self._record_class,
**self._connect_kwargs,
)

if self._init is not None:
try:
Expand Down
66 changes: 66 additions & 0 deletions tests/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import asyncio
import inspect
import os
import pathlib
import platform
import random
import textwrap
Expand All @@ -18,6 +19,7 @@
from asyncpg import _testbase as tb
from asyncpg import connection as pg_connection
from asyncpg import pool as pg_pool
from asyncpg import cluster as pg_cluster

_system = platform.uname().system

Expand Down Expand Up @@ -969,6 +971,70 @@ async def worker():
await pool.release(conn)


@unittest.skipIf(os.environ.get('PGHOST'), 'unmanaged cluster')
class TestPoolReconnectWithTargetSessionAttrs(tb.ClusterTestCase):

@classmethod
def setup_cluster(cls):
cls.cluster = cls.new_cluster(pg_cluster.TempCluster)
cls.start_cluster(cls.cluster)

async def simulate_cluster_recovery_mode(self):
port = self.cluster.get_connection_spec()['port']
await self.loop.run_in_executor(
None,
lambda: self.cluster.stop()
)

# Simulate recovery mode
(pathlib.Path(self.cluster._data_dir) / 'standby.signal').touch()

await self.loop.run_in_executor(
None,
lambda: self.cluster.start(
port=port,
server_settings=self.get_server_settings(),
)
)

async def test_full_reconnect_on_node_change_role(self):
if self.cluster.get_pg_version() < (12, 0):
self.skipTest("PostgreSQL < 12 cannot support standby.signal")
return

pool = await self.create_pool(
min_size=1,
max_size=1,
target_session_attrs='primary'
)

# Force a new connection to be created
await pool.fetchval('SELECT 1')

await self.simulate_cluster_recovery_mode()

# current pool connection info cache is expired,
# but we don't know it yet
with self.assertRaises(asyncpg.TargetServerAttributeNotMatched) as cm:
await pool.execute('SELECT 1')

self.assertEqual(
cm.exception.args[0],
"None of the hosts match the target attribute requirement "
"<SessionAttribute.primary: 'primary'>"
)

# force reconnect
with self.assertRaises(asyncpg.TargetServerAttributeNotMatched) as cm:
await pool.execute('SELECT 1')

self.assertEqual(
cm.exception.args[0],
"None of the hosts match the target attribute requirement "
"<SessionAttribute.primary: 'primary'>"
)


@unittest.skipIf(os.environ.get('PGHOST'), 'using remote cluster for testing')
class TestHotStandby(tb.HotStandbyTestCase):
def create_pool(self, **kwargs):
Expand Down