From 0b0eaf25bea5e133fc747e6c38976791f1e87d51 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Thu, 30 Nov 2023 22:23:39 +0000 Subject: [PATCH 1/2] feat: add socket_callback to connect --- asyncpg/connect_utils.py | 17 +++++++++++++---- asyncpg/connection.py | 6 ++++-- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 414231fd..75408aa8 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -56,6 +56,7 @@ def parse(cls, sslmode): 'direct_tls', 'server_settings', 'target_session_attrs', + 'socket_callback', ]) @@ -261,7 +262,7 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]: def _parse_connect_dsn_and_args(*, dsn, host, port, user, password, passfile, database, ssl, direct_tls, server_settings, - target_session_attrs): + target_session_attrs, socket_callback): # `auth_hosts` is the version of host information for the purposes # of reading the pgpass file. auth_hosts = None @@ -654,7 +655,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, user=user, password=password, database=database, ssl=ssl, sslmode=sslmode, direct_tls=direct_tls, server_settings=server_settings, - target_session_attrs=target_session_attrs) + target_session_attrs=target_session_attrs, + socket_callback=socket_callback) return addrs, params @@ -665,7 +667,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, max_cached_statement_lifetime, max_cacheable_statement_size, ssl, direct_tls, server_settings, - target_session_attrs): + target_session_attrs, socket_callback): local_vars = locals() for var_name in {'max_cacheable_statement_size', 'max_cached_statement_lifetime', @@ -694,7 +696,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, password=password, passfile=passfile, ssl=ssl, direct_tls=direct_tls, database=database, server_settings=server_settings, - target_session_attrs=target_session_attrs) + target_session_attrs=target_session_attrs, + socket_callback=socket_callback) config = _ClientConfiguration( command_timeout=command_timeout, @@ -863,6 +866,12 @@ async def __connect_addr( proto_factory, *addr, ssl=params.ssl ) + elif params.socket_callback: + # if socket factory callback is given, create socket and use + # for connection + sock = await params.socket_callback() + connector = loop.create_connection(proto_factory, sock=sock) + elif params.ssl: connector = _create_ssl_connection( proto_factory, *addr, loop=loop, ssl_context=params.ssl, diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 0367e365..9ffae48f 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -2007,7 +2007,8 @@ async def connect(dsn=None, *, connection_class=Connection, record_class=protocol.Record, server_settings=None, - target_session_attrs=None): + target_session_attrs=None, + socket_callback=None): r"""A coroutine to establish a connection to a PostgreSQL server. The connection parameters may be specified either as a connection @@ -2344,7 +2345,8 @@ async def connect(dsn=None, *, statement_cache_size=statement_cache_size, max_cached_statement_lifetime=max_cached_statement_lifetime, max_cacheable_statement_size=max_cacheable_statement_size, - target_session_attrs=target_session_attrs + target_session_attrs=target_session_attrs, + socket_callback=socket_callback ) From ff82aab23435d86ddccb0b553c2f116be11013a9 Mon Sep 17 00:00:00 2001 From: Jack Wotherspoon Date: Fri, 1 Dec 2023 17:03:47 -0500 Subject: [PATCH 2/2] chore: add option for `ssl` with `socket_callback` (#2) --- asyncpg/connect_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 75408aa8..e1ea9b3a 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -870,7 +870,11 @@ async def __connect_addr( # if socket factory callback is given, create socket and use # for connection sock = await params.socket_callback() - connector = loop.create_connection(proto_factory, sock=sock) + if params.ssl: + host, _ = sock.getpeername() + connector = loop.create_connection(proto_factory, sock=sock, ssl=params.ssl, server_hostname=host) + else: + connector = loop.create_connection(proto_factory, sock=sock) elif params.ssl: connector = _create_ssl_connection(